prithivMLmods commited on
Commit
c96923d
·
verified ·
1 Parent(s): 39acb3e

update [kernels:flash-attn2] (cleaned) ✅

Browse files
Files changed (1) hide show
  1. app.py +167 -29
app.py CHANGED
@@ -27,8 +27,6 @@ from transformers.image_utils import load_image
27
  from gradio.themes import Soft
28
  from gradio.themes.utils import colors, fonts, sizes
29
 
30
- # --- Theme and CSS Definition ---
31
-
32
  colors.steel_blue = colors.Color(
33
  name="steel_blue",
34
  c50="#EBF3F8",
@@ -36,7 +34,7 @@ colors.steel_blue = colors.Color(
36
  c200="#A8CCE1",
37
  c300="#7DB3D2",
38
  c400="#529AC3",
39
- c500="#4682B4", # SteelBlue base color
40
  c600="#3E72A0",
41
  c700="#36638C",
42
  c800="#2E5378",
@@ -91,49 +89,161 @@ class SteelBlueTheme(Soft):
91
 
92
  steel_blue_theme = SteelBlueTheme()
93
 
94
- # Constants for text generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  MAX_MAX_NEW_TOKENS = 2048
96
  DEFAULT_MAX_NEW_TOKENS = 1024
97
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
98
 
99
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
100
 
101
- # Load Cosmos-Reason1-7B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  MODEL_ID_M = "nvidia/Cosmos-Reason1-7B"
103
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
104
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
105
  MODEL_ID_M,
106
- attn_implementation="flash_attention_2",
107
  trust_remote_code=True,
108
  torch_dtype=torch.float16
109
  ).to(device).eval()
110
 
111
- # Load DocScope
112
  MODEL_ID_X = "prithivMLmods/docscopeOCR-7B-050425-exp"
113
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
114
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
115
  MODEL_ID_X,
116
- attn_implementation="flash_attention_2",
117
  trust_remote_code=True,
118
  torch_dtype=torch.float16
119
  ).to(device).eval()
120
 
121
- # Load Relaxed
122
  MODEL_ID_Z = "Ertugrul/Qwen2.5-VL-7B-Captioner-Relaxed"
123
  processor_z = AutoProcessor.from_pretrained(MODEL_ID_Z, trust_remote_code=True)
124
  model_z = Qwen2_5_VLForConditionalGeneration.from_pretrained(
125
  MODEL_ID_Z,
126
- attn_implementation="flash_attention_2",
127
  trust_remote_code=True,
128
  torch_dtype=torch.float16
129
  ).to(device).eval()
130
 
131
- # Load visionOCR
132
  MODEL_ID_V = "prithivMLmods/visionOCR-3B-061125"
133
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
134
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
135
  MODEL_ID_V,
136
- attn_implementation="flash_attention_2",
137
  trust_remote_code=True,
138
  torch_dtype=torch.float16
139
  ).to(device).eval()
@@ -159,13 +269,32 @@ def downsample_video(video_path):
159
  vidcap.release()
160
  return frames
161
 
162
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def generate_image(model_name: str, text: str, image: Image.Image,
164
  max_new_tokens: int = 1024,
165
  temperature: float = 0.6,
166
  top_p: float = 0.9,
167
  top_k: int = 50,
168
- repetition_penalty: float = 1.2):
 
169
  """
170
  Generates responses using the selected model for image input.
171
  Yields raw text and Markdown-formatted text.
@@ -212,13 +341,14 @@ def generate_image(model_name: str, text: str, image: Image.Image,
212
  time.sleep(0.01)
213
  yield buffer, buffer
214
 
215
- @spaces.GPU
216
  def generate_video(model_name: str, text: str, video_path: str,
217
  max_new_tokens: int = 1024,
218
  temperature: float = 0.6,
219
  top_p: float = 0.9,
220
  top_k: int = 50,
221
- repetition_penalty: float = 1.2):
 
222
  """
223
  Generates responses using the selected model for video input.
224
  Yields raw text and Markdown-formatted text.
