bellmake's picture
SAM3 Video Segmentation - Clean deployment
14114e8
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import datetime
import logging
import math
import os
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Set
import numpy as np
import numpy.typing as npt
import torch
import torch.distributed as dist
import torch.nn.functional as F
from sam3 import perflib
from sam3.logger import get_logger
from sam3.model.box_ops import fast_diag_box_iou
from sam3.model.data_misc import BatchedDatapoint
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box
from sam3.perflib.masks_ops import mask_iou
from sam3.train.masks_ops import rle_encode
from torch import nn, Tensor
logger = get_logger(__name__)
class MaskletConfirmationStatus(Enum):
UNCONFIRMED = 1 # newly added masklet, not confirmed by any detection yet
CONFIRMED = 2 # confirmed by at least one detection
class Sam3VideoBase(nn.Module):
def __init__(
self,
detector: nn.Module,
tracker: nn.Module,
# prob threshold for detection outputs -- only keep detections above this threshold
# enters NMS and det-to-track matching
score_threshold_detection=0.5,
# IoU threshold for detection NMS
det_nms_thresh=0.0,
# IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it
# overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1
assoc_iou_thresh=0.5,
# IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched"
# by any detections -- it is often a stricter threshold like 0.5
trk_assoc_iou_thresh=0.5,
# prob threshold for a detection to be added as a new object
new_det_thresh=0.0,
# hotstart parameters: we hold off the outputs for `hotstart_delay` frames and
# 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh`
# 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh`
hotstart_delay=0,
hotstart_unmatch_thresh=3,
hotstart_dup_thresh=3,
# Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period.
suppress_unmatched_only_within_hotstart=True,
init_trk_keep_alive=0,
max_trk_keep_alive=8,
min_trk_keep_alive=-4,
# Threshold for suppressing overlapping objects based on recent occlusion
suppress_overlapping_based_on_recent_occlusion_threshold=0.0,
decrease_trk_keep_alive_for_empty_masklets=False,
o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets
suppress_det_close_to_boundary=False,
fill_hole_area=16,
# The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1)
max_num_objects=-1,
recondition_every_nth_frame=-1,
# masket confirmation status (to suppress unconfirmed masklets)
masklet_confirmation_enable=False,
# a masklet is confirmed after being consecutively detected and matched for
# `masklet_confirmation_consecutive_det_thresh`
masklet_confirmation_consecutive_det_thresh=3,
# bbox heuristic parameters
reconstruction_bbox_iou_thresh=0.0,
reconstruction_bbox_det_score=0.0,
):
super().__init__()
self.detector = detector
self.tracker = tracker
self.score_threshold_detection = score_threshold_detection
self.det_nms_thresh = det_nms_thresh
self.assoc_iou_thresh = assoc_iou_thresh
self.trk_assoc_iou_thresh = trk_assoc_iou_thresh
self.new_det_thresh = new_det_thresh
# hotstart parameters
if hotstart_delay > 0:
assert hotstart_unmatch_thresh <= hotstart_delay
assert hotstart_dup_thresh <= hotstart_delay
self.hotstart_delay = hotstart_delay
self.hotstart_unmatch_thresh = hotstart_unmatch_thresh
self.hotstart_dup_thresh = hotstart_dup_thresh
self.suppress_unmatched_only_within_hotstart = (
suppress_unmatched_only_within_hotstart
)
self.init_trk_keep_alive = init_trk_keep_alive
self.max_trk_keep_alive = max_trk_keep_alive
self.min_trk_keep_alive = min_trk_keep_alive
self.suppress_overlapping_based_on_recent_occlusion_threshold = (
suppress_overlapping_based_on_recent_occlusion_threshold
)
self.suppress_det_close_to_boundary = suppress_det_close_to_boundary
self.decrease_trk_keep_alive_for_empty_masklets = (
decrease_trk_keep_alive_for_empty_masklets
)
self.o2o_matching_masklets_enable = o2o_matching_masklets_enable
self.fill_hole_area = fill_hole_area
self.eval()
self.rank = int(os.getenv("RANK", "0"))
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use)
# the maximum object number
if max_num_objects > 0:
num_obj_for_compile = math.ceil(max_num_objects / self.world_size)
else:
max_num_objects = 10000 # no limit
num_obj_for_compile = 16
logger.info(f"setting {max_num_objects=} and {num_obj_for_compile=}")
self.max_num_objects = max_num_objects
self.num_obj_for_compile = num_obj_for_compile
self.recondition_every_nth_frame = recondition_every_nth_frame
self.masklet_confirmation_enable = masklet_confirmation_enable
self.masklet_confirmation_consecutive_det_thresh = (
masklet_confirmation_consecutive_det_thresh
)
self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh
self.reconstruction_bbox_det_score = reconstruction_bbox_det_score
@property
def device(self):
self._device = getattr(self, "_device", None) or next(self.parameters()).device
return self._device
def _init_dist_pg_cpu(self):
# a short 3-min timeout to quickly detect any synchronization failures
timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180"))
timeout = datetime.timedelta(seconds=timeout_sec)
self._dist_pg_cpu = dist.new_group(backend="gloo", timeout=timeout)
def broadcast_python_obj_cpu(self, python_obj_list, src):
if self._dist_pg_cpu is None:
self._init_dist_pg_cpu()
dist.broadcast_object_list(python_obj_list, src=src, group=self._dist_pg_cpu)
def _det_track_one_frame(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
input_batch: BatchedDatapoint,
geometric_prompt: Any,
tracker_states_local: List[Any],
tracker_metadata_prev: Dict[str, Any],
feature_cache: Dict,
orig_vid_height: int,
orig_vid_width: int,
is_image_only: bool = False,
allow_new_detections: bool = True,
):
"""
This function handles one-step inference for the DenseTracking model in an SPMD manner.
At a high-level, all GPUs execute the same function calls as if it's done on a single GPU,
while under the hood, some function calls involve distributed computation based on sharded
SAM2 states.
- `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs
- `tracker_states_local` holds the local masklet information in this GPU shard
- `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs
it contains both global and local masklet information
"""
# Step 1: run backbone and detector in a distributed manner -- this is done via Sam3ImageOnVideoMultiGPU,
# a MultiGPU model (assigned to `self.detector`) that shards frames in a round-robin manner.
# It returns a "det_out" dict for `frame_idx` and fills SAM2 backbone features for `frame_idx`
# into `feature_cache`. Despite its distributed inference under the hood, the results would be
# the same as if it is running backbone and detector for every frame on a single GPU.
det_out = self.run_backbone_and_detection(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
input_batch=input_batch,
geometric_prompt=geometric_prompt,
feature_cache=feature_cache,
allow_new_detections=allow_new_detections,
)
# Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks.
# the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions
# gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only
# runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks;
# we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics.
if tracker_metadata_prev == {}:
# initialize masklet metadata if it's uninitialized (empty dict)
tracker_metadata_prev.update(self._initialize_metadata())
tracker_low_res_masks_global, tracker_obj_scores_global = (
self.run_tracker_propagation(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
tracker_states_local=tracker_states_local,
tracker_metadata_prev=tracker_metadata_prev,
)
)
# Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans
# for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc).
# We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints.
# **This step should involve all the heuristics needed for any updates.** Most of the update
# planning will be done on the master rank (GPU 0) and the resulting plan `tracker_update_plan` is
# broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the
# new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`).
tracker_update_plan, tracker_metadata_new = (
self.run_tracker_update_planning_phase(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
det_out=det_out,
tracker_low_res_masks_global=tracker_low_res_masks_global,
tracker_obj_scores_global=tracker_obj_scores_global,
tracker_metadata_prev=tracker_metadata_prev,
tracker_states_local=tracker_states_local,
is_image_only=is_image_only,
)
)
# Get reconditioning info from the update plan
reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set())
det_to_matched_trk_obj_ids = tracker_update_plan.get(
"det_to_matched_trk_obj_ids", {}
)
# Step 4: based on `tracker_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states
tracker_states_local_new = self.run_tracker_update_execution_phase(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
det_out=det_out,
tracker_states_local=tracker_states_local,
tracker_update_plan=tracker_update_plan,
orig_vid_height=orig_vid_height,
orig_vid_width=orig_vid_width,
feature_cache=feature_cache,
)
# Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since
# only GPU 0 will send outputs to the server).
if self.rank == 0:
obj_id_to_mask = self.build_outputs(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
det_out=det_out,
tracker_low_res_masks_global=tracker_low_res_masks_global,
tracker_obj_scores_global=tracker_obj_scores_global,
tracker_metadata_prev=tracker_metadata_prev,
tracker_update_plan=tracker_update_plan,
orig_vid_height=orig_vid_height,
orig_vid_width=orig_vid_width,
reconditioned_obj_ids=reconditioned_obj_ids,
det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
)
obj_id_to_score = tracker_metadata_new["obj_id_to_score"]
else:
obj_id_to_mask, obj_id_to_score = {}, {} # dummy outputs on other GPUs
# a few statistics for the current frame as a part of the output
frame_stats = {
"num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]),
"num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"],
}
# add tracker scores to metadata, it should be fired for frames except the first frame
if tracker_obj_scores_global.shape[0] > 0:
# Convert tracker_obj_scores_global to sigmoid scores before updating
tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist()
tracker_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"]
tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][
frame_idx
].update(dict(zip(tracker_obj_ids, tracker_obj_scores_global)))
return (
obj_id_to_mask, # a dict: obj_id --> output mask
obj_id_to_score, # a dict: obj_id --> output score (prob)
tracker_states_local_new,
tracker_metadata_new,
frame_stats,
tracker_obj_scores_global, # a dict: obj_id --> tracker frame-level scores
)
def _suppress_detections_close_to_boundary(self, boxes, margin=0.025):
"""
Suppress detections too close to image edges (for normalized boxes).
boxes: (N, 4) in xyxy format, normalized [0,1]
margin: fraction of image
"""
x_min, y_min, x_max, y_max = boxes.unbind(-1)
x_c = (x_min + x_max) / 2
y_c = (y_min + y_max) / 2
keep = (
(x_c > margin)
& (x_c < 1.0 - margin)
& (y_c > margin)
& (y_c < 1.0 - margin)
)
return keep
def run_backbone_and_detection(
self,
frame_idx: int,
num_frames: int,
input_batch: BatchedDatapoint,
geometric_prompt: Any,
feature_cache: Dict,
reverse: bool,
allow_new_detections: bool,
):
# Step 1: if text feature is not cached in `feature_cache`, compute and cache it
text_batch_key = tuple(input_batch.find_text_batch)
if "text" not in feature_cache or text_batch_key not in feature_cache["text"]:
text_outputs = self.detector.backbone.forward_text(
input_batch.find_text_batch, device=self.device
)
# note: we only cache the text feature of the most recent prompt
feature_cache["text"] = {text_batch_key: text_outputs}
else:
text_outputs = feature_cache["text"][text_batch_key]
# Step 2: run backbone, detector, and post-processing with NMS
if "multigpu_buffer" not in feature_cache:
# "multigpu_buffer" is a buffer cache used by `self.detector` and it needs
# to be passed to `forward_video_grounding_multigpu` for every call
feature_cache["multigpu_buffer"] = {}
# Extract max_frame_num_to_track from feature_cache if available
tracking_bounds = feature_cache.get("tracking_bounds", {})
max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track")
start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx")
sam3_image_out, _ = self.detector.forward_video_grounding_multigpu(
backbone_out={
"img_batch_all_stages": input_batch.img_batch,
**text_outputs,
},
find_inputs=input_batch.find_inputs,
geometric_prompt=geometric_prompt,
frame_idx=frame_idx,
num_frames=num_frames,
multigpu_buffer=feature_cache["multigpu_buffer"],
track_in_reverse=reverse,
# also get the SAM2 backbone features
return_tracker_backbone_feats=True,
# run NMS as a part of distributed computation
run_nms=self.det_nms_thresh > 0.0,
nms_prob_thresh=self.score_threshold_detection,
nms_iou_thresh=self.det_nms_thresh,
# pass max_frame_num_to_track to respect tracking limits
max_frame_num_to_track=max_frame_num_to_track,
propagate_in_video_start_frame_idx=start_frame_idx,
)
# note: detections in `sam3_image_out` has already gone through NMS
pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid()
if not allow_new_detections:
pred_probs = pred_probs - 1e8 # make sure no detections are kept
pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"]
pred_masks = sam3_image_out["pred_masks"]
# get the positive detection outputs above threshold
pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection)
det_out = {
"bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]],
"mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]],
"scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]],
}
# Step 3: build SAM2 backbone features and store them in `feature_cache`
backbone_cache = {}
sam_mask_decoder = self.tracker.sam_mask_decoder
tracker_backbone_fpn = [
sam_mask_decoder.conv_s0(sam3_image_out["tracker_backbone_fpn_0"]),
sam_mask_decoder.conv_s1(sam3_image_out["tracker_backbone_fpn_1"]),
sam3_image_out["tracker_backbone_fpn_2"], # fpn_2 doesn't need conv
]
tracker_backbone_out = {
"vision_features": tracker_backbone_fpn[-1], # top-level feature
"vision_pos_enc": sam3_image_out["tracker_backbone_pos_enc"],
"backbone_fpn": tracker_backbone_fpn,
}
backbone_cache["tracker_backbone_out"] = tracker_backbone_out
feature_cache[frame_idx] = (
input_batch.img_batch[frame_idx],
backbone_cache,
)
# remove from `feature_cache` old features to save GPU memory
feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None)
return det_out
def run_tracker_propagation(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
tracker_states_local: List[Any],
tracker_metadata_prev: Dict[str, npt.NDArray],
):
# Step 1: propagate the local SAM2 states to get the current frame's prediction
# `low_res_masks_local` of the existing masklets on this GPU
# - obj_ids_local: List[int] -- list of object IDs
# - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask)
obj_ids_local, low_res_masks_local, obj_scores_local = (
self._propogate_tracker_one_frame_local_gpu(
tracker_states_local, frame_idx=frame_idx, reverse=reverse
)
)
assert np.all(
obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank]
), "{} != {}".format(
obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank]
)
# Step 2: all-gather `low_res_masks_local` into `low_res_masks_global`
# - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask)
_, H_mask, W_mask = low_res_masks_local.shape
if self.world_size > 1:
# `low_res_masks_local` and `obj_scores_local` need to be contiguous and float32
# (they could be non-contiguous due to slicing and/or bfloat16 due to autocast)
low_res_masks_local = low_res_masks_local.float().contiguous()
obj_scores_local = obj_scores_local.float().contiguous()
num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank]
assert low_res_masks_local.size(0) == num_obj_this_gpu
assert obj_scores_local.size(0) == num_obj_this_gpu
low_res_masks_peers = [
low_res_masks_local.new_empty(num_obj, H_mask, W_mask)
for num_obj in tracker_metadata_prev["num_obj_per_gpu"]
]
obj_scores_peers = [
obj_scores_local.new_empty(num_obj)
for num_obj in tracker_metadata_prev["num_obj_per_gpu"]
]
dist.all_gather(low_res_masks_peers, low_res_masks_local)
dist.all_gather(obj_scores_peers, obj_scores_local)
low_res_masks_global = torch.cat(low_res_masks_peers, dim=0)
obj_scores_global = torch.cat(obj_scores_peers, dim=0)
else:
low_res_masks_global = low_res_masks_local
obj_scores_global = obj_scores_local
return low_res_masks_global, obj_scores_global
def _recondition_masklets(
self,
frame_idx,
det_out: Dict[str, Tensor],
trk_id_to_max_iou_high_conf_det: List[int],
tracker_states_local: List[Any],
tracker_metadata: Dict[str, npt.NDArray],
tracker_obj_scores_global: Tensor,
):
# Recondition the masklets based on the new detections
for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
new_mask = det_out["mask"][det_idx : det_idx + 1]
input_mask_res = self.tracker.input_mask_size
new_mask_binary = (
F.interpolate(
new_mask.unsqueeze(1),
size=(input_mask_res, input_mask_res),
mode="bilinear",
align_corners=False,
).squeeze(1)[0]
> 0
)
HIGH_CONF_THRESH = 0.8
reconditioned_states_idx = set()
obj_idx = np.where(tracker_metadata["obj_ids_all_gpu"] == trk_obj_id)[
0
].item()
obj_score = tracker_obj_scores_global[obj_idx]
for state_idx, inference_state in enumerate(tracker_states_local):
if (
trk_obj_id in inference_state["obj_ids"]
# NOTE: Goal of this condition is to avoid reconditioning masks that are occluded/low qualiy.
# Unfortunately, these can get reconditioned anyway due to batching. We should consider removing these heuristics.
and obj_score > HIGH_CONF_THRESH
):
logger.debug(
f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned."
)
self.tracker.add_new_mask(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=trk_obj_id,
mask=new_mask_binary,
)
reconditioned_states_idx.add(state_idx)
for idx in reconditioned_states_idx:
self.tracker.propagate_in_video_preflight(
tracker_states_local[idx], run_mem_encoder=True
)
return tracker_states_local
def run_tracker_update_planning_phase(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
det_out: Dict[str, Tensor],
tracker_low_res_masks_global: Tensor,
tracker_obj_scores_global: Tensor,
tracker_metadata_prev: Dict[str, npt.NDArray],
tracker_states_local: List[Any],
is_image_only: bool = False,
):
# initialize new metadata from previous metadata (its values will be updated later)
tracker_metadata_new = {
"obj_ids_per_gpu": deepcopy(tracker_metadata_prev["obj_ids_per_gpu"]),
"obj_ids_all_gpu": None, # will be filled later
"num_obj_per_gpu": deepcopy(tracker_metadata_prev["num_obj_per_gpu"]),
"obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]),
"obj_id_to_tracker_score_frame_wise": deepcopy(
tracker_metadata_prev["obj_id_to_tracker_score_frame_wise"]
),
"obj_id_to_last_occluded": {}, # will be filled later
"max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]),
}
# Initialize reconditioned_obj_ids early to avoid UnboundLocalError
reconditioned_obj_ids = set()
# Step 1: make the update plan and resolve heuristics on GPU 0
det_mask_preds: Tensor = det_out["mask"] # low-res mask logits
det_scores_np: npt.NDArray = det_out["scores"].float().cpu().numpy()
det_bbox_xyxy: Tensor = det_out["bbox"]
if self.rank == 0:
# a) match detector and tracker masks and find new objects
(
new_det_fa_inds,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
trk_id_to_max_iou_high_conf_det,
empty_trk_obj_ids,
) = self._associate_det_trk(
det_masks=det_mask_preds,
det_scores_np=det_scores_np,
trk_masks=tracker_low_res_masks_global,
trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"],
)
if self.suppress_det_close_to_boundary:
keep = self._suppress_detections_close_to_boundary(
det_bbox_xyxy[new_det_fa_inds]
)
new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()]
# check whether we've hit the maximum number of objects we can track (and if so, drop some detections)
prev_obj_num = np.sum(tracker_metadata_prev["num_obj_per_gpu"])
new_det_num = len(new_det_fa_inds)
num_obj_dropped_due_to_limit = 0
if not is_image_only and prev_obj_num + new_det_num > self.max_num_objects:
logger.warning(
f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}"
)
new_det_num_to_keep = self.max_num_objects - prev_obj_num
num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep
new_det_fa_inds = self._drop_new_det_with_obj_limit(
new_det_fa_inds, det_scores_np, new_det_num_to_keep
)
assert len(new_det_fa_inds) == new_det_num_to_keep
new_det_num = len(new_det_fa_inds)
# assign object IDs to new detections and decide which GPU to place them
new_det_start_obj_id = tracker_metadata_prev["max_obj_id"] + 1
new_det_obj_ids = new_det_start_obj_id + np.arange(new_det_num)
prev_workload_per_gpu = tracker_metadata_prev["num_obj_per_gpu"]
new_det_gpu_ids = self._assign_new_det_to_gpus(
new_det_num=new_det_num,
prev_workload_per_gpu=prev_workload_per_gpu,
)
# b) handle hotstart heuristics to remove objects
# here `rank0_metadata` contains metadata stored on (and only accessible to) GPU 0;
# we avoid broadcasting them to other GPUs to save communication cost, assuming
# that `rank0_metadata` is not needed by other GPUs
rank0_metadata_new = deepcopy(tracker_metadata_prev["rank0_metadata"])
if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
obj_ids_newly_removed, rank0_metadata_new = self._process_hotstart(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
new_det_obj_ids=new_det_obj_ids,
empty_trk_obj_ids=empty_trk_obj_ids,
unmatched_trk_obj_ids=unmatched_trk_obj_ids,
rank0_metadata=rank0_metadata_new,
tracker_metadata=tracker_metadata_prev,
)
else:
# if warm-up is not complete, we don't remove any objects
obj_ids_newly_removed = set()
tracker_metadata_new["rank0_metadata"] = rank0_metadata_new
# Step 2: broadcast the update plan to other GPUs
NUM_BROADCAST_ITEMS = 9
if self.rank == 0 and self.world_size > 1:
# `num_obj_per_gpu_on_rank0` is used for metadata consistency check on other GPUs
# (it's a small array with length==self.world_size, so broadcasting it is cheap)
num_obj_per_gpu_on_rank0 = tracker_metadata_prev["num_obj_per_gpu"]
update_plan = [
new_det_fa_inds,
new_det_obj_ids,
new_det_gpu_ids,
num_obj_per_gpu_on_rank0,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
obj_ids_newly_removed,
num_obj_dropped_due_to_limit,
trk_id_to_max_iou_high_conf_det,
]
assert (
len(update_plan) == NUM_BROADCAST_ITEMS
), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
self.broadcast_python_obj_cpu(update_plan, src=0)
elif self.rank > 0 and self.world_size > 1:
update_plan = [
None
] * NUM_BROADCAST_ITEMS # other ranks receive the plan from rank 0
self.broadcast_python_obj_cpu(update_plan, src=0)
(
new_det_fa_inds,
new_det_obj_ids,
new_det_gpu_ids,
num_obj_per_gpu_on_rank0,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
obj_ids_newly_removed,
num_obj_dropped_due_to_limit,
trk_id_to_max_iou_high_conf_det,
) = update_plan
# metadata consistency check: verify that the received `num_obj_per_gpu_on_rank0` is consistent with the local metadata
# it's critical that all GPUs agree on the previous number of objects (otherwise the inference might hang or fail silently)
if not np.all(
num_obj_per_gpu_on_rank0 == tracker_metadata_prev["num_obj_per_gpu"]
):
raise RuntimeError(
f"{self.rank=} received {num_obj_per_gpu_on_rank0=}, which is inconsistent with local record "
f"{tracker_metadata_prev['num_obj_per_gpu']=}. There's likely a bug in update planning or execution."
)
# `tracker_update_plan` should be identical on all GPUs after broadcasting
tracker_update_plan = {
"new_det_fa_inds": new_det_fa_inds, # npt.NDArray
"new_det_obj_ids": new_det_obj_ids, # npt.NDArray
"new_det_gpu_ids": new_det_gpu_ids, # npt.NDArray
"unmatched_trk_obj_ids": unmatched_trk_obj_ids, # npt.NDArray
"det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, # dict
"obj_ids_newly_removed": obj_ids_newly_removed, # set
"num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int
"trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, # dict
"reconditioned_obj_ids": reconditioned_obj_ids, # set
}
# Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding
# NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results
should_recondition_iou = False
# Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections
if (
self.reconstruction_bbox_iou_thresh > 0
and len(trk_id_to_max_iou_high_conf_det) > 0
):
for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
det_box = det_out["bbox"][det_idx]
det_score = det_out["scores"][det_idx]
try:
trk_idx = list(tracker_metadata_prev["obj_ids_all_gpu"]).index(
trk_obj_id
)
except ValueError:
continue # Skip if tracklet not found
tracker_mask = tracker_low_res_masks_global[trk_idx]
mask_binary = tracker_mask > 0
mask_area = mask_binary.sum().item()
if mask_area == 0:
continue # Skip tracklets with zero mask area
# Get bounding box from SAM2 mask and convert to normalized coordinates
tracker_box_pixels = (
mask_to_box(mask_binary.unsqueeze(0).unsqueeze(0))
.squeeze(0)
.squeeze(0)
)
mask_height, mask_width = tracker_mask.shape[-2:]
tracker_box_normalized = torch.tensor(
[
tracker_box_pixels[0] / mask_width,
tracker_box_pixels[1] / mask_height,
tracker_box_pixels[2] / mask_width,
tracker_box_pixels[3] / mask_height,
],
device=tracker_box_pixels.device,
)
# Compute IoU between detection and SAM2 tracklet bounding boxes
det_box_batch = det_box.unsqueeze(0)
tracker_box_batch = tracker_box_normalized.unsqueeze(0)
iou = fast_diag_box_iou(det_box_batch, tracker_box_batch)[0]
if (
iou < self.reconstruction_bbox_iou_thresh
and det_score >= self.reconstruction_bbox_det_score
):
should_recondition_iou = True
reconditioned_obj_ids.add(trk_obj_id)
should_recondition_periodic = (
self.recondition_every_nth_frame > 0
and frame_idx % self.recondition_every_nth_frame == 0
and len(trk_id_to_max_iou_high_conf_det) > 0
)
# Recondition if periodic or IoU condition met
if should_recondition_periodic or should_recondition_iou:
self._recondition_masklets(
frame_idx,
det_out,
trk_id_to_max_iou_high_conf_det,
tracker_states_local,
tracker_metadata_prev,
tracker_obj_scores_global,
)
# Step 4: Run SAM2 memory encoder on the current frame's prediction masks
# This is done on all GPUs
batch_size = tracker_low_res_masks_global.size(0)
if batch_size > 0:
if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0:
# NOTE: tracker_low_res_masks_global is updated in-place then returned
tracker_low_res_masks_global = (
self._suppress_overlapping_based_on_recent_occlusion(
frame_idx,
tracker_low_res_masks_global,
tracker_metadata_prev,
tracker_metadata_new,
obj_ids_newly_removed,
reverse,
)
)
self._tracker_update_memories(
tracker_states_local,
frame_idx,
tracker_metadata=tracker_metadata_prev,
low_res_masks=tracker_low_res_masks_global,
)
# Step 4: update the SAM2 metadata based on the update plan
# note: except for "rank0_metadata" (that is only available on GPU 0),
# the updated `tracker_metadata_new` should be identical on all GPUs
for rank in range(self.world_size):
new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank]
updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank]
if len(new_det_obj_ids_this_gpu) > 0:
updated_obj_ids_this_gpu = np.concatenate(
[updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu]
)
if len(obj_ids_newly_removed) > 0:
is_removed = np.isin(
updated_obj_ids_this_gpu, list(obj_ids_newly_removed)
)
updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed]
tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu
tracker_metadata_new["num_obj_per_gpu"][rank] = len(
updated_obj_ids_this_gpu
)
tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate(
tracker_metadata_new["obj_ids_per_gpu"]
)
# update object scores and the maximum object ID assigned so far
if len(new_det_obj_ids) > 0:
tracker_metadata_new["obj_id_to_score"].update(
zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])
)
# tracker scores are not available for new objects, use det score instead.
tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][
frame_idx
].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]))
tracker_metadata_new["max_obj_id"] = max(
tracker_metadata_new["max_obj_id"],
np.max(new_det_obj_ids),
)
# for removed objects, we set their scores to a very low value (-1e4) but still
# keep them in "obj_id_to_score" (it's easier to handle outputs this way)
for obj_id in obj_ids_newly_removed:
tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4
tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][
obj_id
] = -1e4
tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None)
# check that "rank0_metadata" is in tracker_metadata_new if and only if it's GPU 0
assert ("rank0_metadata" in tracker_metadata_new) == (self.rank == 0)
if self.rank == 0 and self.masklet_confirmation_enable:
rank0_metadata = self.update_masklet_confirmation_status(
rank0_metadata=tracker_metadata_new["rank0_metadata"],
obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"],
obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"],
det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
new_det_obj_ids=new_det_obj_ids,
)
tracker_metadata_new["rank0_metadata"] = rank0_metadata
return tracker_update_plan, tracker_metadata_new
def _suppress_overlapping_based_on_recent_occlusion(
self,
frame_idx: int,
tracker_low_res_masks_global: Tensor,
tracker_metadata_prev: Dict[str, Any],
tracker_metadata_new: Dict[str, Any],
obj_ids_newly_removed: Set[int],
reverse: bool = False,
):
"""
Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object.
Args:
frame_idx (int): The current frame index.
tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame.
tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame.
tracker_metadata_new (Dict[str, Any]): The metadata for the current frame.
obj_ids_newly_removed (Set[int]): The object IDs that have been removed.
Return:
Tensor: The updated low-resolution masks with some objects suppressed.
"""
obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"]
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
batch_size = tracker_low_res_masks_global.size(0)
if batch_size > 0:
assert (
len(obj_ids_global) == batch_size
), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
NEVER_OCCLUDED = -1
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
last_occluded_prev = torch.cat(
[
tracker_metadata_prev["obj_id_to_last_occluded"].get(
obj_id,
torch.full(
(1,),
fill_value=(
NEVER_OCCLUDED
if obj_id not in obj_ids_newly_removed
else ALWAYS_OCCLUDED
),
device=binary_tracker_low_res_masks_global.device,
dtype=torch.long,
),
)
for obj_id in obj_ids_global
],
dim=0,
)
to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded(
binary_tracker_low_res_masks_global,
last_occluded_prev,
obj_ids_global,
frame_idx,
reverse,
)
# Update metadata with occlusion information
is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2)))
is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress
last_occluded_new = last_occluded_prev.clone()
last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx
# Slice out the last occluded frame for each object
tracker_metadata_new["obj_id_to_last_occluded"] = {
obj_id: last_occluded_new[obj_idx : obj_idx + 1]
for obj_idx, obj_id in enumerate(obj_ids_global)
}
# Zero out suppressed masks before memory encoding
NO_OBJ_LOGIT = -10
tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT
return tracker_low_res_masks_global
def run_tracker_update_execution_phase(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
det_out: Dict[str, Tensor],
tracker_states_local: List[Any],
tracker_update_plan: Dict[str, npt.NDArray],
orig_vid_height: int,
orig_vid_width: int,
feature_cache: Dict,
):
# initialize tracking scores with detection scores
new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"]
new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"]
new_det_gpu_ids: npt.NDArray = tracker_update_plan["new_det_gpu_ids"]
is_on_this_gpu: npt.NDArray = new_det_gpu_ids == self.rank
new_det_obj_ids_local: npt.NDArray = new_det_obj_ids[is_on_this_gpu]
new_det_fa_inds_local: npt.NDArray = new_det_fa_inds[is_on_this_gpu]
obj_ids_newly_removed: Set[int] = tracker_update_plan["obj_ids_newly_removed"]
# Step 1: add new objects from the detector to SAM2 inference states
if len(new_det_fa_inds_local) > 0:
new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local)
new_det_masks: Tensor = det_out["mask"][new_det_fa_inds_local_t]
# initialize SAM2 with new object masks
tracker_states_local = self._tracker_add_new_objects(
frame_idx=frame_idx,
num_frames=num_frames,
new_obj_ids=new_det_obj_ids_local,
new_obj_masks=new_det_masks,
tracker_states_local=tracker_states_local,
orig_vid_height=orig_vid_height,
orig_vid_width=orig_vid_width,
feature_cache=feature_cache,
)
# Step 2: remove from SAM2 inference states those objects removed by heuristics
if len(obj_ids_newly_removed) > 0:
self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed)
return tracker_states_local
def build_outputs(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
det_out: Dict[str, Tensor],
tracker_low_res_masks_global: Tensor,
tracker_obj_scores_global: Tensor,
tracker_metadata_prev: Dict[str, npt.NDArray],
tracker_update_plan: Dict[str, npt.NDArray],
orig_vid_height: int,
orig_vid_width: int,
reconditioned_obj_ids: set = None,
det_to_matched_trk_obj_ids: dict = None,
):
new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"]
new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"]
obj_id_to_mask = {} # obj_id --> output mask tensor
# Part 1: masks from previous SAM2 propagation
existing_masklet_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"]
existing_masklet_video_res_masks = F.interpolate(
tracker_low_res_masks_global.unsqueeze(1),
size=(orig_vid_height, orig_vid_width),
mode="bilinear",
align_corners=False,
) # (num_obj, 1, H_video, W_video)
existing_masklet_binary = existing_masklet_video_res_masks > 0
assert len(existing_masklet_obj_ids) == len(existing_masklet_binary)
for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary):
obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
# Part 2: masks from new detections
new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds)
new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1)
new_det_low_res_masks = fill_holes_in_mask_scores(
new_det_low_res_masks,
max_area=self.fill_hole_area,
fill_holes=True,
remove_sprinkles=True,
)
new_masklet_video_res_masks = F.interpolate(
new_det_low_res_masks,
size=(orig_vid_height, orig_vid_width),
mode="bilinear",
align_corners=False,
) # (num_obj, 1, H_video, W_video)
new_masklet_binary = new_masklet_video_res_masks > 0
assert len(new_det_obj_ids) == len(new_masklet_video_res_masks)
for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary):
obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
# Part 3: Override masks for reconditioned objects using detection masks
if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0:
trk_id_to_max_iou_high_conf_det = tracker_update_plan.get(
"trk_id_to_max_iou_high_conf_det", {}
)
for obj_id in reconditioned_obj_ids:
det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id)
if det_idx is not None:
det_mask = det_out["mask"][det_idx]
det_mask = det_mask.unsqueeze(0).unsqueeze(0)
det_mask_resized = (
F.interpolate(
det_mask.float(),
size=(orig_vid_height, orig_vid_width),
mode="bilinear",
align_corners=False,
)
> 0
)
det_mask_final = det_mask_resized.squeeze(0)
obj_id_to_mask[obj_id] = det_mask_final
return obj_id_to_mask
def _get_objects_to_suppress_based_on_most_recently_occluded(
self,
binary_low_res_masks: Tensor,
last_occluded: List[int],
obj_ids: List[int],
frame_idx: int = None,
reverse: bool = False,
):
# Suppress overlapping masks for objects that were most recently occluded
assert (
binary_low_res_masks.dtype == torch.bool
), f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
to_suppress = torch.zeros(
binary_low_res_masks.size(0),
device=binary_low_res_masks.device,
dtype=torch.bool,
)
if len(obj_ids) <= 1:
return to_suppress
iou = mask_iou(binary_low_res_masks, binary_low_res_masks) # [N,N]
# Create masks for upper triangular matrix (i < j) and IoU threshold
mask_iou_thresh = (
iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold
)
overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N]
last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1)
last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N)
# Suppress most recently occluded
cmp_op = torch.gt if not reverse else torch.lt
suppress_i_mask = (
overlapping_pairs
& cmp_op(
last_occ_expanded_i, last_occ_expanded_j
) # (last_occ_expanded_i > last_occ_expanded_j)
& (
last_occ_expanded_j > -1
) # j can suppress i only if i was previously occluded
)
suppress_j_mask = (
overlapping_pairs
& cmp_op(last_occ_expanded_j, last_occ_expanded_i)
& (
last_occ_expanded_i > -1
) # i can suppress j only if j was previously occluded
)
# Apply suppression
to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0)
# Log for debugging
if (
self.rank == 0
and logger.isEnabledFor(logging.DEBUG)
and frame_idx is not None
):
suppress_i_mask = suppress_i_mask.cpu().numpy()
suppress_j_mask = suppress_j_mask.cpu().numpy()
last_occluded = last_occluded.cpu().numpy()
# Find all suppression pairs without using torch.where
batch_size = suppress_i_mask.shape[0]
# Log i-suppression cases (where i gets suppressed in favor of j)
for i in range(batch_size):
for j in range(batch_size):
if suppress_i_mask[i, j]:
logger.debug(
f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}"
)
# Log j-suppression cases (where j gets suppressed in favor of i)
for i in range(batch_size):
for j in range(batch_size):
if suppress_j_mask[i, j]:
logger.debug(
f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}"
)
return to_suppress
def _propogate_tracker_one_frame_local_gpu(
self,
inference_states: List[Any],
frame_idx: int,
reverse: bool,
# by default, we disable memory encoding until we gather all outputs
run_mem_encoder: bool = False,
):
"""
inference_states: List of inference states, each state corresponds to a different set of objects.
"""
obj_ids_local = []
low_res_masks_list = []
obj_scores_list = []
for inference_state in inference_states:
if len(inference_state["obj_ids"]) == 0:
continue # skip propagation on empty inference states
# propagate one frame
num_frames_propagated = 0
for out in self.tracker.propagate_in_video(
inference_state,
start_frame_idx=frame_idx,
# end_frame_idx = start_frame_idx + max_frame_num_to_track
# (i.e. propagating 1 frame since end_frame_idx is inclusive)
max_frame_num_to_track=0,
reverse=reverse,
tqdm_disable=True,
run_mem_encoder=run_mem_encoder,
):
out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = out
num_frames_propagated += 1
# only 1 frames should be propagated
assert (
num_frames_propagated == 1 and out_frame_idx == frame_idx
), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
assert isinstance(out_obj_ids, list)
obj_ids_local.extend(out_obj_ids)
low_res_masks_list.append(out_low_res_masks.squeeze(1))
obj_scores_list.append(out_obj_scores.squeeze(1))
# concatenate the output masklets from all local inference states
H_mask = W_mask = self.tracker.low_res_mask_size
if len(low_res_masks_list) > 0:
low_res_masks_local = torch.cat(low_res_masks_list, dim=0)
obj_scores_local = torch.cat(obj_scores_list, dim=0)
assert low_res_masks_local.shape[1:] == (H_mask, W_mask)
# Apply hole filling to the masks
low_res_masks_local = fill_holes_in_mask_scores(
low_res_masks_local.unsqueeze(1),
max_area=self.fill_hole_area,
fill_holes=True,
remove_sprinkles=True,
)
low_res_masks_local = low_res_masks_local.squeeze(1)
else:
low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device)
obj_scores_local = torch.zeros(0, device=self.device)
return obj_ids_local, low_res_masks_local, obj_scores_local
def _associate_det_trk(
self,
det_masks: Tensor,
det_scores_np: npt.NDArray,
trk_masks: Tensor,
trk_obj_ids: npt.NDArray,
):
"""
Match detections on the current frame with the existing masklets.
Args:
- det_masks: (N, H, W) tensor of predicted masks
- det_scores_np: (N,) array of detection scores
- trk_masks: (M, H, W) tensor of track masks
- trk_obj_ids: (M,) array of object IDs corresponding to trk_masks
Returns:
- new_det_fa_inds: array of new object indices.
- unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched
to any detections on this frame (for unmatched, we only count masklets with >0 area)
- det_to_matched_trk_obj_ids: dict[int, npt.NDArray]: mapping from detector's detection indices
to the list of matched tracklet object IDs
- empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction
"""
iou_threshold = self.assoc_iou_thresh
iou_threshold_trk = self.trk_assoc_iou_thresh
new_det_thresh = self.new_det_thresh
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert (
trk_masks.size(0) == len(trk_obj_ids)
), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
if trk_masks.size(0) == 0:
# all detections are new
new_det_fa_inds = np.arange(det_masks.size(0))
unmatched_trk_obj_ids = np.array([], np.int64)
empty_trk_obj_ids = np.array([], np.int64)
det_to_matched_trk_obj_ids = {}
trk_id_to_max_iou_high_conf_det = {}
return (
new_det_fa_inds,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
trk_id_to_max_iou_high_conf_det,
empty_trk_obj_ids,
)
elif det_masks.size(0) == 0:
# all previous tracklets are unmatched if they have a non-zero area
new_det_fa_inds = np.array([], np.int64)
trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy()
unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty]
empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
det_to_matched_trk_obj_ids = {}
trk_id_to_max_iou_high_conf_det = {}
return (
new_det_fa_inds,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
trk_id_to_max_iou_high_conf_det,
empty_trk_obj_ids,
)
if det_masks.shape[-2:] != trk_masks.shape[-2:]:
# resize to the smaller size to save GPU memory
if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]):
trk_masks = F.interpolate(
trk_masks.unsqueeze(1),
size=det_masks.shape[-2:],
mode="bilinear",
align_corners=False,
).squeeze(1)
else:
# resize detections to track size
det_masks = F.interpolate(
det_masks.unsqueeze(1),
size=trk_masks.shape[-2:],
mode="bilinear",
align_corners=False,
).squeeze(1)
det_masks_binary = det_masks > 0
trk_masks_binary = trk_masks > 0
ious = mask_iou(det_masks_binary, trk_masks_binary) # (N, M)
ious_np = ious.cpu().numpy()
if self.o2o_matching_masklets_enable:
from scipy.optimize import linear_sum_assignment
# Hungarian matching for tracks (one-to-one: each track matches at most one detection)
cost_matrix = 1 - ious_np # Hungarian solves for minimum cost
row_ind, col_ind = linear_sum_assignment(cost_matrix)
trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool)
for d, t in zip(row_ind, col_ind):
if ious_np[d, t] >= iou_threshold_trk:
trk_is_matched[t] = True
else:
trk_is_matched = (ious_np >= iou_threshold_trk).any(axis=0)
# Non-empty tracks not matched by Hungarian assignment above threshold are unmatched
trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)).cpu().numpy()
trk_is_unmatched = np.logical_and(trk_is_nonempty, ~trk_is_matched)
unmatched_trk_obj_ids = trk_obj_ids[trk_is_unmatched]
# also record masklets that have zero area in SAM 2 prediction
empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
# For detections: allow many tracks to match to the same detection (many-to-one)
# So, a detection is 'new' if it does not match any track above threshold
is_new_det = np.logical_and(
det_scores_np >= new_det_thresh,
np.logical_not(np.any(ious_np >= iou_threshold, axis=1)),
)
new_det_fa_inds = np.nonzero(is_new_det)[0]
# for each detection, which tracks it matched to (above threshold)
det_to_matched_trk_obj_ids = {}
trk_id_to_max_iou_high_conf_det = {} # trk id --> exactly one detection idx
HIGH_CONF_THRESH = 0.8
HIGH_IOU_THRESH = 0.8
det_to_max_iou_trk_idx = np.argmax(ious_np, axis=1)
det_is_high_conf = (det_scores_np >= HIGH_CONF_THRESH) & ~is_new_det
det_is_high_iou = np.max(ious_np, axis=1) >= HIGH_IOU_THRESH
det_is_high_conf_and_iou = set(
np.nonzero(det_is_high_conf & det_is_high_iou)[0]
)
for d in range(det_masks.size(0)):
det_to_matched_trk_obj_ids[d] = trk_obj_ids[ious_np[d, :] >= iou_threshold]
if d in det_is_high_conf_and_iou:
trk_obj_id = trk_obj_ids[det_to_max_iou_trk_idx[d]].item()
trk_id_to_max_iou_high_conf_det[trk_obj_id] = d
return (
new_det_fa_inds,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
trk_id_to_max_iou_high_conf_det,
empty_trk_obj_ids,
)
def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu):
"""Distribute the new objects to the GPUs with the least workload."""
workload_per_gpu: npt.NDArray = prev_workload_per_gpu.copy()
new_det_gpu_ids = np.zeros(new_det_num, np.int64)
# assign the objects one by one
for i in range(len(new_det_gpu_ids)):
# find the GPU with the least workload
min_gpu = np.argmin(workload_per_gpu)
new_det_gpu_ids[i] = min_gpu
workload_per_gpu[min_gpu] += 1
return new_det_gpu_ids
def _process_hotstart(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
det_to_matched_trk_obj_ids: Dict[int, npt.NDArray],
new_det_obj_ids: npt.NDArray,
empty_trk_obj_ids: npt.NDArray,
unmatched_trk_obj_ids: npt.NDArray,
rank0_metadata: Dict[str, Any],
tracker_metadata: Dict[str, Any],
):
"""Handle hotstart heuristics to remove unmatched or duplicated objects."""
# obj_id --> first frame index where the object was detected
obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"]
# obj_id --> [mismatched frame indices]
unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"]
trk_keep_alive = rank0_metadata["trk_keep_alive"]
# (first_appear_obj_id, obj_id) --> [overlap frame indices]
overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"]
# removed_obj_ids: object IDs that are suppressed via hot-start
removed_obj_ids = rank0_metadata["removed_obj_ids"]
suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx]
obj_ids_newly_removed = set() # object IDs to be newly removed on this frame
hotstart_diff = (
frame_idx - self.hotstart_delay
if not reverse
else frame_idx + self.hotstart_delay
)
# Step 1: log the frame index where each object ID first appears
for obj_id in new_det_obj_ids:
if obj_id not in obj_first_frame_idx:
obj_first_frame_idx[obj_id] = frame_idx
assert obj_id not in trk_keep_alive
trk_keep_alive[obj_id] = self.init_trk_keep_alive
matched_trks = set()
# We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded
for matched_trks_per_det in det_to_matched_trk_obj_ids.values():
matched_trks.update(matched_trks_per_det)
for obj_id in matched_trks:
# NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive
trk_keep_alive[obj_id] = min(
self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1
)
for obj_id in unmatched_trk_obj_ids:
unmatched_frame_inds[obj_id].append(frame_idx)
# NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
# The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough.
trk_keep_alive[obj_id] = max(
self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1
)
if self.decrease_trk_keep_alive_for_empty_masklets:
for obj_id in empty_trk_obj_ids:
# NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
trk_keep_alive[obj_id] = max(
self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1
)
# Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period
# a) add unmatched frame indices for each existing object ID
# note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask
# doesn't match any detection; it excludes those frames where SAM2 gives an empty mask
# b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more
# than `self.hotstart_unmatch_thresh` frames
for obj_id, frame_indices in unmatched_frame_inds.items():
if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
continue # skip if the object is already removed
if len(frame_indices) >= self.hotstart_unmatch_thresh:
is_within_hotstart = (
obj_first_frame_idx[obj_id] > hotstart_diff and not reverse
) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse)
if is_within_hotstart:
obj_ids_newly_removed.add(obj_id)
logger.debug(
f"Removing object {obj_id} at frame {frame_idx} "
f"since it is unmatched for frames: {frame_indices}"
)
if (
trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long
and not self.suppress_unmatched_only_within_hotstart
and obj_id not in removed_obj_ids
and obj_id not in obj_ids_newly_removed
):
logger.debug(
f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched"
)
suppressed_obj_ids.add(obj_id)
# Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames
# a) find overlaps tracks -- we consider overlap if they match to the same detection
for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items():
if len(matched_trk_obj_ids) < 2:
continue # only count detections that are matched to multiple (>=2) masklets
# if there are multiple matched track ids, we need to find the one that appeared first;
# these later appearing ids may be removed since they may be considered as duplicates
first_appear_obj_id = (
min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
if not reverse
else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
)
for obj_id in matched_trk_obj_ids:
if obj_id != first_appear_obj_id:
key = (first_appear_obj_id, obj_id)
overlap_pair_to_frame_inds[key].append(frame_idx)
# b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another
# masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames
for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items():
if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
continue # skip if the object is already removed
if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or (
obj_first_frame_idx[obj_id] < hotstart_diff and reverse
):
if len(frame_indices) >= self.hotstart_dup_thresh:
obj_ids_newly_removed.add(obj_id)
logger.debug(
f"Removing object {obj_id} at frame {frame_idx} "
f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}"
)
removed_obj_ids.update(obj_ids_newly_removed)
return obj_ids_newly_removed, rank0_metadata
def _tracker_update_memories(
self,
tracker_inference_states: List[Any],
frame_idx: int,
tracker_metadata: Dict[str, Any],
low_res_masks: Tensor,
):
"""
Run Sam2 memory encoder, enforcing non-overlapping constraints globally.
"""
if len(tracker_inference_states) == 0:
return
# Avoid an extra interpolation step by directly interpolating to `interpol_size`
high_res_H, high_res_W = (
self.tracker.maskmem_backbone.mask_downsampler.interpol_size
)
# NOTE: inspect this part if we observe OOMs in the demo
high_res_masks = F.interpolate(
low_res_masks.unsqueeze(1),
size=(high_res_H, high_res_W),
mode="bilinear",
align_corners=False,
)
# We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics.
if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
high_res_masks = self.tracker._suppress_object_pw_area_shrinkage(
high_res_masks
)
# Instead of gathering the predicted object scores, we use mask areas as a proxy.
object_score_logits = torch.where(
(high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0
)
# Run the memory encoder on local slices for each GPU
start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank])
start_idx_state = start_idx_gpu
for tracker_state in tracker_inference_states:
num_obj_per_state = len(tracker_state["obj_ids"])
if num_obj_per_state == 0:
continue
# Get the local high-res masks and object score logits for this inference state
end_idx_state = start_idx_state + num_obj_per_state
local_high_res_masks = high_res_masks[start_idx_state:end_idx_state]
local_object_score_logits = object_score_logits[
start_idx_state:end_idx_state
]
local_batch_size = local_high_res_masks.size(0)
# Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default
encoded_mem = self.tracker._run_memory_encoder(
tracker_state,
frame_idx,
local_batch_size,
local_high_res_masks,
local_object_score_logits,
is_mask_from_pts=False,
)
local_maskmem_features, local_maskmem_pos_enc = encoded_mem
# Store encoded memories in the local inference state
output_dict = tracker_state["output_dict"]
for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]:
if frame_idx not in output_dict[storage_key]:
continue
output_dict[storage_key][frame_idx]["maskmem_features"] = (
local_maskmem_features
)
output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [
pos for pos in local_maskmem_pos_enc
]
# for batched inference state, we also need to add per-object
# memory slides to support instance interactivity
self.tracker._add_output_per_object(
inference_state=tracker_state,
frame_idx=frame_idx,
current_out=output_dict[storage_key][frame_idx],
storage_key=storage_key,
)
start_idx_state += num_obj_per_state
def _tracker_add_new_objects(
self,
frame_idx: int,
num_frames: int,
new_obj_ids: List[int],
new_obj_masks: Tensor,
tracker_states_local: List[Any],
orig_vid_height: int,
orig_vid_width: int,
feature_cache: Dict,
):
"""Add a new object to SAM2 inference states."""
prev_tracker_state = (
tracker_states_local[0] if len(tracker_states_local) > 0 else None
)
# prepare inference_state
# batch objects that first appear on the same frame together
# Clear inference state. Keep the cached image features if available.
new_tracker_state = self.tracker.init_state(
cached_features=feature_cache,
video_height=orig_vid_height,
video_width=orig_vid_width,
num_frames=num_frames,
)
new_tracker_state["backbone_out"] = (
prev_tracker_state.get("backbone_out", None)
if prev_tracker_state is not None
else None
)
assert len(new_obj_ids) == new_obj_masks.size(0)
assert new_obj_masks.is_floating_point()
input_mask_res = self.tracker.input_mask_size
new_obj_masks = F.interpolate(
new_obj_masks.unsqueeze(1),
size=(input_mask_res, input_mask_res),
mode="bilinear",
align_corners=False,
).squeeze(1)
new_obj_masks = new_obj_masks > 0
# add object one by one
for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks):
self.tracker.add_new_mask(
inference_state=new_tracker_state,
frame_idx=frame_idx,
obj_id=new_obj_id,
mask=new_mask,
add_mask_to_memory=True,
)
# NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects.
self.tracker.propagate_in_video_preflight(
new_tracker_state, run_mem_encoder=True
)
tracker_states_local.append(new_tracker_state)
return tracker_states_local
def _tracker_remove_object(self, tracker_states_local: List[Any], obj_id: int):
"""
Remove an object from SAM2 inference states. This would remove the object from
all frames in the video.
"""
tracker_states_local_before_removal = tracker_states_local.copy()
tracker_states_local.clear()
for tracker_inference_state in tracker_states_local_before_removal:
# we try to remove `obj_id` on every inference state with `strict=False`
# it will not do anything if an inference state doesn't contain `obj_id`
new_obj_ids, _ = self.tracker.remove_object(
tracker_inference_state, obj_id, strict=False, need_output=False
)
# only keep an inference state if it's non-empty after object removal
if len(new_obj_ids) > 0:
tracker_states_local.append(tracker_inference_state)
def _tracker_remove_objects(
self, tracker_states_local: List[Any], obj_ids: list[int]
):
"""
Remove an object from SAM2 inference states. This would remove the object from
all frames in the video.
"""
for obj_id in obj_ids:
self._tracker_remove_object(tracker_states_local, obj_id)
def _initialize_metadata(self):
"""Initialize metadata for the masklets."""
tracker_metadata = {
"obj_ids_per_gpu": [np.array([], np.int64) for _ in range(self.world_size)],
"obj_ids_all_gpu": np.array([], np.int64),
"num_obj_per_gpu": np.zeros(self.world_size, np.int64),
"max_obj_id": -1,
"obj_id_to_score": {},
"obj_id_to_tracker_score_frame_wise": defaultdict(dict),
"obj_id_to_last_occluded": {},
}
if self.rank == 0:
# "rank0_metadata" contains metadata that is only stored on (and accessible to) GPU 0
# - obj_first_frame_idx: obj_id --> first frame index where the object was detected
# - unmatched_frame_inds: obj_id --> [mismatched frame indices]
# - overlap_pair_to_frame_inds: (first_appear_obj_id, obj_id) --> [overlap frame indices]
# - removed_obj_ids: object IDs that are suppressed via hot-start
rank0_metadata = {
"obj_first_frame_idx": {},
"unmatched_frame_inds": defaultdict(list),
"trk_keep_alive": defaultdict(
int
), # This is used only for object suppression not for removal
"overlap_pair_to_frame_inds": defaultdict(list),
"removed_obj_ids": set(),
"suppressed_obj_ids": defaultdict(
set
), # frame_idx --> set of objects with suppressed outputs, but still continue to be tracked
}
if self.masklet_confirmation_enable:
# all the following are npt.NDArray with the same shape as `obj_ids_all_gpu`
rank0_metadata["masklet_confirmation"] = {
# "status" is the confirmation status of each masklet (in `MaskletConfirmationStatus`)
"status": np.array([], np.int64),
# "consecutive_det_num" is the number of consecutive frames where the masklet is
# detected by the detector (with a matched detection)
"consecutive_det_num": np.array([], np.int64),
}
tracker_metadata["rank0_metadata"] = rank0_metadata
return tracker_metadata
def update_masklet_confirmation_status(
self,
rank0_metadata: Dict[str, Any],
obj_ids_all_gpu_prev: npt.NDArray,
obj_ids_all_gpu_updated: npt.NDArray,
det_to_matched_trk_obj_ids: Dict[int, npt.NDArray],
new_det_obj_ids: npt.NDArray,
):
confirmation_data = rank0_metadata["masklet_confirmation"]
# a) first, expand "confirmation_data" to include new masklets added in this frame
status_prev = confirmation_data["status"]
consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
assert (
status_prev.shape == obj_ids_all_gpu_prev.shape
), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
obj_id_to_updated_idx = {
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)
}
prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated)
prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated]
prev_elem_inds_in_updated = np.array(
[obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated],
dtype=np.int64,
)
# newly added masklets are initialized to "UNCONFIRMED" status
unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value
status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val)
status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated]
consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated)
consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[
prev_elem_is_in_updated
]
# b) update the confirmation status of all masklets based on the current frame
# b.1) update "consecutive_det_num"
# "is_matched": whether a masklet is matched to a detection on this frame
is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids)
for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values():
is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids)
consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0)
# b.2) update "status"
change_to_confirmed = (
consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh
)
status[change_to_confirmed] = MaskletConfirmationStatus.CONFIRMED.value
confirmation_data["status"] = status
confirmation_data["consecutive_det_num"] = consecutive_det_num
return rank0_metadata
def forward(self, input: BatchedDatapoint, is_inference: bool = False):
raise NotImplementedError("Evaluation outside demo is not implemented yet")
def _load_checkpoint(self, ckpt_path: str, strict: bool = True):
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict)
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
logger.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}")
else:
logger.info("Loaded ckpt successfully without missing or unexpected keys")
def prep_for_evaluator(self, video_frames, tracking_res, scores_labels):
"""This method is only used for benchmark eval (not used in the demo)."""
num_frames = len(video_frames)
w, h = video_frames[0].size
zero_mask = torch.zeros((1, h, w), dtype=torch.bool)
object_ids = list(scores_labels.keys())
preds = {"scores": [], "labels": [], "boxes": [], "masks_rle": []}
for oid in object_ids:
o_masks = []
o_score = scores_labels[oid][0].item()
o_label = scores_labels[oid][1]
for frame_idx in range(num_frames):
if frame_idx not in tracking_res:
o_masks.append(zero_mask)
else:
o_masks.append(tracking_res[frame_idx].get(oid, zero_mask))
o_masks = torch.cat(o_masks, dim=0) # (n_frames, H, W)
preds["scores"].append(o_score)
preds["labels"].append(o_label)
preds["boxes"].append(mask_to_box(o_masks.unsqueeze(1)).squeeze())
preds["masks_rle"].append(rle_encode(o_masks, return_areas=True))
preds["boxes"] = (
torch.stack(preds["boxes"], dim=0)
if len(preds["boxes"]) > 0
else torch.empty(
(0, num_frames, 4), dtype=torch.float32, device=self.device
)
)
preds["scores"] = (
torch.tensor(preds["scores"], device=self.device)
if len(preds["scores"]) > 0
else torch.empty((0,), device=self.device)
)
preds["per_frame_scores"] = preds["scores"]
preds["labels"] = (
torch.tensor(preds["labels"], device=self.device)
if len(preds["labels"]) > 0
else torch.empty((0,), device=self.device)
)
return preds
def _encode_prompt(self, **kwargs):
return self.detector._encode_prompt(**kwargs)
def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep):
"""
Drop a few new detections based on the maximum number of objects. We drop new objects based
on their detection scores, keeping the high-scoring ones and dropping the low-scoring ones.
"""
assert 0 <= num_to_keep <= len(new_det_fa_inds)
if num_to_keep == 0:
return np.array([], np.int64) # keep none
if num_to_keep == len(new_det_fa_inds):
return new_det_fa_inds # keep all
# keep the top-scoring detections
score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1]
new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]]
return new_det_fa_inds