File size: 20,348 Bytes
75921b2
ff645cc
 
 
 
bc62515
ff645cc
ab05948
 
 
 
 
 
ff645cc
 
 
 
ab05948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff645cc
e3ba680
ab05948
ff645cc
ab05948
ff645cc
 
ab05948
ff645cc
bc62515
ab05948
bc62515
ff645cc
 
 
30a638e
 
 
 
ff645cc
 
 
 
 
 
 
 
 
 
 
 
ab05948
ff645cc
ab05948
bc62515
 
 
 
 
ff645cc
ab05948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff645cc
ab05948
 
 
 
 
 
 
 
ff645cc
 
ab05948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff645cc
ed119eb
 
ab05948
ed119eb
 
 
bc62515
 
 
 
ed119eb
 
ff645cc
 
 
 
ab05948
ff645cc
 
 
 
 
bc62515
ab05948
ff645cc
 
 
 
 
ab05948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff645cc
ab05948
 
 
 
 
ff645cc
ab05948
 
 
 
ff645cc
 
ab05948
 
 
 
 
ff645cc
ab05948
 
 
 
30a638e
 
ab05948
 
 
ed119eb
 
ab05948
 
 
 
ff645cc
 
 
 
 
 
ab05948
 
 
 
bc62515
ff645cc
 
 
 
fd4d970
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
import requests
import warnings
import json
import os
from datetime import datetime
from threading import Thread
from queue import Queue
import time
warnings.filterwarnings("ignore")

# Global model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import Sam3Processor, Sam3Model
model = Sam3Model.from_pretrained("DiffusionWave/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
processor = Sam3Processor.from_pretrained("DiffusionWave/sam3")

# Background processing queue
job_queue = Queue()
results_store = {}
job_counter = 0

# History storage
HISTORY_DIR = "segmentation_history"
HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json")
CROPS_DIR = os.path.join(HISTORY_DIR, "crops")
os.makedirs(HISTORY_DIR, exist_ok=True)
os.makedirs(CROPS_DIR, exist_ok=True)

def load_history():
    """Load segmentation history from file"""
    if os.path.exists(HISTORY_FILE):
        try:
            with open(HISTORY_FILE, 'r') as f:
                return json.load(f)
        except:
            return []
    return []

def save_history(history):
    """Save segmentation history to file"""
    with open(HISTORY_FILE, 'w') as f:
        json.dump(history, f, indent=2)

def crop_segmented_objects(image: Image.Image, masks, text: str, timestamp: str):
    """
    Crop individual objects from masks and save them
    Returns list of cropped image paths
    """
    cropped_paths = []
    image_np = np.array(image)
    
    for i, mask in enumerate(masks):
        # Convert mask to numpy if needed
        if isinstance(mask, torch.Tensor):
            mask_np = mask.cpu().numpy()
        else:
            mask_np = mask
        
        # Find bounding box of the mask
        rows = np.any(mask_np, axis=1)
        cols = np.any(mask_np, axis=0)
        
        if not rows.any() or not cols.any():
            continue
            
        y_min, y_max = np.where(rows)[0][[0, -1]]
        x_min, x_max = np.where(cols)[0][[0, -1]]
        
        # Add padding (10 pixels)
        padding = 10
        y_min = max(0, y_min - padding)
        y_max = min(image_np.shape[0], y_max + padding)
        x_min = max(0, x_min - padding)
        x_max = min(image_np.shape[1], x_max + padding)
        
        # Crop the image
        cropped = image_np[y_min:y_max, x_min:x_max]
        
        # Apply mask to cropped region (transparent background)
        mask_crop = mask_np[y_min:y_max, x_min:x_max]
        
        # Create RGBA image
        cropped_rgba = np.zeros((*cropped.shape[:2], 4), dtype=np.uint8)
        cropped_rgba[:, :, :3] = cropped
        cropped_rgba[:, :, 3] = (mask_crop * 255).astype(np.uint8)
        
        # Save cropped image
        crop_filename = f"crop_{timestamp.replace(':', '-').replace(' ', '_')}_{text}_{i+1}.png"
        crop_path = os.path.join(CROPS_DIR, crop_filename)
        Image.fromarray(cropped_rgba).save(crop_path)
        cropped_paths.append(crop_path)
    
    return cropped_paths

def add_to_history(image_path, prompt, n_masks, scores, timestamp, crop_paths):
    """Add a new entry to history"""
    history = load_history()
    entry = {
        "id": len(history) + 1,
        "timestamp": timestamp,
        "image_path": image_path,
        "prompt": prompt,
        "n_masks": n_masks,
        "scores": scores,
        "crop_paths": crop_paths
    }
    history.insert(0, entry)  # Add to beginning
    # Keep only last 100 entries
    history = history[:100]
    save_history(history)
    return history

@spaces.GPU()
def segment_core(image: Image.Image, text: str, threshold: float, mask_threshold: float, save_crops: bool = True):
    """
    Core segmentation function - can be called independently
    """
    if image is None:
        return None, "❌ Please upload an image.", None, []
    
    if not text.strip():
        return (image, []), "❌ Please enter a text prompt.", None, []
    
    try:
        inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device)
        
        for key in inputs:
            if inputs[key].dtype == torch.float32:
                inputs[key] = inputs[key].to(model.dtype)
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        results = processor.post_process_instance_segmentation(
            outputs,
            threshold=threshold,
            mask_threshold=mask_threshold,
            target_sizes=inputs.get("original_sizes").tolist()
        )[0]
        
        n_masks = len(results['masks'])
        if n_masks == 0:
            return (image, []), f"❌ No objects found matching '{text}' (try adjusting thresholds).", None, []
        
        # Format for AnnotatedImage
        annotations = []
        for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])):
            mask_np = mask.cpu().numpy().astype(np.float32)
            label = f"{text} #{i+1} ({score:.2f})"
            annotations.append((mask_np, label))
        
        scores_list = results['scores'].cpu().numpy().tolist()
        scores_text = ", ".join([f"{s:.2f}" for s in scores_list[:5]])
        info = f"✅ Found **{n_masks}** objects matching **'{text}'**\n"
        info += f"Confidence scores: {scores_text}{'...' if n_masks > 5 else ''}\n"
        
        # Crop objects if requested
        cropped_images = []
        if save_crops:
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            crop_paths = crop_segmented_objects(image, results['masks'], text, timestamp)
            info += f"✂️ Extracted **{len(crop_paths)}** cropped objects"
            
            # Load cropped images for display
            for path in crop_paths[:10]:  # Limit to 10 for display
                if os.path.exists(path):
                    cropped_images.append(Image.open(path))
        else:
            crop_paths = []
        
        metadata = {
            "n_masks": n_masks, 
            "scores": scores_list,
            "crop_paths": crop_paths,
            "masks": results['masks']
        }
        
        return (image, annotations), info, metadata, cropped_images
        
    except Exception as e:
        return (image, []), f"❌ Error during segmentation: {str(e)}", None, []

