File size: 37,532 Bytes
14114e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import os
from copy import deepcopy
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from sam3.model.model_misc import SAM3Output
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
from sam3.model.vl_combiner import SAM3VLBackbone
from sam3.perflib.nms import nms_masks
from sam3.train.data.collator import BatchedDatapoint
from .act_ckpt_utils import activation_ckpt_wrapper
from .box_ops import box_cxcywh_to_xyxy
from .geometry_encoders import Prompt
from .model_misc import inverse_sigmoid
def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
out[out_name] = out_value[-1] if auxiliary else out_value
if auxiliary and update_aux:
if "aux_outputs" not in out:
out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
assert len(out["aux_outputs"]) == len(out_value) - 1
for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
aux_output[out_name] = aux_value
class Sam3Image(torch.nn.Module):
TEXT_ID_FOR_TEXT = 0
TEXT_ID_FOR_VISUAL = 1
TEXT_ID_FOR_GEOMETRIC = 2
def __init__(
self,
backbone: SAM3VLBackbone,
transformer,
input_geometry_encoder,
segmentation_head=None,
num_feature_levels=1,
o2m_mask_predict=True,
dot_prod_scoring=None,
use_instance_query: bool = True,
multimask_output: bool = True,
use_act_checkpoint_seg_head: bool = True,
interactivity_in_encoder: bool = True,
matcher=None,
use_dot_prod_scoring=True,
supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
separate_scorer_for_instance: bool = False,
num_interactive_steps_val: int = 0,
inst_interactive_predictor: SAM3InteractiveImagePredictor = None,
**kwargs,
):
super().__init__()
self.backbone = backbone
self.geometry_encoder = input_geometry_encoder
self.transformer = transformer
self.hidden_dim = transformer.d_model
self.num_feature_levels = num_feature_levels
self.segmentation_head = segmentation_head
self.o2m_mask_predict = o2m_mask_predict
self.dot_prod_scoring = dot_prod_scoring
self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
self.interactivity_in_encoder = interactivity_in_encoder
self.matcher = matcher
self.num_interactive_steps_val = num_interactive_steps_val
self.use_dot_prod_scoring = use_dot_prod_scoring
if self.use_dot_prod_scoring:
assert dot_prod_scoring is not None
self.dot_prod_scoring = dot_prod_scoring
self.instance_dot_prod_scoring = None
if separate_scorer_for_instance:
self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
else:
self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
self.instance_class_embed = None
if separate_scorer_for_instance:
self.instance_class_embed = deepcopy(self.class_embed)
self.supervise_joint_box_scores = supervise_joint_box_scores
self.detach_presence_in_joint_score = detach_presence_in_joint_score
# verify the number of queries for O2O and O2M
num_o2o_static = self.transformer.decoder.num_queries
num_o2m_static = self.transformer.decoder.num_o2m_queries
assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
self.dac = self.transformer.decoder.dac
self.use_instance_query = use_instance_query
self.multimask_output = multimask_output
self.inst_interactive_predictor = inst_interactive_predictor
@property
def device(self):
self._device = getattr(self, "_device", None) or next(self.parameters()).device
return self._device
def to(self, *args, **kwargs):
# clear cached _device in case the model is moved to a different device
self._device = None
return super().to(*args, **kwargs)
def _get_img_feats(self, backbone_out, img_ids):
"""Retrieve correct image features from backbone output."""
if "backbone_fpn" in backbone_out:
if "id_mapping" in backbone_out and backbone_out["id_mapping"] is not None:
img_ids = backbone_out["id_mapping"][img_ids]
# If this assert fails, it likely means we're requesting different img_ids (perhaps a different frame?)
# We currently don't expect this to happen. We could technically trigger a recompute here,
# but likely at the cost of a cpu<->gpu sync point, which would deteriorate perf
torch._assert_async((img_ids >= 0).all())
vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :]
vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
vis_feat_sizes = [x.shape[-2:] for x in vis_pos_enc] # (H, W) shapes
# index and flatten visual features NxCxHxW => HWxNxC (batch-first => seq-first)
img_feats = [x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_feats]
img_pos_embeds = [
x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_pos_enc
]
return backbone_out, img_feats, img_pos_embeds, vis_feat_sizes
# Image features not available in backbone output, so we compute them on the fly
# This case likely occurs for video. In that case, we want to forward only the current frame
img_batch = backbone_out["img_batch_all_stages"]
if img_ids.numel() > 1:
# Only forward backbone on unique image ids to avoid repetitive computation
unique_ids, _ = torch.unique(img_ids, return_inverse=True)
else:
unique_ids, _ = img_ids, slice(None)
# Compute the image features on those unique image ids
# note: we allow using a list (or other indexable types) of tensors as img_batch
# (e.g. for async frame loading in demo). In this case we index img_batch.tensors directly
if isinstance(img_batch, torch.Tensor):
image = img_batch[unique_ids]
elif unique_ids.numel() == 1:
image = img_batch[unique_ids.item()].unsqueeze(0)
else:
image = torch.stack([img_batch[i] for i in unique_ids.tolist()])
# `img_batch` might be fp16 and offloaded to CPU
image = image.to(dtype=torch.float32, device=self.device)
# Next time we call this function, we want to remember which indices we computed
id_mapping = torch.full(
(len(img_batch),), -1, dtype=torch.long, device=self.device
)
id_mapping[unique_ids] = torch.arange(len(unique_ids), device=self.device)
backbone_out = {
**backbone_out,
**self.backbone.forward_image(image),
"id_mapping": id_mapping,
}
assert "backbone_fpn" in backbone_out
return self._get_img_feats(backbone_out, img_ids=img_ids)
def _encode_prompt(
self,
backbone_out,
find_input,
geometric_prompt,
visual_prompt_embed=None,
visual_prompt_mask=None,
encode_text=True,
prev_mask_pred=None,
):
# index text features (note that regardless of early or late fusion, the batch size of
# `txt_feats` is always the number of *prompts* in the encoder)
txt_ids = find_input.text_ids
txt_feats = backbone_out["language_features"][:, txt_ids]
txt_masks = backbone_out["language_mask"][txt_ids]
feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids)
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple
if prev_mask_pred is not None:
img_feats = [img_feats[-1] + prev_mask_pred]
# Encode geometry
geo_feats, geo_masks = self.geometry_encoder(
geo_prompt=geometric_prompt,
img_feats=img_feats,
img_sizes=vis_feat_sizes,
img_pos_embeds=img_pos_embeds,
)
if visual_prompt_embed is None:
visual_prompt_embed = torch.zeros(
(0, *geo_feats.shape[1:]), device=geo_feats.device
)
visual_prompt_mask = torch.zeros(
(*geo_masks.shape[:-1], 0),
device=geo_masks.device,
dtype=geo_masks.dtype,
)
if encode_text:
prompt = torch.cat([txt_feats, geo_feats, visual_prompt_embed], dim=0)
prompt_mask = torch.cat([txt_masks, geo_masks, visual_prompt_mask], dim=1)
else:
prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
return prompt, prompt_mask, backbone_out
def _run_encoder(
self,
backbone_out,
find_input,
prompt,
prompt_mask,
encoder_extra_kwargs: Optional[Dict] = None,
):
feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids)
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple
# Run the encoder
prompt_pos_embed = torch.zeros_like(prompt)
# make a copy of the image feature lists since the encoder may modify these lists in-place
memory = self.transformer.encoder(
src=img_feats.copy(),
src_key_padding_mask=None,
src_pos=img_pos_embeds.copy(),
prompt=prompt,
prompt_pos=prompt_pos_embed,
prompt_key_padding_mask=prompt_mask,
feat_sizes=vis_feat_sizes,
encoder_extra_kwargs=encoder_extra_kwargs,
)
encoder_out = {
# encoded image features
"encoder_hidden_states": memory["memory"],
"pos_embed": memory["pos_embed"],
"padding_mask": memory["padding_mask"],
"level_start_index": memory["level_start_index"],
"spatial_shapes": memory["spatial_shapes"],
"valid_ratios": memory["valid_ratios"],
"vis_feat_sizes": vis_feat_sizes,
# encoded text features (or other prompts)
"prompt_before_enc": prompt,
"prompt_after_enc": memory.get("memory_text", prompt),
"prompt_mask": prompt_mask,
}
return backbone_out, encoder_out, feat_tuple
def _run_decoder(
self,
pos_embed,
memory,
src_mask,
out,
prompt,
prompt_mask,
encoder_out,
):
bs = memory.shape[1]
query_embed = self.transformer.decoder.query_embed.weight
tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
apply_dac = self.transformer.decoder.dac and self.training
hs, reference_boxes, dec_presence_out, dec_presence_feats = (
self.transformer.decoder(
tgt=tgt,
memory=memory,
memory_key_padding_mask=src_mask,
pos=pos_embed,
reference_boxes=None,
level_start_index=encoder_out["level_start_index"],
spatial_shapes=encoder_out["spatial_shapes"],
valid_ratios=encoder_out["valid_ratios"],
tgt_mask=None,
memory_text=prompt,
text_attention_mask=prompt_mask,
apply_dac=apply_dac,
)
)
hs = hs.transpose(1, 2) # seq-first to batch-first
reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
if dec_presence_out is not None:
# seq-first to batch-first
dec_presence_out = dec_presence_out.transpose(1, 2)
out["presence_feats"] = dec_presence_feats
self._update_scores_and_boxes(
out,
hs,
reference_boxes,
prompt,
prompt_mask,
dec_presence_out=dec_presence_out,
)
return out, hs
def _update_scores_and_boxes(
self,
out,
hs,
reference_boxes,
prompt,
prompt_mask,
dec_presence_out=None,
is_instance_prompt=False,
):
apply_dac = self.transformer.decoder.dac and self.training
num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2)
num_o2m = hs.size(2) - num_o2o
assert num_o2m == (num_o2o if apply_dac else 0)
out["queries"] = hs[-1][:, :num_o2o] # remove o2m queries if there are any
# score prediction
if self.use_dot_prod_scoring:
dot_prod_scoring_head = self.dot_prod_scoring
if is_instance_prompt and self.instance_dot_prod_scoring is not None:
dot_prod_scoring_head = self.instance_dot_prod_scoring
outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
else:
class_embed_head = self.class_embed
if is_instance_prompt and self.instance_class_embed is not None:
class_embed_head = self.instance_class_embed
outputs_class = class_embed_head(hs)
# box prediction
box_head = self.transformer.decoder.bbox_embed
if (
is_instance_prompt
and self.transformer.decoder.instance_bbox_embed is not None
):
box_head = self.transformer.decoder.instance_bbox_embed
anchor_box_offsets = box_head(hs)
reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
outputs_boxes_xyxy = box_cxcywh_to_xyxy(outputs_coord)
if dec_presence_out is not None:
_update_out(
out, "presence_logit_dec", dec_presence_out, update_aux=self.training
)
if self.supervise_joint_box_scores:
assert dec_presence_out is not None
prob_dec_presence_out = dec_presence_out.clone().sigmoid()
if self.detach_presence_in_joint_score:
prob_dec_presence_out = prob_dec_presence_out.detach()
outputs_class = inverse_sigmoid(
outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)
).clamp(min=-10.0, max=10.0)
_update_out(
out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=self.training
)
_update_out(
out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=self.training
)
_update_out(
out,
"pred_boxes_xyxy",
outputs_boxes_xyxy[:, :, :num_o2o],
update_aux=self.training,
)
if num_o2m > 0 and self.training:
_update_out(
out,
"pred_logits_o2m",
outputs_class[:, :, num_o2o:],
update_aux=self.training,
)
_update_out(
out,
"pred_boxes_o2m",
outputs_coord[:, :, num_o2o:],
update_aux=self.training,
)
_update_out(
out,
"pred_boxes_xyxy_o2m",
outputs_boxes_xyxy[:, :, num_o2o:],
update_aux=self.training,
)
def _run_segmentation_heads(
self,
out,
backbone_out,
img_ids,
vis_feat_sizes,
encoder_hidden_states,
prompt,
prompt_mask,
hs,
):
apply_dac = self.transformer.decoder.dac and self.training
if self.segmentation_head is not None:
num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2)
num_o2m = hs.size(2) - num_o2o
obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
seg_head_outputs = activation_ckpt_wrapper(self.segmentation_head)(
backbone_feats=backbone_out["backbone_fpn"],
obj_queries=obj_queries,
image_ids=img_ids,
encoder_hidden_states=encoder_hidden_states,
act_ckpt_enable=self.training and self.use_act_checkpoint_seg_head,
prompt=prompt,
prompt_mask=prompt_mask,
)
aux_masks = False # self.aux_loss and self.segmentation_head.aux_masks
for k, v in seg_head_outputs.items():
if k in self.segmentation_head.instance_keys:
_update_out(out, k, v[:, :num_o2o], auxiliary=aux_masks)
if (
self.o2m_mask_predict and num_o2m > 0
): # handle o2m mask prediction
_update_out(
out, f"{k}_o2m", v[:, num_o2o:], auxiliary=aux_masks
)
else:
out[k] = v
else:
backbone_out.pop("backbone_fpn", None)
def _get_best_mask(self, out):
prev_mask_idx = out["pred_logits"].argmax(dim=1).squeeze(1)
batch_idx = torch.arange(
out["pred_logits"].shape[0], device=prev_mask_idx.device
)
prev_mask_pred = out["pred_masks"][batch_idx, prev_mask_idx][:, None]
# Downsample mask to match image resolution.
prev_mask_pred = self.geometry_encoder.mask_encoder.mask_downsampler(
prev_mask_pred
)
prev_mask_pred = prev_mask_pred.flatten(-2).permute(2, 0, 1)
return prev_mask_pred
def forward_grounding(
self,
backbone_out,
find_input,
find_target,
geometric_prompt: Prompt,
):
with torch.profiler.record_function("SAM3Image._encode_prompt"):
prompt, prompt_mask, backbone_out = self._encode_prompt(
backbone_out, find_input, geometric_prompt
)
# Run the encoder
with torch.profiler.record_function("SAM3Image._run_encoder"):
backbone_out, encoder_out, _ = self._run_encoder(
backbone_out, find_input, prompt, prompt_mask
)
out = {
"encoder_hidden_states": encoder_out["encoder_hidden_states"],
"prev_encoder_out": {
"encoder_out": encoder_out,
"backbone_out": backbone_out,
},
}
# Run the decoder
with torch.profiler.record_function("SAM3Image._run_decoder"):
out, hs = self._run_decoder(
memory=out["encoder_hidden_states"],
pos_embed=encoder_out["pos_embed"],
src_mask=encoder_out["padding_mask"],
out=out,
prompt=prompt,
prompt_mask=prompt_mask,
encoder_out=encoder_out,
)
# Run segmentation heads
with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
self._run_segmentation_heads(
out=out,
backbone_out=backbone_out,
img_ids=find_input.img_ids,
vis_feat_sizes=encoder_out["vis_feat_sizes"],
encoder_hidden_states=out["encoder_hidden_states"],
prompt=prompt,
prompt_mask=prompt_mask,
hs=hs,
)
if self.training or self.num_interactive_steps_val > 0:
self._compute_matching(out, self.back_convert(find_target))
return out
def _postprocess_out(self, out: Dict, multimask_output: bool = False):
# For multimask output, during eval we return the single best mask with the dict keys expected by the evaluators, but also return the multimasks output with new keys.
num_mask_boxes = out["pred_boxes"].size(1)
if not self.training and multimask_output and num_mask_boxes > 1:
out["multi_pred_logits"] = out["pred_logits"]
if "pred_masks" in out:
out["multi_pred_masks"] = out["pred_masks"]
out["multi_pred_boxes"] = out["pred_boxes"]
out["multi_pred_boxes_xyxy"] = out["pred_boxes_xyxy"]
best_mask_idx = out["pred_logits"].argmax(1).squeeze(1)
batch_idx = torch.arange(len(best_mask_idx), device=best_mask_idx.device)
out["pred_logits"] = out["pred_logits"][batch_idx, best_mask_idx].unsqueeze(
1
)
if "pred_masks" in out:
out["pred_masks"] = out["pred_masks"][
batch_idx, best_mask_idx
].unsqueeze(1)
out["pred_boxes"] = out["pred_boxes"][batch_idx, best_mask_idx].unsqueeze(1)
out["pred_boxes_xyxy"] = out["pred_boxes_xyxy"][
batch_idx, best_mask_idx
].unsqueeze(1)
return out
def _get_dummy_prompt(self, num_prompts=1):
device = self.device
geometric_prompt = Prompt(
box_embeddings=torch.zeros(0, num_prompts, 4, device=device),
box_mask=torch.zeros(num_prompts, 0, device=device, dtype=torch.bool),
)
return geometric_prompt
def forward(self, input: BatchedDatapoint):
device = self.device
backbone_out = {"img_batch_all_stages": input.img_batch}
backbone_out.update(self.backbone.forward_image(input.img_batch))
num_frames = len(input.find_inputs)
assert num_frames == 1
text_outputs = self.backbone.forward_text(input.find_text_batch, device=device)
backbone_out.update(text_outputs)
previous_stages_out = SAM3Output(
iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE
)
find_input = input.find_inputs[0]
find_target = input.find_targets[0]
if find_input.input_points is not None and find_input.input_points.numel() > 0:
print("Warning: Point prompts are ignored in PCS.")
num_interactive_steps = 0 if self.training else self.num_interactive_steps_val
geometric_prompt = Prompt(
box_embeddings=find_input.input_boxes,
box_mask=find_input.input_boxes_mask,
box_labels=find_input.input_boxes_label,
)
# Init vars that are shared across the loop.
stage_outs = []
for cur_step in range(num_interactive_steps + 1):
if cur_step > 0:
# We sample interactive geometric prompts (boxes, points)
geometric_prompt, _ = self.interactive_prompt_sampler.sample(
geo_prompt=geometric_prompt,
find_target=find_target,
previous_out=stage_outs[-1],
)
out = self.forward_grounding(
backbone_out=backbone_out,
find_input=find_input,
find_target=find_target,
geometric_prompt=geometric_prompt.clone(),
)
stage_outs.append(out)
previous_stages_out.append(stage_outs)
return previous_stages_out
def _compute_matching(self, out, targets):
out["indices"] = self.matcher(out, targets)
for aux_out in out.get("aux_outputs", []):
aux_out["indices"] = self.matcher(aux_out, targets)
def back_convert(self, targets):
batched_targets = {
"boxes": targets.boxes.view(-1, 4),
"boxes_xyxy": box_cxcywh_to_xyxy(targets.boxes.view(-1, 4)),
"boxes_padded": targets.boxes_padded,
"positive_map": targets.boxes.new_ones(len(targets.boxes), 1),
"num_boxes": targets.num_boxes,
"masks": targets.segments,
"semantic_masks": targets.semantic_segments,
"is_valid_mask": targets.is_valid_segment,
"is_exhaustive": targets.is_exhaustive,
"object_ids_packed": targets.object_ids,
"object_ids_padded": targets.object_ids_padded,
}
return batched_targets
def predict_inst(
self,
inference_state,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
orig_h, orig_w = (
inference_state["original_height"],
inference_state["original_width"],
)
backbone_out = inference_state["backbone_out"]["sam2_backbone_out"]
(
_,
vision_feats,
_,
_,
) = self.inst_interactive_predictor.model._prepare_backbone_features(
backbone_out
)
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
vision_feats[-1] = (
vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed
)
feats = [
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(
vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1]
)
][::-1]
self.inst_interactive_predictor._features = {
"image_embed": feats[-1],
"high_res_feats": feats[:-1],
}
self.inst_interactive_predictor._is_image_set = True
self.inst_interactive_predictor._orig_hw = [(orig_h, orig_w)]
res = self.inst_interactive_predictor.predict(**kwargs)
self.inst_interactive_predictor._features = None
self.inst_interactive_predictor._is_image_set = False
return res
def predict_inst_batch(
self,
inference_state,
*args,
**kwargs,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
backbone_out = inference_state["backbone_out"]["sam2_backbone_out"]
(
_,
vision_feats,
_,
_,
) = self.inst_interactive_predictor.model._prepare_backbone_features(
backbone_out
)
# Add no_mem_embed, which is added to the lowest res feat. map during training on videos
vision_feats[-1] = (
vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed
)
batch_size = vision_feats[-1].shape[1]
orig_heights, orig_widths = (
inference_state["original_heights"],
inference_state["original_widths"],
)
assert (
batch_size == len(orig_heights) == len(orig_widths)
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
feats = [
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(
vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1]
)
][::-1]
self.inst_interactive_predictor._features = {
"image_embed": feats[-1],
"high_res_feats": feats[:-1],
}
self.inst_interactive_predictor._is_image_set = True
self.inst_interactive_predictor._is_batch = True
self.inst_interactive_predictor._orig_hw = [
(orig_h, orig_w) for orig_h, orig_w in zip(orig_heights, orig_widths)
]
res = self.inst_interactive_predictor.predict_batch(*args, **kwargs)
self.inst_interactive_predictor._features = None
self.inst_interactive_predictor._is_image_set = False
self.inst_interactive_predictor._is_batch = False
return res
class Sam3ImageOnVideoMultiGPU(Sam3Image):
def __init__(
self, *args, async_all_gather=True, gather_backbone_out=None, **kwargs
):
super().__init__(*args, **kwargs)
self.rank = int(os.getenv("RANK", "0"))
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
self.async_all_gather = async_all_gather
# if gather_backbone is not set, default to gathering only for `SAM3VLBackbone`
if gather_backbone_out is None:
gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone)
self.gather_backbone_out = gather_backbone_out
def forward_video_grounding_multigpu(
self,
backbone_out,
find_inputs,
geometric_prompt: Prompt,
frame_idx,
num_frames,
# `multigpu_buffer` is a dict to cache detector's outputs in a chunk between different calls
multigpu_buffer,
track_in_reverse=False,
# whether to also return the SAM2 backbone features
return_sam2_backbone_feats=False,
# whether to perform NMS and suppress the scores of those detections removed by NMS
run_nms=False,
nms_prob_thresh=None,
nms_iou_thresh=None,
**kwargs,
):
"""
Compute the detector's detection outputs in a distributed manner, where all GPUs process
a chunk of frames (equal to the number of GPUs) at once and store them in cache.
"""
# Step 1: fetch the detector outputs in the current chunk from buffer
frame_idx_curr_b = frame_idx - frame_idx % self.world_size
frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames)
# in case the current frame's detection results are not in the buffer yet, build the current chunk
# (this should only happen on the first chunk, since we are also building the next chunk below)
if frame_idx not in multigpu_buffer:
with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"):
self._build_multigpu_buffer_next_chunk(
backbone_out=backbone_out,
find_inputs=find_inputs,
geometric_prompt=geometric_prompt,
frame_idx_begin=frame_idx_curr_b,
frame_idx_end=frame_idx_curr_e,
num_frames=num_frames,
multigpu_buffer=multigpu_buffer,
run_nms=run_nms,
nms_prob_thresh=nms_prob_thresh,
nms_iou_thresh=nms_iou_thresh,
)
# read out the current frame's results from `multigpu_buffer`
out = {}
for k, (v, handle) in multigpu_buffer[frame_idx].items():
if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats:
continue
if handle is not None:
handle.wait() # wait for async all-gather to finish
out[k] = v
# Step 2: remove detection outputs of the previous chunk from cache to save GPU memory
if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
frame_idx_prev_e = frame_idx_curr_b
frame_idx_prev_b = frame_idx_curr_b - self.world_size
elif track_in_reverse and frame_idx_curr_e < num_frames:
frame_idx_prev_b = frame_idx_curr_e
frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames)
else:
frame_idx_prev_b = frame_idx_prev_e = None
if frame_idx_prev_b is not None:
for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e):
multigpu_buffer.pop(frame_idx_rm, None)
# Step 3: compute and cache detection outputs of the next chunk ahead of time
# (so that we can overlap computation with all-gather transfer)
if not track_in_reverse and frame_idx_curr_e < num_frames:
frame_idx_next_b = frame_idx_curr_e
frame_idx_next_e = min(frame_idx_next_b + self.world_size, num_frames)
elif track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
frame_idx_next_e = frame_idx_curr_b
frame_idx_next_b = frame_idx_curr_b - self.world_size
else:
frame_idx_next_b = frame_idx_next_e = None
if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer:
with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"):
self._build_multigpu_buffer_next_chunk(
backbone_out=backbone_out,
find_inputs=find_inputs,
geometric_prompt=geometric_prompt,
frame_idx_begin=frame_idx_next_b,
frame_idx_end=frame_idx_next_e,
num_frames=num_frames,
multigpu_buffer=multigpu_buffer,
run_nms=run_nms,
nms_prob_thresh=nms_prob_thresh,
nms_iou_thresh=nms_iou_thresh,
)
return out, backbone_out
def _build_multigpu_buffer_next_chunk(
self,
backbone_out,
find_inputs,
geometric_prompt: Prompt,
frame_idx_begin,
frame_idx_end,
num_frames,
multigpu_buffer,
run_nms=False,
nms_prob_thresh=None,
nms_iou_thresh=None,
):
"""Compute detection outputs on a chunk of frames and store their results in multigpu_buffer."""
# each GPU computes detections on one frame in the chunk (in a round-robin manner)
frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1)
# `forward_grounding` (from base class `Sam3ImageOnVideo`) runs the detector on a single frame
with torch.profiler.record_function("forward_grounding"):
out_local = self.forward_grounding(
backbone_out=backbone_out,
find_input=find_inputs[frame_idx_local_gpu],
find_target=None,
geometric_prompt=geometric_prompt,
)
if run_nms:
with torch.profiler.record_function("nms_masks"):
# run NMS as a post-processing step on top of the detection outputs
assert nms_prob_thresh is not None and nms_iou_thresh is not None
pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid()
pred_masks = out_local["pred_masks"]
# loop over text prompts (not an overhead for demo where there's only 1 prompt)
for prompt_idx in range(pred_probs.size(0)):
keep = nms_masks(
pred_probs=pred_probs[prompt_idx],
pred_masks=pred_masks[prompt_idx],
prob_threshold=nms_prob_thresh,
iou_threshold=nms_iou_thresh,
)
# set a very low threshold for those detections removed by NMS
out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float()
if self.gather_backbone_out:
# gather the SAM 2 backbone features across GPUs
feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"]
assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels
# cast the SAM2 backbone features to bfloat16 for all-gather (this is usually
# a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP)
backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]]
fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0])
fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1])
fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2])
# vision_pos_enc is the same on all frames, so no need to all-gather them
vision_pos_enc = feats["vision_pos_enc"]
# trim the detector output to only include the necessary keys
out_local = {
"pred_logits": out_local["pred_logits"],
"pred_boxes": out_local["pred_boxes"],
"pred_boxes_xyxy": out_local["pred_boxes_xyxy"],
"pred_masks": out_local["pred_masks"],
}
# gather the results: after this step, each GPU will receive detector outputs on
# all frames in the chunk and store them in `multigpu_buffer`
out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()}
for rank in range(self.world_size):
frame_idx_to_save = frame_idx_begin + rank
if frame_idx_to_save >= num_frames:
continue
frame_buffer = {
k: (v[rank], handle) for k, (v, handle) in out_gathered.items()
}
if self.gather_backbone_out:
# also add gathered SAM 2 backbone features to frame_buffer
frame_buffer["tracker_backbone_fpn_0"] = (fpn0[rank], fpn_handle0)
frame_buffer["tracker_backbone_fpn_1"] = (fpn1[rank], fpn_handle1)
frame_buffer["tracker_backbone_fpn_2"] = (fpn2[rank], fpn_handle2)
frame_buffer["tracker_backbone_pos_enc"] = (vision_pos_enc, None)
multigpu_buffer[frame_idx_to_save] = frame_buffer
def _gather_tensor(self, x):
if self.world_size == 1:
return [x], None
async_op = self.async_all_gather
# here `.contiguous()` is required -- otherwise NCCL all_gather
# sometimes gives wrong results
x = x.contiguous() # ensure contiguous memory for NCCL
output_list = [torch.empty_like(x) for _ in range(self.world_size)]
handle = torch.distributed.all_gather(output_list, x, async_op=async_op)
return output_list, handle
|