@@ -276,7 +406,6 @@ def generate_video(model_name: str, text: str, video_path: str,
276
  time.sleep(0.01)
277
  yield buffer, buffer
278
 
279
- # Define examples for image and video inference
280
  image_examples = [
281
  ["Perform OCR on the text in the image.", "images/1.jpg"],
282
  ["Explain the scene in detail.", "images/2.jpg"]
@@ -287,16 +416,6 @@ video_examples = [
287
  ["Identify the main actions in the video", "videos/2.mp4"]
288
  ]
289
 
290
- css = """
291
- #main-title h1 {
292
- font-size: 2.3em !important;
293
- }
294
- #output-title h2 {
295
- font-size: 2.1em !important;
296
- }
297
- """
298
-
299
- # Create the Gradio Interface
300
  with gr.Blocks() as demo:
301
  gr.Markdown("# **DocScope R1**", elem_id="main-title")
302
  with gr.Row():
@@ -332,14 +451,33 @@ with gr.Blocks() as demo:
332
  value="Cosmos-Reason1-7B"
333
  )
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  image_submit.click(
336
  fn=generate_image,
337
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
338
  outputs=[raw_output, markdown_output]
339
  )
340
  video_submit.click(
341
  fn=generate_video,
342
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
343
  outputs=[raw_output, markdown_output]
344
  )
345
 
 
27
  from gradio.themes import Soft
28
  from gradio.themes.utils import colors, fonts, sizes
29
 
 
 
30
  colors.steel_blue = colors.Color(
31
  name="steel_blue",
32
  c50="#EBF3F8",
 
34
  c200="#A8CCE1",
35
  c300="#7DB3D2",
36
  c400="#529AC3",
37
+ c500="#4682B4",
38
  c600="#3E72A0",
39
  c700="#36638C",
40
  c800="#2E5378",
 
89
 
90
  steel_blue_theme = SteelBlueTheme()
91
 
92
+ css = """
93
+ #main-title h1 {
94
+ font-size: 2.3em !important;
95
+ }
96
+ #output-title h2 {
97
+ font-size: 2.2em !important;
98
+ }
99
+
100
+ /* RadioAnimated Styles */
101
+ .ra-wrap{ width: fit-content; }
102
+ .ra-inner{
103
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
104
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
105
+ }
106
+ .ra-input{ display: none; }
107
+ .ra-label{
108
+ position: relative; z-index: 2; padding: 8px 16px;
109
+ font-family: inherit; font-size: 14px; font-weight: 600;
110
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
111
+ }
112
+ .ra-highlight{
113
+ position: absolute; z-index: 1; top: 6px; left: 6px;
114
+ height: calc(100% - 12px); border-radius: 9999px;
115
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
116
+ transition: transform 0.2s, width 0.2s;
117
+ }
118
+ .ra-input:checked + .ra-label{ color: black; }
119
+
120
+ /* Dark mode adjustments for Radio */
121
+ .dark .ra-inner { background: var(--neutral-800); }
122
+ .dark .ra-label { color: var(--neutral-400); }
123
+ .dark .ra-highlight { background: var(--neutral-600); }
124
+ .dark .ra-input:checked + .ra-label { color: white; }
125
+
126
+ #gpu-duration-container {
127
+ padding: 10px;
128
+ border-radius: 8px;
129
+ background: var(--background-fill-secondary);
130
+ border: 1px solid var(--border-color-primary);
131
+ margin-top: 10px;
132
+ }
133
+ """
134
+
135
  MAX_MAX_NEW_TOKENS = 2048
136
  DEFAULT_MAX_NEW_TOKENS = 1024
137
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
138
 
139
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
140
 
141
+ class RadioAnimated(gr.HTML):
142
+ def __init__(self, choices, value=None, **kwargs):
143
+ if not choices or len(choices) < 2:
144
+ raise ValueError("RadioAnimated requires at least 2 choices.")
145
+ if value is None:
146
+ value = choices[0]
147
+
148
+ uid = uuid.uuid4().hex[:8]
149
+ group_name = f"ra-{uid}"
150
+
151
+ inputs_html = "\n".join(
152
+ f"""
153
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
154
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
155
+ """
156
+ for i, c in enumerate(choices)
157
+ )
158
+
159
+ html_template = f"""
160
+ <div class="ra-wrap" data-ra="{uid}">
161
+ <div class="ra-inner">
162
+ <div class="ra-highlight"></div>
163
+ {inputs_html}
164
+ </div>
165
+ </div>
166
+ """
167
+
168
+ js_on_load = r"""
169
+ (() => {
170
+ const wrap = element.querySelector('.ra-wrap');
171
+ const inner = element.querySelector('.ra-inner');
172
+ const highlight = element.querySelector('.ra-highlight');
173
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
174
+
175
+ if (!inputs.length) return;
176
+
177
+ const choices = inputs.map(i => i.value);
178
+
179
+ function setHighlightByIndex(idx) {
180
+ const n = choices.length;
181
+ const pct = 100 / n;
182
+ highlight.style.width = `calc(${pct}% - 6px)`;
183
+ highlight.style.transform = `translateX(${idx * 100}%)`;
184
+ }
185
+
186
+ function setCheckedByValue(val, shouldTrigger=false) {
187
+ const idx = Math.max(0, choices.indexOf(val));
188
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
189
+ setHighlightByIndex(idx);
190
+
191
+ props.value = choices[idx];
192
+ if (shouldTrigger) trigger('change', props.value);
193
+ }
194
+
195
+ setCheckedByValue(props.value ?? choices[0], false);
196
+
197
+ inputs.forEach((inp) => {
198
+ inp.addEventListener('change', () => {
199
+ setCheckedByValue(inp.value, true);
200
+ });
201
+ });
202
+ })();
203
+ """
204
+
205
+ super().__init__(
206
+ value=value,
207
+ html_template=html_template,
208
+ js_on_load=js_on_load,
209
+ **kwargs
210
+ )
211
+
212
+ def apply_gpu_duration(val: str):
213
+ return int(val)
214
+
215
  MODEL_ID_M = "nvidia/Cosmos-Reason1-7B"
216
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
217
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
218
  MODEL_ID_M,
219
+ attn_implementation="kernels-community/flash-attn2",
220
  trust_remote_code=True,
221
  torch_dtype=torch.float16
222
  ).to(device).eval()
223
 
 
224
  MODEL_ID_X = "prithivMLmods/docscopeOCR-7B-050425-exp"
225
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
226
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
227
  MODEL_ID_X,
228
+ attn_implementation="kernels-community/flash-attn2",
229
  trust_remote_code=True,
230
  torch_dtype=torch.float16
231
  ).to(device).eval()