def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float):
    """
    Frontend segment function - with history saving
    """
    result, info, metadata, cropped_images = segment_core(image, text, threshold, mask_threshold, save_crops=True)
    
    # Save to history if successful
    if metadata and metadata["n_masks"] > 0:
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        # Save image temporarily
        img_filename = f"img_{int(time.time())}.jpg"
        img_path = os.path.join(HISTORY_DIR, img_filename)
        image.save(img_path)
        
        add_to_history(
            img_path,
            text,
            metadata["n_masks"],
            metadata["scores"],
            timestamp,
            metadata["crop_paths"]
        )
    
    return result, info, cropped_images

def background_worker():
    """
    Background worker thread - processes jobs independently
    """
    while True:
        job = job_queue.get()
        if job is None:
            break
        
        job_id, image, text, threshold, mask_threshold = job
        
        try:
            result, info, metadata, cropped_images = segment_core(image, text, threshold, mask_threshold, save_crops=True)
            results_store[job_id] = {
                "status": "completed",
                "result": result,
                "info": info,
                "metadata": metadata,
                "cropped_images": cropped_images,
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            }
            
            # Save to history
            if metadata and metadata["n_masks"] > 0:
                img_filename = f"bg_img_{job_id}.jpg"
                img_path = os.path.join(HISTORY_DIR, img_filename)
                image.save(img_path)
                add_to_history(
                    img_path, 
                    text, 
                    metadata["n_masks"], 
                    metadata["scores"], 
                    results_store[job_id]["timestamp"],
                    metadata["crop_paths"]
                )
                
        except Exception as e:
            results_store[job_id] = {
                "status": "failed",
                "error": str(e),
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            }
        
        job_queue.task_done()

# Start background worker
worker_thread = Thread(target=background_worker, daemon=True)
worker_thread.start()

def submit_background_job(image, text, threshold, mask_threshold):
    """Submit a job to background queue"""
    global job_counter
    if image is None or not text.strip():
        return "❌ Please provide image and text prompt.", ""
    
    job_counter += 1
    job_id = job_counter
    
    job_queue.put((job_id, image, text, threshold, mask_threshold))
    results_store[job_id] = {"status": "processing"}
    
    return f"✅ Job #{job_id} submitted to background queue.", f"{job_id}"

