Translsis commited on
Commit
ab05948
·
verified ·
1 Parent(s): 4c0f830

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +462 -78
app.py CHANGED
@@ -3,27 +3,128 @@ import gradio as gr
3
  import torch
4
  import numpy as np
5
  from PIL import Image
6
- from transformers import Sam3Processor, Sam3Model
7
  import requests
8
  import warnings
 
 
 
 
 
 
9
  warnings.filterwarnings("ignore")
10
 
11
  # Global model and processor
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model = Sam3Model.from_pretrained("facebook/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
14
- processor = Sam3Processor.from_pretrained("facebook/sam3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @spaces.GPU()
17
- def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
18
  """
19
- Perform promptable concept segmentation using SAM3.
20
- Returns format compatible with gr.AnnotatedImage: (image, [(mask, label), ...])
21
  """
22
  if image is None:
23
- return None, " Please upload an image."
24
 
25
  if not text.strip():
26
- return (image, []), " Please enter a text prompt."
27
 
28
  try:
29
  inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
@@ -44,29 +145,205 @@ def segment(image: Image.Image, text: str, threshold: float, mask_threshold: flo
44
 
45
  n_masks = len(results['masks'])
46
  if n_masks == 0:
47
- return (image, []), f" No objects found matching '{text}' (try adjusting thresholds)."
48
 
49
- # Format for AnnotatedImage: list of (mask, label) tuples
50
- # mask should be numpy array with values 0-1 (float) matching image dimensions
51
  annotations = []
52
  for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])):
53
- # Convert binary mask to float numpy array (0-1 range)
54
  mask_np = mask.cpu().numpy().astype(np.float32)
55
  label = f"{text} #{i+1} ({score:.2f})"
56
  annotations.append((mask_np, label))
57
 
58
- scores_text = ", ".join([f"{s:.2f}" for s in results['scores'].cpu().numpy()[:5]])
59
- info = f"✅ Found **{n_masks}** objects matching **'{text}'**\nConfidence scores: {scores_text}{'...' if n_masks > 5 else ''}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Return tuple: (base_image, list_of_annotations)
62
- return (image, annotations), info
 
 
 
 
 
 
63
 
64
  except Exception as e:
65
- return (image, []), f" Error during segmentation: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def clear_all():
68
  """Clear all inputs and outputs"""
69
- return None, "", None, 0.5, 0.5, "📝 Enter a prompt and click **Segment** to start."
70
 
71
  def segment_example(image_path: str, prompt: str):
72
  """Handle example clicks"""
@@ -80,92 +357,199 @@ def segment_example(image_path: str, prompt: str):
80
  with gr.Blocks(
81
  theme=gr.themes.Soft(),
82
  title="SAM3 - Promptable Concept Segmentation",
83
- css=".gradio-container {max-width: 1400px !important;}"
84
  ) as demo:
85
  gr.Markdown(
86
  """
87
  # SAM3 - Promptable Concept Segmentation (PCS)
88
 
89
  **SAM3** performs zero-shot instance segmentation using natural language prompts.
90
- Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks.
91
 
92
  Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
93
  """
94
  )
95
 
96
- gr.Markdown("### Inputs")
97
- with gr.Row(variant="panel"):
98
- image_input = gr.Image(
99
- label="Input Image",
100
- type="pil",
101
- height=400,
102
- )
103
- # AnnotatedImage expects: (base_image, [(mask, label), ...])
104
- image_output = gr.AnnotatedImage(
105
- label="Output (Segmented Image)",
106
- height=400,
107
- show_legend=True,
108
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- with gr.Row():
111
- text_input = gr.Textbox(
112
- label="Text Prompt",
113
- placeholder="e.g., person, ear, cat, bicycle...",
114
- scale=3
115
- )
116
- clear_btn = gr.Button("🔍 Clear", size="sm", variant="secondary")
117
-
118
- with gr.Row():
119
- thresh_slider = gr.Slider(
120
- minimum=0.0,
121
- maximum=1.0,
122
- value=0.5,
123
- step=0.01,
124
- label="Detection Threshold",
125
- info="Higher = fewer detections"
126
- )
127
- mask_thresh_slider = gr.Slider(
128
- minimum=0.0,
129
- maximum=1.0,
130
- value=0.5,
131
- step=0.01,
132
- label="Mask Threshold",
133
- info="Higher = sharper masks"
134
- )
135
 
136
- info_output = gr.Markdown(
137
- value="📝 Enter a prompt and click **Segment** to start.",
138
- label="Info / Results"
 
139
  )
140
 
141
- segment_btn = gr.Button("🎯 Segment", variant="primary", size="lg")
 
 
 
 
142
 
143
- gr.Examples(
144
- examples=[
145
- ["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"],
146
- ],
147
- inputs=[image_input, text_input],
148
- outputs=[image_output, info_output],
149
- fn=segment_example,
150
- cache_examples=False,
151
  )
152
 
153
- clear_btn.click(
154
- fn=clear_all,
155
- outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output]
156
  )