232
 
 
233
  MODEL_ID_Z = "Ertugrul/Qwen2.5-VL-7B-Captioner-Relaxed"
234
  processor_z = AutoProcessor.from_pretrained(MODEL_ID_Z, trust_remote_code=True)
235
  model_z = Qwen2_5_VLForConditionalGeneration.from_pretrained(
236
  MODEL_ID_Z,
237
+ attn_implementation="kernels-community/flash-attn2",
238
  trust_remote_code=True,
239
  torch_dtype=torch.float16
240
  ).to(device).eval()
241
 
 
242
  MODEL_ID_V = "prithivMLmods/visionOCR-3B-061125"
243
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
244
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
245
  MODEL_ID_V,
246
+ attn_implementation="kernels-community/flash-attn2",
247
  trust_remote_code=True,
248
  torch_dtype=torch.float16
249
  ).to(device).eval()
 
269
  vidcap.release()
270
  return frames
271
 
272
+ def calc_timeout_image(model_name: str, text: str, image: Image.Image,
273
+ max_new_tokens: int, temperature: float, top_p: float,
274
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
275
+ """Calculate GPU timeout duration for image inference."""
276
+ try:
277
+ return int(gpu_timeout)
278
+ except:
279
+ return 60
280
+
281
+ def calc_timeout_video(model_name: str, text: str, video_path: str,
282
+ max_new_tokens: int, temperature: float, top_p: float,
283
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
284
+ """Calculate GPU timeout duration for video inference."""
285
+ try:
286
+ return int(gpu_timeout)
287
+ except:
288
+ return 60
289
+
290
+ @spaces.GPU(duration=calc_timeout_image)
291
  def generate_image(model_name: str, text: str, image: Image.Image,
292
  max_new_tokens: int = 1024,
293
  temperature: float = 0.6,
294
  top_p: float = 0.9,
295
  top_k: int = 50,
296
+ repetition_penalty: float = 1.2,
297
+ gpu_timeout: int = 60):
298
  """
299
  Generates responses using the selected model for image input.
300
  Yields raw text and Markdown-formatted text.
 
341
  time.sleep(0.01)
342
  yield buffer, buffer
343
 
344
+ @spaces.GPU(duration=calc_timeout_video)
345
  def generate_video(model_name: str, text: str, video_path: str,
346
  max_new_tokens: int = 1024,
347
  temperature: float = 0.6,
348
  top_p: float = 0.9,
349
  top_k: int = 50,
350
+ repetition_penalty: float = 1.2,
351
+ gpu_timeout: int = 90):
352
  """
353
  Generates responses using the selected model for video input.
354
  Yields raw text and Markdown-formatted text.
 
406
  time.sleep(0.01)
407
  yield buffer, buffer
408
 
 
409
  image_examples = [
410
  ["Perform OCR on the text in the image.", "images/1.jpg"],
411
  ["Explain the scene in detail.", "images/2.jpg"]
 
416
  ["Identify the main actions in the video", "videos/2.mp4"]
417
  ]
418
 
 
 
 
 
 
 
 
 
 
 
419
  with gr.Blocks() as demo:
420
  gr.Markdown("# **DocScope R1**", elem_id="main-title")
421
  with gr.Row():
 
451
  value="Cosmos-Reason1-7B"
452
  )
453
 
454
+ with gr.Row(elem_id="gpu-duration-container"):
455
+ with gr.Column():
456
+ gr.Markdown("**GPU Duration (seconds)**")
457
+ radioanimated_gpu_duration = RadioAnimated(
458
+ choices=["60", "90", "120", "180", "240", "300"],
459
+ value="60",
460
+ elem_id="radioanimated_gpu_duration"
461
+ )
462
+ gpu_duration_state = gr.Number(value=60, visible=False)
463
+
464
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
465
+
466
+ radioanimated_gpu_duration.change(
467
+ fn=apply_gpu_duration,
468
+ inputs=radioanimated_gpu_duration,
469
+ outputs=[gpu_duration_state],
470
+ api_visibility="private"
471
+ )
472
+
473
  image_submit.click(
474
  fn=generate_image,
475
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
476
  outputs=[raw_output, markdown_output]
477
  )
478
  video_submit.click(
479
  fn=generate_video,
480
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
481
  outputs=[raw_output, markdown_output]
482
  )
483