def check_background_job(job_id_str):
    """Check status of background job"""
    if not job_id_str.strip():
        return "❌ Please enter a job ID.", None, []
    
    try:
        job_id = int(job_id_str)
        if job_id not in results_store:
            return f"❌ Job #{job_id} not found.", None, []
        
        job_data = results_store[job_id]
        status = job_data["status"]
        
        if status == "processing":
            return f"⏳ Job #{job_id} is still processing...", None, []
        elif status == "completed":
            return (
                f"✅ Job #{job_id} completed!\n{job_data['info']}", 
                job_data["result"],
                job_data.get("cropped_images", [])
            )
        else:
            return f"❌ Job #{job_id} failed: {job_data.get('error', 'Unknown error')}", None, []
            
    except ValueError:
        return "❌ Invalid job ID format.", None, []

def load_history_display():
    """Load and format history for display"""
    history = load_history()
    if not history:
        return "📝 No history yet. Start segmenting images!"
    
    display = "## Segmentation History\n\n"
    for entry in history[:20]:  # Show last 20
        display += f"**#{entry['id']}** - {entry['timestamp']}\n"
        display += f"- Prompt: `{entry['prompt']}`\n"
        display += f"- Found: {entry['n_masks']} objects\n"
        display += f"- Cropped: {len(entry.get('crop_paths', []))} images\n"
        display += f"- Top scores: {', '.join([f'{s:.2f}' for s in entry['scores'][:3]])}\n\n"
    
    return display

def load_history_item(item_id):
    """Load a specific history item with cropped images"""
    history = load_history()
    for entry in history:
        if entry['id'] == int(item_id):
            info = f"**History item #{entry['id']}**\n"
            info += f"Timestamp: {entry['timestamp']}\n"
            info += f"Prompt: `{entry['prompt']}`\n"
            info += f"Objects found: {entry['n_masks']}\n"
            info += f"Cropped images: {len(entry.get('crop_paths', []))}"
            
            image = None
            if os.path.exists(entry['image_path']):
                image = Image.open(entry['image_path'])
            
            # Load cropped images
            cropped_images = []
            for crop_path in entry.get('crop_paths', [])[:10]:
                if os.path.exists(crop_path):
                    cropped_images.append(Image.open(crop_path))
            
            return image, entry['prompt'], info, cropped_images
    
    return None, "", f"❌ History item #{item_id} not found", []

def clear_all():
    """Clear all inputs and outputs"""
    return None, "", None, 0.5, 0.5, "📝 Enter a prompt and click **Segment** to start.", []

def segment_example(image_path: str, prompt: str):
    """Handle example clicks"""
    if image_path.startswith("http"):
        image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
    else:
        image = Image.open(image_path).convert("RGB")
    return segment(image, prompt, 0.5, 0.5)