157
 
158
- segment_btn.click(
159
- fn=segment,
160
- inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
161
- outputs=[image_output, info_output]
162
  )
163
 
164
  gr.Markdown(
165
  """
166
  ### Notes
167
  - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
168
- - Click on segments in the output to see labels
 
 
 
169
  - GPU recommended for faster inference
170
  """
171
  )
 
3
  import torch
4
  import numpy as np
5
  from PIL import Image
 
6
  import requests
7
  import warnings
8
+ import json
9
+ import os
10
+ from datetime import datetime
11
+ from threading import Thread
12
+ from queue import Queue
13
+ import time
14
  warnings.filterwarnings("ignore")
15
 
16
  # Global model and processor
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ from transformers import Sam3Processor, Sam3Model
19
+ model = Sam3Model.from_pretrained("DiffusionWave/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
20
+ processor = Sam3Processor.from_pretrained("DiffusionWave/sam3")
21
+
22
+ # Background processing queue
23
+ job_queue = Queue()
24
+ results_store = {}
25
+ job_counter = 0
26
+
27
+ # History storage
28
+ HISTORY_DIR = "segmentation_history"
29
+ HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
30
+ CROPS_DIR = os.path.join(HISTORY_DIR, "crops")
31
+ os.makedirs(HISTORY_DIR, exist_ok=True)
32
+ os.makedirs(CROPS_DIR, exist_ok=True)
33
+
34
+ def load_history():
35
+ """Load segmentation history from file"""
36
+ if os.path.exists(HISTORY_FILE):
37
+ try:
38
+ with open(HISTORY_FILE, 'r') as f:
39
+ return json.load(f)
40
+ except:
41
+ return []
42
+ return []
43
+
44
+ def save_history(history):
45
+ """Save segmentation history to file"""
46
+ with open(HISTORY_FILE, 'w') as f:
47
+ json.dump(history, f, indent=2)
48
+
49
+ def crop_segmented_objects(image: Image.Image, masks, text: str, timestamp: str):
50
+ """
51
+ Crop individual objects from masks and save them
52
+ Returns list of cropped image paths
53
+ """
54
+ cropped_paths = []
55
+ image_np = np.array(image)
56
+
57
+ for i, mask in enumerate(masks):
58
+ # Convert mask to numpy if needed
59
+ if isinstance(mask, torch.Tensor):
60
+ mask_np = mask.cpu().numpy()
61
+ else:
62
+ mask_np = mask
63
+
64
+ # Find bounding box of the mask
65
+ rows = np.any(mask_np, axis=1)
66
+ cols = np.any(mask_np, axis=0)
67
+
68
+ if not rows.any() or not cols.any():
69
+ continue
70
+
71
+ y_min, y_max = np.where(rows)[0][[0, -1]]
72
+ x_min, x_max = np.where(cols)[0][[0, -1]]
73
+
74
+ # Add padding (10 pixels)
75
+ padding = 10
76
+ y_min = max(0, y_min - padding)
77
+ y_max = min(image_np.shape[0], y_max + padding)
78
+ x_min = max(0, x_min - padding)
79
+ x_max = min(image_np.shape[1], x_max + padding)
80
+
81
+ # Crop the image
82
+ cropped = image_np[y_min:y_max, x_min:x_max]
83
+
84
+ # Apply mask to cropped region (transparent background)
85
+ mask_crop = mask_np[y_min:y_max, x_min:x_max]
86
+
87
+ # Create RGBA image
88
+ cropped_rgba = np.zeros((*cropped.shape[:2], 4), dtype=np.uint8)
89
+ cropped_rgba[:, :, :3] = cropped
90
+ cropped_rgba[:, :, 3] = (mask_crop * 255).astype(np.uint8)
91
+
92
+ # Save cropped image
93
+ crop_filename = f"crop_{timestamp.replace(':', '-').replace(' ', '_')}_{text}_{i+1}.png"
94
+ crop_path = os.path.join(CROPS_DIR, crop_filename)
95
+ Image.fromarray(cropped_rgba).save(crop_path)
96
+ cropped_paths.append(crop_path)
97
+
98
+ return cropped_paths
99
+
100
+ def add_to_history(image_path, prompt, n_masks, scores, timestamp, crop_paths):
101
+ """Add a new entry to history"""
102
+ history = load_history()
103
+ entry = {
104
+ "id": len(history) + 1,
105
+ "timestamp": timestamp,
106
+ "image_path": image_path,
107
+ "prompt": prompt,
108
+ "n_masks": n_masks,
109
+ "scores": scores,
110
+ "crop_paths": crop_paths
111
+ }
112
+ history.insert(0, entry) # Add to beginning
113
+ # Keep only last 100 entries
114
+ history = history[:100]
115
+ save_history(history)
116
+ return history
117
 
118
  @spaces.GPU()
119
+ def segment_core(image: Image.Image, text: str, threshold: float, mask_threshold: float, save_crops: bool = True):
120
  """
121
+ Core segmentation function - can be called independently
 
122
  """
123
  if image is None:
124
+ return None, "❌ Please upload an image.", None, []
125
 
126
  if not text.strip():
127
+ return (image, []), "❌ Please enter a text prompt.", None, []
128
 
129
  try:
130
  inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
 
145
 
146
  n_masks = len(results['masks'])
147
  if n_masks == 0:
148
+ return (image, []), f"❌ No objects found matching '{text}' (try adjusting thresholds).", None, []
149
 
150
+ # Format for AnnotatedImage
 
151
  annotations = []
152
  for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])):
 
