|
|
import spaces |
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import requests |
|
|
import warnings |
|
|
import json |
|
|
import os |
|
|
from datetime import datetime |
|
|
from threading import Thread |
|
|
from queue import Queue |
|
|
import time |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
from transformers import Sam3Processor, Sam3Model |
|
|
model = Sam3Model.from_pretrained("DiffusionWave/sam3", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device) |
|
|
processor = Sam3Processor.from_pretrained("DiffusionWave/sam3") |
|
|
|
|
|
|
|
|
job_queue = Queue() |
|
|
results_store = {} |
|
|
job_counter = 0 |
|
|
|
|
|
|
|
|
HISTORY_DIR = "segmentation_history" |
|
|
HISTORY_FILE = os.path.join(HISTORY_DIR, "history.json") |
|
|
CROPS_DIR = os.path.join(HISTORY_DIR, "crops") |
|
|
os.makedirs(HISTORY_DIR, exist_ok=True) |
|
|
os.makedirs(CROPS_DIR, exist_ok=True) |
|
|
|
|
|
def load_history(): |
|
|
"""Load segmentation history from file""" |
|
|
if os.path.exists(HISTORY_FILE): |
|
|
try: |
|
|
with open(HISTORY_FILE, 'r') as f: |
|
|
return json.load(f) |
|
|
except: |
|
|
return [] |
|
|
return [] |
|
|
|
|
|
def save_history(history): |
|
|
"""Save segmentation history to file""" |
|
|
with open(HISTORY_FILE, 'w') as f: |
|
|
json.dump(history, f, indent=2) |
|
|
|
|
|
def crop_segmented_objects(image: Image.Image, masks, text: str, timestamp: str): |
|
|
""" |
|
|
Crop individual objects from masks and save them |
|
|
Returns list of cropped image paths |
|
|
""" |
|
|
cropped_paths = [] |
|
|
image_np = np.array(image) |
|
|
|
|
|
for i, mask in enumerate(masks): |
|
|
|
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = mask |
|
|
|
|
|
|
|
|
rows = np.any(mask_np, axis=1) |
|
|
cols = np.any(mask_np, axis=0) |
|
|
|
|
|
if not rows.any() or not cols.any(): |
|
|
continue |
|
|
|
|
|
y_min, y_max = np.where(rows)[0][[0, -1]] |
|
|
x_min, x_max = np.where(cols)[0][[0, -1]] |
|
|
|
|
|
|
|
|
padding = 10 |
|
|
y_min = max(0, y_min - padding) |
|
|
y_max = min(image_np.shape[0], y_max + padding) |
|
|
x_min = max(0, x_min - padding) |
|
|
x_max = min(image_np.shape[1], x_max + padding) |
|
|
|
|
|
|
|
|
cropped = image_np[y_min:y_max, x_min:x_max] |
|
|
|
|
|
|
|
|
mask_crop = mask_np[y_min:y_max, x_min:x_max] |
|
|
|
|
|
|
|
|
cropped_rgba = np.zeros((*cropped.shape[:2], 4), dtype=np.uint8) |
|
|
cropped_rgba[:, :, :3] = cropped |
|
|
cropped_rgba[:, :, 3] = (mask_crop * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
crop_filename = f"crop_{timestamp.replace(':', '-').replace(' ', '_')}_{text}_{i+1}.png" |
|
|
crop_path = os.path.join(CROPS_DIR, crop_filename) |
|
|
Image.fromarray(cropped_rgba).save(crop_path) |
|
|
cropped_paths.append(crop_path) |
|
|
|
|
|
return cropped_paths |
|
|
|
|
|
def add_to_history(image_path, prompt, n_masks, scores, timestamp, crop_paths): |
|
|
"""Add a new entry to history""" |
|
|
history = load_history() |
|
|
entry = { |
|
|
"id": len(history) + 1, |
|
|
"timestamp": timestamp, |
|
|
"image_path": image_path, |
|
|
"prompt": prompt, |
|
|
"n_masks": n_masks, |
|
|
"scores": scores, |
|
|
"crop_paths": crop_paths |
|
|
} |
|
|
history.insert(0, entry) |
|
|
|
|
|
history = history[:100] |
|
|
save_history(history) |
|
|
return history |
|
|
|
|
|
@spaces.GPU() |
|
|
def segment_core(image: Image.Image, text: str, threshold: float, mask_threshold: float, save_crops: bool = True): |
|
|
""" |
|
|
Core segmentation function - can be called independently |
|
|
""" |
|
|
if image is None: |
|
|
return None, "⌠Please upload an image.", None, [] |
|
|
|
|
|
if not text.strip(): |
|
|
return (image, []), "⌠Please enter a text prompt.", None, [] |
|
|
|
|
|
try: |
|
|
inputs = processor(images=image, text=text.strip(), return_tensors="pt").to(device) |
|
|
|
|
|
for key in inputs: |
|
|
if inputs[key].dtype == torch.float32: |
|
|
inputs[key] = inputs[key].to(model.dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
results = processor.post_process_instance_segmentation( |
|
|
outputs, |
|
|
threshold=threshold, |
|
|
mask_threshold=mask_threshold, |
|
|
target_sizes=inputs.get("original_sizes").tolist() |
|
|
)[0] |
|
|
|
|
|
n_masks = len(results['masks']) |
|
|
if n_masks == 0: |
|
|
return (image, []), f"⌠No objects found matching '{text}' (try adjusting thresholds).", None, [] |
|
|
|
|
|
|
|
|
annotations = [] |
|
|
for i, (mask, score) in enumerate(zip(results['masks'], results['scores'])): |
|
|
mask_np = mask.cpu().numpy().astype(np.float32) |
|
|
label = f"{text} #{i+1} ({score:.2f})" |
|
|
annotations.append((mask_np, label)) |
|
|
|
|
|
scores_list = results['scores'].cpu().numpy().tolist() |
|
|
scores_text = ", ".join([f"{s:.2f}" for s in scores_list[:5]]) |
|
|
info = f"✅ Found **{n_masks}** objects matching **'{text}'**\n" |
|
|
info += f"Confidence scores: {scores_text}{'...' if n_masks > 5 else ''}\n" |
|
|
|
|
|
|
|
|
cropped_images = [] |
|
|
if save_crops: |
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
crop_paths = crop_segmented_objects(image, results['masks'], text, timestamp) |
|
|
info += f"âœ‚ï¸ Extracted **{len(crop_paths)}** cropped objects" |
|
|
|
|
|
|
|
|
for path in crop_paths[:10]: |
|
|
if os.path.exists(path): |
|
|
cropped_images.append(Image.open(path)) |
|
|
else: |
|
|
crop_paths = [] |
|
|
|
|
|
metadata = { |
|
|
"n_masks": n_masks, |
|
|
"scores": scores_list, |
|
|
"crop_paths": crop_paths, |
|
|
"masks": results['masks'] |
|
|
} |
|
|
|
|
|
return (image, annotations), info, metadata, cropped_images |
|
|
|
|
|
except Exception as e: |
|
|
return (image, []), f"⌠Error during segmentation: {str(e)}", None, [] |
|
|
|
|
|
def segment(image: Image.Image, text: str, threshold: float, mask_threshold: float): |
|
|
""" |
|
|
Frontend segment function - with history saving |
|
|
""" |
|
|
result, info, metadata, cropped_images = segment_core(image, text, threshold, mask_threshold, save_crops=True) |
|
|
|
|
|
|
|
|
if metadata and metadata["n_masks"] > 0: |
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
img_filename = f"img_{int(time.time())}.jpg" |
|
|
img_path = os.path.join(HISTORY_DIR, img_filename) |
|
|
image.save(img_path) |
|
|
|
|
|
add_to_history( |
|
|
img_path, |
|
|
text, |
|
|
metadata["n_masks"], |
|
|
metadata["scores"], |
|
|
timestamp, |
|
|
metadata["crop_paths"] |
|
|
) |
|
|
|
|
|
return result, info, cropped_images |
|
|
|
|
|
def background_worker(): |
|
|
""" |
|
|
Background worker thread - processes jobs independently |
|
|
""" |
|
|
while True: |
|
|
job = job_queue.get() |
|
|
if job is None: |
|
|
break |
|
|
|
|
|
job_id, image, text, threshold, mask_threshold = job |
|
|
|
|
|
try: |
|
|
result, info, metadata, cropped_images = segment_core(image, text, threshold, mask_threshold, save_crops=True) |
|
|
results_store[job_id] = { |
|
|
"status": "completed", |
|
|
"result": result, |
|
|
"info": info, |
|
|
"metadata": metadata, |
|
|
"cropped_images": cropped_images, |
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
} |
|
|
|
|
|
|
|
|
if metadata and metadata["n_masks"] > 0: |
|
|
img_filename = f"bg_img_{job_id}.jpg" |
|
|
img_path = os.path.join(HISTORY_DIR, img_filename) |
|
|
image.save(img_path) |
|
|
add_to_history( |
|
|
img_path, |
|
|
text, |
|
|
metadata["n_masks"], |
|
|
metadata["scores"], |
|
|
results_store[job_id]["timestamp"], |
|
|
metadata["crop_paths"] |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
results_store[job_id] = { |
|
|
"status": "failed", |
|
|
"error": str(e), |
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
} |
|
|
|
|
|
job_queue.task_done() |
|
|
|
|
|
|
|
|
worker_thread = Thread(target=background_worker, daemon=True) |
|
|
worker_thread.start() |
|
|
|
|
|
def submit_background_job(image, text, threshold, mask_threshold): |
|
|
"""Submit a job to background queue""" |
|
|
global job_counter |
|
|
if image is None or not text.strip(): |
|
|
return "⌠Please provide image and text prompt.", "" |
|
|
|
|
|
job_counter += 1 |
|
|
job_id = job_counter |
|
|
|
|
|
job_queue.put((job_id, image, text, threshold, mask_threshold)) |
|
|
results_store[job_id] = {"status": "processing"} |
|
|
|
|
|
return f"✅ Job #{job_id} submitted to background queue.", f"{job_id}" |
|
|
|
|
|
def check_background_job(job_id_str): |
|
|
"""Check status of background job""" |
|
|
if not job_id_str.strip(): |
|
|
return "⌠Please enter a job ID.", None, [] |
|
|
|
|
|
try: |
|
|
job_id = int(job_id_str) |
|
|
if job_id not in results_store: |
|
|
return f"⌠Job #{job_id} not found.", None, [] |
|
|
|
|
|
job_data = results_store[job_id] |
|
|
status = job_data["status"] |
|
|
|
|
|
if status == "processing": |
|
|
return f"â³ Job #{job_id} is still processing...", None, [] |
|
|
elif status == "completed": |
|
|
return ( |
|
|
f"✅ Job #{job_id} completed!\n{job_data['info']}", |
|
|
job_data["result"], |
|
|
job_data.get("cropped_images", []) |
|
|
) |
|
|
else: |
|
|
return f"⌠Job #{job_id} failed: {job_data.get('error', 'Unknown error')}", None, [] |
|
|
|
|
|
except ValueError: |
|
|
return "⌠Invalid job ID format.", None, [] |
|
|
|
|
|
def load_history_display(): |
|
|
"""Load and format history for display""" |
|
|
history = load_history() |
|
|
if not history: |
|
|
return "📠No history yet. Start segmenting images!" |
|
|
|
|
|
display = "## Segmentation History\n\n" |
|
|
for entry in history[:20]: |
|
|
display += f"**#{entry['id']}** - {entry['timestamp']}\n" |
|
|
display += f"- Prompt: `{entry['prompt']}`\n" |
|
|
display += f"- Found: {entry['n_masks']} objects\n" |
|
|
display += f"- Cropped: {len(entry.get('crop_paths', []))} images\n" |
|
|
display += f"- Top scores: {', '.join([f'{s:.2f}' for s in entry['scores'][:3]])}\n\n" |
|
|
|
|
|
return display |
|
|
|
|
|
def load_history_item(item_id): |
|
|
"""Load a specific history item with cropped images""" |
|
|
history = load_history() |
|
|
for entry in history: |
|
|
if entry['id'] == int(item_id): |
|
|
info = f"**History item #{entry['id']}**\n" |
|
|
info += f"Timestamp: {entry['timestamp']}\n" |
|
|
info += f"Prompt: `{entry['prompt']}`\n" |
|
|
info += f"Objects found: {entry['n_masks']}\n" |
|
|
info += f"Cropped images: {len(entry.get('crop_paths', []))}" |
|
|
|
|
|
image = None |
|
|
if os.path.exists(entry['image_path']): |
|
|
image = Image.open(entry['image_path']) |
|
|
|
|
|
|
|
|
cropped_images = [] |
|
|
for crop_path in entry.get('crop_paths', [])[:10]: |
|
|
if os.path.exists(crop_path): |
|
|
cropped_images.append(Image.open(crop_path)) |
|
|
|
|
|
return image, entry['prompt'], info, cropped_images |
|
|
|
|
|
return None, "", f"⌠History item #{item_id} not found", [] |
|
|
|
|
|
def clear_all(): |
|
|
"""Clear all inputs and outputs""" |
|
|
return None, "", None, 0.5, 0.5, "📠Enter a prompt and click **Segment** to start.", [] |
|
|
|
|
|
def segment_example(image_path: str, prompt: str): |
|
|
"""Handle example clicks""" |
|
|
if image_path.startswith("http"): |
|
|
image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB") |
|
|
else: |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
return segment(image, prompt, 0.5, 0.5) |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
theme=gr.themes.Soft(), |
|
|
title="SAM3 - Promptable Concept Segmentation", |
|
|
css=".gradio-container {max-width: 1600px !important;}" |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# SAM3 - Promptable Concept Segmentation (PCS) |
|
|
|
|
|
**SAM3** performs zero-shot instance segmentation using natural language prompts. |
|
|
Upload an image, enter a text prompt (e.g., "person", "car", "dog"), and get segmentation masks + cropped objects. |
|
|
|
|
|
Built with [anycoder](https://huggingface.co/spaces/akhaliq/anycoder) |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("🎯 Segmentation"): |
|
|
gr.Markdown("### Inputs") |
|
|
with gr.Row(variant="panel"): |
|
|
image_input = gr.Image( |
|
|
label="Input Image", |
|
|
type="pil", |
|
|
height=400, |
|
|
) |
|
|
image_output = gr.AnnotatedImage( |
|
|
label="Output (Segmented Image)", |
|
|
height=400, |
|
|
show_legend=True, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
text_input = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
placeholder="e.g., person, ear, cat, bicycle...", |
|
|
scale=3 |
|
|
) |
|
|
clear_btn = gr.Button("🔄 Clear", size="sm", variant="secondary") |
|
|
|
|
|
with gr.Row(): |
|
|
thresh_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.01, |
|
|
label="Detection Threshold", |
|
|
info="Higher = fewer detections" |
|
|
) |
|
|
mask_thresh_slider = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.01, |
|
|
label="Mask Threshold", |
|
|
info="Higher = sharper masks" |
|
|
) |
|
|
|
|
|
info_output = gr.Markdown( |
|
|
value="📠Enter a prompt and click **Segment** to start.", |
|
|
label="Info / Results" |
|
|
) |
|
|
|
|
|
segment_btn = gr.Button("🎯 Segment Now", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown("### âœ‚ï¸ Cropped Objects") |
|
|
cropped_gallery = gr.Gallery( |
|
|
label="Extracted Objects", |
|
|
columns=5, |
|
|
height=300, |
|
|
object_fit="contain" |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["http://images.cocodataset.org/val2017/000000077595.jpg", "cat"], |
|
|
], |
|
|
inputs=[image_input, text_input], |
|
|
outputs=[image_output, info_output, cropped_gallery], |
|
|
fn=segment_example, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("âš™ï¸ Background Processing"): |
|
|
gr.Markdown( |
|
|
""" |
|
|
### Background Job Queue |
|
|
Submit segmentation jobs that run independently in the background. |
|
|
Useful for batch processing or when you want to continue working while processing. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
bg_image_input = gr.Image(label="Image", type="pil", height=300) |
|
|
bg_status_output = gr.Markdown("📠Submit a job to start background processing.") |
|
|
|
|
|
with gr.Row(): |
|
|
bg_text_input = gr.Textbox(label="Text Prompt", placeholder="e.g., person, car...") |
|
|
bg_job_id_output = gr.Textbox(label="Job ID", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
bg_thresh = gr.Slider(0.0, 1.0, 0.5, 0.01, label="Detection Threshold") |
|
|
bg_mask_thresh = gr.Slider(0.0, 1.0, 0.5, 0.01, label="Mask Threshold") |
|
|
|
|
|
bg_submit_btn = gr.Button("📤 Submit Background Job", variant="primary") |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Check Job Status") |
|
|
|
|
|
with gr.Row(): |
|
|
check_job_id = gr.Textbox(label="Enter Job ID", placeholder="e.g., 1") |
|
|
check_btn = gr.Button("🔠Check Status", variant="secondary") |
|
|
|
|
|
check_status_output = gr.Markdown("Enter a job ID and click Check Status.") |
|
|
check_result_output = gr.AnnotatedImage(label="Result", height=400) |
|
|
|
|
|
gr.Markdown("### Cropped Objects from Job") |
|
|
check_cropped_gallery = gr.Gallery( |
|
|
label="Extracted Objects", |
|
|
columns=5, |
|
|
height=300, |
|
|
object_fit="contain" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("📚 History"): |
|
|
gr.Markdown("### Segmentation History") |
|
|
|
|
|
with gr.Row(): |
|
|
refresh_history_btn = gr.Button("🔄 Refresh History", variant="secondary") |
|
|
history_item_id = gr.Textbox(label="Load History Item #", placeholder="Enter ID") |
|
|
load_history_btn = gr.Button("📂 Load Item", variant="primary") |
|
|
|
|
|
history_display = gr.Markdown(load_history_display()) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Loaded History Item") |
|
|
|
|
|
with gr.Row(): |
|
|
history_image = gr.Image(label="Original Image", type="pil", height=300) |
|
|
history_info = gr.Markdown("Select a history item to view.") |
|
|
|
|
|
history_prompt = gr.Textbox(label="Prompt", interactive=False) |
|
|
|
|
|
gr.Markdown("### Cropped Objects from History") |
|
|
history_cropped_gallery = gr.Gallery( |
|
|
label="Extracted Objects", |
|
|
columns=5, |
|
|
height=300, |
|
|
object_fit="contain" |
|
|
) |
|
|
|
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_all, |
|
|
outputs=[image_input, text_input, image_output, thresh_slider, mask_thresh_slider, info_output, cropped_gallery] |
|
|
) |
|
|
|
|
|
segment_btn.click( |
|
|
fn=segment, |
|
|
inputs=[image_input, text_input, thresh_slider, mask_thresh_slider], |
|
|
outputs=[image_output, info_output, cropped_gallery] |
|
|
) |
|
|
|
|
|
bg_submit_btn.click( |
|
|
fn=submit_background_job, |
|
|
inputs=[bg_image_input, bg_text_input, bg_thresh, bg_mask_thresh], |
|
|
outputs=[bg_status_output, bg_job_id_output] |
|
|
) |
|
|
|
|
|
check_btn.click( |
|
|
fn=check_background_job, |
|
|
inputs=[check_job_id], |
|
|
outputs=[check_status_output, check_result_output, check_cropped_gallery] |
|
|
) |
|
|
|
|
|
refresh_history_btn.click( |
|
|
fn=load_history_display, |
|
|
outputs=[history_display] |
|
|
) |
|
|
|
|
|
load_history_btn.click( |
|
|
fn=load_history_item, |
|
|
inputs=[history_item_id], |
|
|
outputs=[history_image, history_prompt, history_info, history_cropped_gallery] |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### Notes |
|
|
- **Model**: [facebook/sam3](https://huggingface.co/facebook/sam3) |
|
|
- Background jobs run independently and are tracked by Job ID |
|
|
- All segmented objects are automatically cropped and saved |
|
|
- Cropped images have transparent backgrounds (PNG format) |
|
|
- History is saved automatically and persists across sessions |
|
|
- GPU recommended for faster inference |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True) |