# Gradio Interface
with gr.Blocks(
    theme=gr.themes.Soft(),
    title="SAM3 - Promptable Concept Segmentation",
    css=".gradio-container {max-width: 1600px !important;}"
) as demo:
    gr.Markdown(
        """
        # SAM3 - Promptable Concept Segmentation (PCS)
        
        **SAM3** performs zero-shot instance segmentation using natural language prompts.
        Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks + cropped objects.
        
        Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
        """
    )
    
    with gr.Tabs():
        # Tab 1: Main Segmentation
        with gr.Tab("🎯 Segmentation"):
            gr.Markdown("### Inputs")
            with gr.Row(variant="panel"):
                image_input = gr.Image(
                    label="Input Image",
                    type="pil",
                    height=400,
                )
                image_output = gr.AnnotatedImage(
                    label="Output (Segmented Image)",
                    height=400,
                    show_legend=True,
                )
            
            with gr.Row():
                text_input = gr.Textbox(
                    label="Text Prompt",
                    placeholder="e.g., person, ear, cat, bicycle...",
                    scale=3
                )
                clear_btn = gr.Button("🔄 Clear", size="sm", variant="secondary")
            
            with gr.Row():
                thresh_slider = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.5,
                    step=0.01,
                    label="Detection Threshold",
                    info="Higher = fewer detections"
                )
                mask_thresh_slider = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.5,
                    step=0.01,
                    label="Mask Threshold",
                    info="Higher = sharper masks"
                )
            
            info_output = gr.Markdown(
                value="📝 Enter a prompt and click **Segment** to start.",
                label="Info / Results"
            )
            
            segment_btn = gr.Button("🎯 Segment Now", variant="primary", size="lg")
            
            gr.Markdown("### ✂️ Cropped Objects")
            cropped_gallery = gr.Gallery(
                label="Extracted Objects",
                columns=5,
                height=300,
                object_fit="contain"
            )
            
            gr.Examples(
                examples=[
                    ["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"],
                ],
                inputs=[image_input, text_input],
                outputs=[image_output, info_output, cropped_gallery],
                fn=segment_example,
                cache_examples=False,
            )
        
        # Tab 2: Background Processing
        with gr.Tab("⚙️ Background Processing"):
            gr.Markdown(
                """
                ### Background Job Queue
                Submit segmentation jobs that run independently in the background.
                Useful for batch processing or when you want to continue working while processing.
                """
            )
            
            with gr.Row():
                bg_image_input = gr.Image(label="Image", type="pil", height=300)
                bg_status_output = gr.Markdown("📝 Submit a job to start background processing.")
            
            with gr.Row():
                bg_text_input = gr.Textbox(label="Text Prompt", placeholder="e.g., person, car...")
                bg_job_id_output = gr.Textbox(label="Job ID", interactive=False)
            
            with gr.Row():
                bg_thresh = gr.Slider(0.0, 1.0, 0.5, 0.01, label="Detection Threshold")
                bg_mask_thresh = gr.Slider(0.0, 1.0, 0.5, 0.01, label="Mask Threshold")
            
            bg_submit_btn = gr.Button("📤 Submit Background Job", variant="primary")
            
            gr.Markdown("---")
            gr.Markdown("### Check Job Status")
            
            with gr.Row():
                check_job_id = gr.Textbox(label="Enter Job ID", placeholder="e.g., 1")
                check_btn = gr.Button("🔍 Check Status", variant="secondary")
            
            check_status_output = gr.Markdown("Enter a job ID and click Check Status.")
            check_result_output = gr.AnnotatedImage(label="Result", height=400)
            
            gr.Markdown("### Cropped Objects from Job")
            check_cropped_gallery = gr.Gallery(
                label="Extracted Objects",
                columns=5,
                height=300,
                object_fit="contain"
            )
        
        # Tab 3: History
        with gr.Tab("📚 History"):
            gr.Markdown("### Segmentation History")
            
            with gr.Row():
                refresh_history_btn = gr.Button("🔄 Refresh History", variant="secondary")
                history_item_id = gr.Textbox(label="Load History Item #", placeholder="Enter ID")
                load_history_btn = gr.Button("📂 Load Item", variant="primary")
            
            history_display = gr.Markdown(load_history_display())
            
            gr.Markdown("---")
            gr.Markdown("### Loaded History Item")
            
            with gr.Row():
                history_image = gr.Image(label="Original Image", type="pil", height=300)
                history_info = gr.Markdown("Select a history item to view.")
            
            history_prompt = gr.Textbox(label="Prompt", interactive=False)
            
            gr.Markdown("### Cropped Objects from History")
            history_cropped_gallery = gr.Gallery(
                label="Extracted Objects",
                columns=5,
                height=300,
                object_fit="contain"
            )
    
    # Event handlers
    clear_btn.click(
        fn=clear_all,
        outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output, cropped_gallery]
    )
    
    segment_btn.click(
        fn=segment,
        inputs=[image_input, text_input, thresh_slider, mask_thresh_slider],
        outputs=[image_output, info_output, cropped_gallery]
    )
    
    bg_submit_btn.click(
        fn=submit_background_job,
        inputs=[bg_image_input, bg_text_input, bg_thresh, bg_mask_thresh],
        outputs=[bg_status_output, bg_job_id_output]
    )
    
    check_btn.click(
        fn=check_background_job,
        inputs=[check_job_id],
        outputs=[check_status_output, check_result_output, check_cropped_gallery]
    )
    
    refresh_history_btn.click(
        fn=load_history_display,
        outputs=[history_display]
    )
    
    load_history_btn.click(
        fn=load_history_item,
        inputs=[history_item_id],
        outputs=[history_image, history_prompt, history_info, history_cropped_gallery]
    )
    
    gr.Markdown(
        """
        ### Notes
        - **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3)
        - Background jobs run independently and are tracked by Job ID
        - All segmented objects are automatically cropped and saved
        - Cropped images have transparent backgrounds (PNG format)
        - History is saved automatically and persists across sessions
        - GPU recommended for faster inference
        """
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)