153
  mask_np = mask.cpu().numpy().astype(np.float32)
154
  label = f"{text} #{i+1} ({score:.2f})"
155
  annotations.append((mask_np, label))
156
 
157
+ scores_list = results['scores'].cpu().numpy().tolist()
158
+ scores_text = ", ".join([f"{s:.2f}" for s in scores_list[:5]])
159
+ info = f"✅ Found **{n_masks}** objects matching **'{text}'**\n"
160
+ info += f"Confidence scores: {scores_text}{'...' if n_masks > 5 else ''}\n"
161
+
162
+ # Crop objects if requested
163
+ cropped_images = []
164
+ if save_crops:
165
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
166
+ crop_paths = crop_segmented_objects(image, results['masks'], text, timestamp)
167
+ info += f"✂️ Extracted **{len(crop_paths)}** cropped objects"
168
+
169
+ # Load cropped images for display
170
+ for path in crop_paths[:10]: # Limit to 10 for display
171
+ if os.path.exists(path):
172
+ cropped_images.append(Image.open(path))
173
+ else:
174
+ crop_paths = []
175
 
176
+ metadata = {
177
+ "n_masks": n_masks,
178
+ "scores": scores_list,
179
+ "crop_paths": crop_paths,
180
+ "masks": results['masks']
181
+ }
182
+
183
+ return (image, annotations), info, metadata, cropped_images
184
 
185
  except Exception as e:
186
+ return (image, []), f"❌ Error during segmentation: {str(e)}", None, []
187
+
188
+ def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
189
+ """
190
+ Frontend segment function - with history saving
191
+ """
192
+ result, info, metadata, cropped_images = segment_core(image, text, threshold, mask_threshold, save_crops=True)
193
+
194
+ # Save to history if successful
195
+ if metadata and metadata["n_masks"] > 0:
196
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
197
+ # Save image temporarily
198
+ img_filename = f"img_{int(time.time())}.jpg"
199
+ img_path = os.path.join(HISTORY_DIR, img_filename)
200
+ image.save(img_path)
201
+
202
+ add_to_history(
203
+ img_path,
204
+ text,
205
+ metadata["n_masks"],
206
+ metadata["scores"],
207
+ timestamp,
208
+ metadata["crop_paths"]
209
+ )
210
+
211
+ return result, info, cropped_images
212
+
213
+ def background_worker():
214
+ """
215
+ Background worker thread - processes jobs independently
216
+ """
217
+ while True:
218
+ job = job_queue.get()
219
+ if job is None:
220
+ break
221
+
222
+ job_id, image, text, threshold, mask_threshold = job
223
+
224
+ try:
225
+ result, info, metadata, cropped_images = segment_core(image, text, threshold, mask_threshold, save_crops=True)
226
+ results_store[job_id] = {
227
+ "status": "completed",
228
+ "result": result,
229
+ "info": info,
230
+ "metadata": metadata,
231
+ "cropped_images": cropped_images,
232
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
233
+ }
234
+
235
+ # Save to history
236
+ if metadata and metadata["n_masks"] > 0:
237
+ img_filename = f"bg_img_{job_id}.jpg"
238
+ img_path = os.path.join(HISTORY_DIR, img_filename)
239
+ image.save(img_path)
240
+ add_to_history(
241
+ img_path,
242
+ text,
243
+ metadata["n_masks"],
244
+ metadata["scores"],
245
+ results_store[job_id]["timestamp"],
246
+ metadata["crop_paths"]
247
+ )
248
+
249
+ except Exception as e:
250
+ results_store[job_id] = {
251
+ "status": "failed",
252
+ "error": str(e),
253
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
254
+ }
255
+
256
+ job_queue.task_done()
257
+
258
+ # Start background worker
259
+ worker_thread = Thread(target=background_worker, daemon=True)
260
+ worker_thread.start()
261
+
262
+ def submit_background_job(image, text, threshold, mask_threshold):
263
+ """Submit a job to background queue"""
264
+ global job_counter
265
+ if image is None or not text.strip():
266
+ return "❌ Please provide image and text prompt.", ""
267
+
268
+ job_counter += 1
269
+ job_id = job_counter
270
+
271
+ job_queue.put((job_id, image, text, threshold, mask_threshold))
272
+ results_store[job_id] = {"status": "processing"}
273
+
274
+ return f"✅ Job #{job_id} submitted to background queue.", f"{job_id}"
275
+
276
+ def check_background_job(job_id_str):
277
+ """Check status of background job"""
278
+ if not job_id_str.strip():
279
+ return "❌ Please enter a job ID.", None, []
280
+
281
+ try:
282
+ job_id = int(job_id_str)
283
+ if job_id not in results_store:
284
+ return f"❌ Job #{job_id} not found.", None, []
285
+
286
+ job_data = results_store[job_id]
287
+ status = job_data["status"]
288
+
289
+ if status == "processing":
290
+ return f"⏳ Job #{job_id} is still processing...", None, []
291
+ elif status == "completed":
292
+ return (
293
+ f"✅ Job #{job_id} completed!\n{job_data['info']}",
294
+ job_data["result"],
295
+ job_data.get("cropped_images", [])
296
+ )
297
+ else:
298
+ return f"❌ Job #{job_id} failed: {job_data.get('error', 'Unknown error')}", None, []
299
+
300
+ except ValueError:
301
+ return "❌ Invalid job ID format.", None, []
302
+
303
+ def load_history_display():
304
+ """Load and format history for display"""
305
+ history = load_history()
306
+ if not history:
307
+ return "📝 No history yet. Start segmenting images!"
308
+
309
+ display = "## Segmentation History\n\n"
310
+ for entry in history[:20]: # Show last 20
311
+ display += f"**#{entry['id']}** - {entry['timestamp']}\n"
312
+ display += f"- Prompt: `{entry['prompt']}`\n"
313
+ display += f"- Found: {entry['n_masks']} objects\n"
314
+ display += f"- Cropped: {len(entry.get('crop_paths', []))} images\n"
315
+ display += f"- Top scores: {', '.join([f'{s:.2f}' for s in entry['scores'][:3]])}\n\n"
316
+
317
+ return display
318
+
319
+ def load_history_item(item_id):
320
+ """Load a specific history item with cropped images"""
321
+ history = load_history()
322
+ for entry in history:
323
+ if entry['id'] == int(item_id):
324
+ info = f"**History item #{entry['id']}**\n"
325
+ info += f"Timestamp: {entry['timestamp']}\n"
326
+ info += f"Prompt: `{entry['prompt']}`\n"
327
+ info += f"Objects found: {entry['n_masks']}\n"
328
+ info += f"Cropped images: {len(entry.get('crop_paths', []))}"
329
+
330
+ image = None
331
+ if os.path.exists(entry['image_path']):
332
+ image = Image.open(entry['image_path'])
333
+
334
+ # Load cropped images
335
+ cropped_images = []
336
+ for crop_path in entry.get('crop_paths', [])[:10]:
337
+ if os.path.exists(crop_path):
338
+ cropped_images.append(Image.open(crop_path))
339
+
340
+ return image, entry['prompt'], info, cropped_images
341
+
342
+ return None, "", f"❌ History item #{item_id} not found", []
343
 
