Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved | |
| import json | |
| import os | |
| import torch | |
| from PIL import Image | |
| from sam3.model.box_ops import box_xyxy_to_xywh | |
| from sam3.train.masks_ops import rle_encode | |
| from .helpers.mask_overlap_removal import remove_overlapping_masks | |
| from .viz import visualize | |
| def sam3_inference(processor, image_path, text_prompt): | |
| """Run SAM 3 image inference with text prompts and format the outputs""" | |
| image = Image.open(image_path) | |
| orig_img_w, orig_img_h = image.size | |
| # model inference | |
| inference_state = processor.set_image(image) | |
| inference_state = processor.set_text_prompt( | |
| state=inference_state, prompt=text_prompt | |
| ) | |
| # format and assemble outputs | |
| pred_boxes_xyxy = torch.stack( | |
| [ | |
| inference_state["boxes"][:, 0] / orig_img_w, | |
| inference_state["boxes"][:, 1] / orig_img_h, | |
| inference_state["boxes"][:, 2] / orig_img_w, | |
| inference_state["boxes"][:, 3] / orig_img_h, | |
| ], | |
| dim=-1, | |
| ) # normalized in range [0, 1] | |
| pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist() | |
| pred_masks = rle_encode(inference_state["masks"].squeeze(1)) | |
| pred_masks = [m["counts"] for m in pred_masks] | |
| outputs = { | |
| "orig_img_h": orig_img_h, | |
| "orig_img_w": orig_img_w, | |
| "pred_boxes": pred_boxes_xywh, | |
| "pred_masks": pred_masks, | |
| "pred_scores": inference_state["scores"].tolist(), | |
| } | |
| return outputs | |
| def call_sam_service( | |
| sam3_processor, | |
| image_path: str, | |
| text_prompt: str, | |
| output_folder_path: str = "sam3_output", | |
| ): | |
| """ | |
| Loads an image, sends it with a text prompt to the service, | |
| saves the results, and renders the visualization. | |
| """ | |
| print(f"π Loading image '{image_path}' and sending with prompt '{text_prompt}'...") | |
| text_prompt_for_save_path = ( | |
| text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt | |
| ) | |
| os.makedirs( | |
| os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True | |
| ) | |
| output_json_path = os.path.join( | |
| output_folder_path, | |
| image_path.replace("/", "-"), | |
| rf"{text_prompt_for_save_path}.json", | |
| ) | |
| output_image_path = os.path.join( | |
| output_folder_path, | |
| image_path.replace("/", "-"), | |
| rf"{text_prompt_for_save_path}.png", | |
| ) | |
| try: | |
| # Send the image and text prompt as a multipart/form-data request | |
| serialized_response = sam3_inference(sam3_processor, image_path, text_prompt) | |
| # 1. Prepare the response dictionary | |
| serialized_response = remove_overlapping_masks(serialized_response) | |
| serialized_response = { | |
| "original_image_path": image_path, | |
| "output_image_path": output_image_path, | |
| **serialized_response, | |
| } | |
| # 2. Reorder predictions by scores (highest to lowest) if scores are available | |
| if "pred_scores" in serialized_response and serialized_response["pred_scores"]: | |
| # Create indices sorted by scores in descending order | |
| score_indices = sorted( | |
| range(len(serialized_response["pred_scores"])), | |
| key=lambda i: serialized_response["pred_scores"][i], | |
| reverse=True, | |
| ) | |
| # Reorder all three lists based on the sorted indices | |
| serialized_response["pred_scores"] = [ | |
| serialized_response["pred_scores"][i] for i in score_indices | |
| ] | |
| serialized_response["pred_boxes"] = [ | |
| serialized_response["pred_boxes"][i] for i in score_indices | |
| ] | |
| serialized_response["pred_masks"] = [ | |
| serialized_response["pred_masks"][i] for i in score_indices | |
| ] | |
| # 3. Remove any invalid RLE masks that is too short (shorter than 5 characters) | |
| valid_masks = [] | |
| valid_boxes = [] | |
| valid_scores = [] | |
| for i, rle in enumerate(serialized_response["pred_masks"]): | |
| if len(rle) > 4: | |
| valid_masks.append(rle) | |
| valid_boxes.append(serialized_response["pred_boxes"][i]) | |
| valid_scores.append(serialized_response["pred_scores"][i]) | |
| serialized_response["pred_masks"] = valid_masks | |
| serialized_response["pred_boxes"] = valid_boxes | |
| serialized_response["pred_scores"] = valid_scores | |
| with open(output_json_path, "w") as f: | |
| json.dump(serialized_response, f, indent=4) | |
| print(f"β Raw JSON response saved to '{output_json_path}'") | |
| # 4. Render and save visualizations on the image and save it in the SAM3 output folder | |
| print("π Rendering visualizations on the image ...") | |
| viz_image = visualize(serialized_response) | |
| os.makedirs(os.path.dirname(output_image_path), exist_ok=True) | |
| viz_image.save(output_image_path) | |
| print("β Saved visualization at:", output_image_path) | |
| except Exception as e: | |
| print(f"β Error calling service: {e}") | |
| return output_json_path | |