344
  def clear_all():
345
  """Clear all inputs and outputs"""
346
+ return None, "", None, 0.5, 0.5, "📝 Enter a prompt and click **Segment** to start.", []
347
 
348
  def segment_example(image_path: str, prompt: str):
349
  """Handle example clicks"""
 
357
  with gr.Blocks(
358
  theme=gr.themes.Soft(),
359
  title="SAM3 - Promptable Concept Segmentation",
360
+ css=".gradio-container {max-width: 1600px !important;}"
361
  ) as demo:
362
  gr.Markdown(
363
  """
364
  # SAM3 - Promptable Concept Segmentation (PCS)
365
 
366
  **SAM3** performs zero-shot instance segmentation using natural language prompts.
367
+ Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks + cropped objects.
368
 
369
  Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
370
  """
371
  )
372
 
373
+ with gr.Tabs():
374
+ # Tab 1: Main Segmentation
375
+ with gr.Tab("🎯 Segmentation"):
376
+ gr.Markdown("### Inputs")
377
+ with gr.Row(variant="panel"):
378
+ image_input = gr.Image(
379
+ label="Input Image",
380
+ type="pil",
381
+ height=400,
382
+ )
383
+ image_output = gr.AnnotatedImage(
384
+ label="Output (Segmented Image)",
385
+ height=400,
386
+ show_legend=True,
387
+ )
388
+
389
+ with gr.Row():
390
+ text_input = gr.Textbox(
391
+ label="Text Prompt",
392
+ placeholder="e.g., person, ear, cat, bicycle...",
393
+ scale=3
394
+ )
395
+ clear_btn = gr.Button("🔄 Clear", size="sm", variant="secondary")
396
+
397
+ with gr.Row():
398
+ thresh_slider = gr.Slider(
399
+ minimum=0.0,
400
+ maximum=1.0,
401
+ value=0.5,
402
+ step=0.01,
403
+ label="Detection Threshold",
404
+ info="Higher = fewer detections"
405
+ )
406
+ mask_thresh_slider = gr.Slider(
407
+ minimum=0.0,
408
+ maximum=1.0,
409
+ value=0.5,
410
+ step=0.01,
411
+ label="Mask Threshold",
412
+ info="Higher = sharper masks"
413
+ )
414
+
415
+ info_output = gr.Markdown(
416
+ value="📝 Enter a prompt and click **Segment** to start.",
417
+ label="Info / Results"
418
+ )
419
+
420
+ segment_btn = gr.Button("🎯 Segment Now", variant="primary", size="lg")
421
+
422
+ gr.Markdown("### ✂️ Cropped Objects")
423
+ cropped_gallery = gr.Gallery(
424
+ label="Extracted Objects",
425
+ columns=5,
426
+ height=300,
427
+ object_fit="contain"
428
+ )
429
+
430
+ gr.Examples(
431
+ examples=[
432
+ ["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"],
433
+ ],
434
+ inputs=[image_input, text_input],
435
+ outputs=[image_output, info_output, cropped_gallery],
436
+ fn=segment_example,
437
+ cache_examples=False,
438
+ )
439
+
440
+ # Tab 2: Background Processing
441
+ with gr.Tab("⚙️ Background Processing"):
442
+ gr.Markdown(
443
+ """
444
+ ### Background Job Queue
445
+ Submit segmentation jobs that run independently in the background.
446
+ Useful for batch processing or when you want to continue working while processing.
447
+ """
448
+ )
449
+
450
+ with gr.Row():
451
+ bg_image_input = gr.Image(label="Image", type="pil", height=300)
452
+ bg_status_output = gr.Markdown("📝 Submit a job to start background processing.")
453
+
454
+ with gr.Row():
455
+ bg_text_input = gr.Textbox(label="Text Prompt", placeholder="e.g., person, car...")
456
+ bg_job_id_output = gr.Textbox(label="Job ID", interactive=False)
457
+
458
+ with gr.Row():
459
+ bg_thresh = gr.Slider(0.0, 1.0, 0.5, 0.01, label="Detection Threshold")
460
+ bg_mask_thresh = gr.Slider(0.0, 1.0, 0.5, 0.01, label="Mask Threshold")
461
+
462
+ bg_submit_btn = gr.Button("📤 Submit Background Job", variant="primary")
463
+
464
+ gr.Markdown("---")
465
+ gr.Markdown("### Check Job Status")
466
+
467
+ with gr.Row():
468
+ check_job_id = gr.Textbox(label="Enter Job ID", placeholder="e.g., 1")
469
+ check_btn = gr.Button("🔍 Check Status", variant="secondary")
470
+
471
+ check_status_output = gr.Markdown("Enter a job ID and click Check Status.")
472
+ check_result_output = gr.AnnotatedImage(label="Result", height=400)
473
+
474
+ gr.Markdown("### Cropped Objects from Job")
475
+ check_cropped_gallery = gr.Gallery(
476
+ label="Extracted Objects",
477
+ columns=5,
478
+ height=300,
479
+ object_fit="contain"
480
+ )
481
+
482
+ # Tab 3: History
483
+ with gr.Tab("📚 History"):
484
+ gr.Markdown("### Segmentation History")
485
+
486
+ with gr.Row():
487
+ refresh_history_btn = gr.Button("🔄 Refresh History", variant="secondary")
488
+ history_item_id = gr.Textbox(label="Load History Item #", placeholder="Enter ID")
489
+ load_history_btn = gr.Button("📂 Load Item", variant="primary")
490
+
491
+ history_display = gr.Markdown(load_history_display())
492
+
493
+ gr.Markdown("---")
494
+ gr.Markdown("### Loaded History Item")
495
+
496
+ with gr.Row():
497
+ history_image = gr.Image(label="Original Image", type="pil", height=300)
498
+ history_info = gr.Markdown("Select a history item to view.")
499
+
500
+ history_prompt = gr.Textbox(label="Prompt", interactive=False)
501
+
502
+ gr.Markdown("### Cropped Objects from History")
503
+ history_cropped_gallery = gr.Gallery(
504
+ label="Extracted Objects",
505
+ columns=5,
506
+ height=300,
507
+ object_fit="contain"
508
+ )
509
 
510
+ # Event handlers
511
+ clear_btn.click(
512
+ fn=clear_all,
513
+ outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output, cropped_gallery]
514
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
+ segment_btn.click(
517
+ fn=segment,
518
+ inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
519
+ outputs=[image_output, info_output, cropped_gallery]
520
  )
521
 
522
+ bg_submit_btn.click(
523
+ fn=submit_background_job,
524
+ inputs=[bg_image_input, bg_text_input, bg_thresh, bg_mask_thresh],
525
+ outputs=[bg_status_output, bg_job_id_output]
526
+ )
527
 
528
+ check_btn.click(
529
+ fn=check_background_job,
530
+ inputs=[check_job_id],
531
+ outputs=[check_status_output, check_result_output, check_cropped_gallery]
 
 
 
 
532
  )
533
 
534
+ refresh_history_btn.click(
535
+ fn=load_history_display,
536
+ outputs=[history_display]
537
  )
538
 
539
+ load_history_btn.click(
540
+ fn=load_history_item,
541
+ inputs=[history_item_id],
542
+ outputs=[history_image, history_prompt, history_info, history_cropped_gallery]
543
  )
544
 
545
  gr.Markdown(
546
  """
547
  ### Notes
548
  - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
549
+ - Background jobs run independently and are tracked by Job ID
550
+ - All segmented objects are automatically cropped and saved
551
+ - Cropped images have transparent backgrounds (PNG format)
552
+ - History is saved automatically and persists across sessions
553
  - GPU recommended for faster inference
554
  """
555
  )