| | |
| | |
| |
|
| | |
| | |
| |
|
| | import warnings |
| | from collections import OrderedDict |
| |
|
| | import torch |
| |
|
| | from tqdm import tqdm |
| |
|
| | from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base |
| | from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames |
| |
|
| |
|
| | class SAM2VideoPredictor(SAM2Base): |
| | """The predictor class to handle user interactions and manage inference states.""" |
| |
|
| | def __init__( |
| | self, |
| | fill_hole_area=0, |
| | |
| | non_overlap_masks=False, |
| | |
| | |
| | clear_non_cond_mem_around_input=False, |
| | |
| | clear_non_cond_mem_for_multi_obj=False, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.fill_hole_area = fill_hole_area |
| | self.non_overlap_masks = non_overlap_masks |
| | self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input |
| | self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj |
| |
|
| | @torch.inference_mode() |
| | def init_state( |
| | self, |
| | video_path, |
| | offload_video_to_cpu=False, |
| | offload_state_to_cpu=False, |
| | async_loading_frames=False, |
| | ): |
| | """Initialize an inference state.""" |
| | compute_device = self.device |
| | images, video_height, video_width = load_video_frames( |
| | video_path=video_path, |
| | image_size=self.image_size, |
| | offload_video_to_cpu=offload_video_to_cpu, |
| | async_loading_frames=async_loading_frames, |
| | compute_device=compute_device, |
| | ) |
| | inference_state = {} |
| | inference_state["images"] = images |
| | inference_state["num_frames"] = len(images) |
| | |
| | |
| | inference_state["offload_video_to_cpu"] = offload_video_to_cpu |
| | |
| | |
| | |
| | |
| | inference_state["offload_state_to_cpu"] = offload_state_to_cpu |
| | |
| | inference_state["video_height"] = video_height |
| | inference_state["video_width"] = video_width |
| | inference_state["device"] = compute_device |
| | if offload_state_to_cpu: |
| | inference_state["storage_device"] = torch.device("cpu") |
| | else: |
| | inference_state["storage_device"] = compute_device |
| | |
| | inference_state["point_inputs_per_obj"] = {} |
| | inference_state["mask_inputs_per_obj"] = {} |
| | |
| | inference_state["cached_features"] = {} |
| | |
| | inference_state["constants"] = {} |
| | |
| | inference_state["obj_id_to_idx"] = OrderedDict() |
| | inference_state["obj_idx_to_id"] = OrderedDict() |
| | inference_state["obj_ids"] = [] |
| | |
| | inference_state["output_dict"] = { |
| | "cond_frame_outputs": {}, |
| | "non_cond_frame_outputs": {}, |
| | } |
| | |
| | inference_state["output_dict_per_obj"] = {} |
| | |
| | |
| | inference_state["temp_output_dict_per_obj"] = {} |
| | |
| | |
| | inference_state["consolidated_frame_inds"] = { |
| | "cond_frame_outputs": set(), |
| | "non_cond_frame_outputs": set(), |
| | } |
| | |
| | inference_state["tracking_has_started"] = False |
| | inference_state["frames_already_tracked"] = {} |
| | |
| | self._get_image_feature(inference_state, frame_idx=0, batch_size=1) |
| | return inference_state |
| |
|
| | def init_state_images( |
| | self, |
| | images, |
| | video_height, |
| | video_width, |
| | offload_video_to_cpu=False, |
| | offload_state_to_cpu=False, |
| | async_loading_frames=False, |
| | ): |
| | """Initialize an inference state.""" |
| | compute_device = self.device |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | inference_state = {} |
| | inference_state["images"] = images[0] |
| | inference_state["num_frames"] = len(images[0]) |
| | |
| | |
| | inference_state["offload_video_to_cpu"] = offload_video_to_cpu |
| | |
| | |
| | |
| | |
| | inference_state["offload_state_to_cpu"] = offload_state_to_cpu |
| | |
| | inference_state["video_height"] = video_height |
| | inference_state["video_width"] = video_width |
| | inference_state["device"] = compute_device |
| | if offload_state_to_cpu: |
| | inference_state["storage_device"] = torch.device("cpu") |
| | else: |
| | inference_state["storage_device"] = compute_device |
| | |
| | inference_state["point_inputs_per_obj"] = {} |
| | inference_state["mask_inputs_per_obj"] = {} |
| | |
| | inference_state["cached_features"] = {} |
| | |
| | inference_state["constants"] = {} |
| | |
| | inference_state["obj_id_to_idx"] = OrderedDict() |
| | inference_state["obj_idx_to_id"] = OrderedDict() |
| | inference_state["obj_ids"] = [] |
| | |
| | inference_state["output_dict"] = { |
| | "cond_frame_outputs": {}, |
| | "non_cond_frame_outputs": {}, |
| | } |
| | |
| | inference_state["output_dict_per_obj"] = {} |
| | |
| | |
| | inference_state["temp_output_dict_per_obj"] = {} |
| | |
| | |
| | inference_state["consolidated_frame_inds"] = { |
| | "cond_frame_outputs": set(), |
| | "non_cond_frame_outputs": set(), |
| | } |
| | |
| | inference_state["tracking_has_started"] = False |
| | inference_state["frames_already_tracked"] = {} |
| | |
| | self._get_image_feature(inference_state, frame_idx=0, batch_size=1) |
| | return inference_state |
| |
|
| |
|
| |
|
| | @classmethod |
| | def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": |
| | """ |
| | Load a pretrained model from the Hugging Face hub. |
| | |
| | Arguments: |
| | model_id (str): The Hugging Face repository ID. |
| | **kwargs: Additional arguments to pass to the model constructor. |
| | |
| | Returns: |
| | (SAM2VideoPredictor): The loaded model. |
| | """ |
| | from sam2.build_sam import build_sam2_video_predictor_hf |
| |
|
| | sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) |
| | return sam_model |
| |
|
| | def _obj_id_to_idx(self, inference_state, obj_id): |
| | """Map client-side object id to model-side object index.""" |
| | obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) |
| | if obj_idx is not None: |
| | return obj_idx |
| |
|
| | |
| | |
| | allow_new_object = not inference_state["tracking_has_started"] |
| | if allow_new_object: |
| | |
| | obj_idx = len(inference_state["obj_id_to_idx"]) |
| | inference_state["obj_id_to_idx"][obj_id] = obj_idx |
| | inference_state["obj_idx_to_id"][obj_idx] = obj_id |
| | inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) |
| | |
| | inference_state["point_inputs_per_obj"][obj_idx] = {} |
| | inference_state["mask_inputs_per_obj"][obj_idx] = {} |
| | inference_state["output_dict_per_obj"][obj_idx] = { |
| | "cond_frame_outputs": {}, |
| | "non_cond_frame_outputs": {}, |
| | } |
| | inference_state["temp_output_dict_per_obj"][obj_idx] = { |
| | "cond_frame_outputs": {}, |
| | "non_cond_frame_outputs": {}, |
| | } |
| | return obj_idx |
| | else: |
| | raise RuntimeError( |
| | f"Cannot add new object id {obj_id} after tracking starts. " |
| | f"All existing object ids: {inference_state['obj_ids']}. " |
| | f"Please call 'reset_state' to restart from scratch." |
| | ) |
| |
|
| | def _obj_idx_to_id(self, inference_state, obj_idx): |
| | """Map model-side object index to client-side object id.""" |
| | return inference_state["obj_idx_to_id"][obj_idx] |
| |
|
| | def _get_obj_num(self, inference_state): |
| | """Get the total number of unique object ids received so far in this session.""" |
| | return len(inference_state["obj_idx_to_id"]) |
| |
|
| | @torch.inference_mode() |
| | def add_new_points_or_box( |
| | self, |
| | inference_state, |
| | frame_idx, |
| | obj_id, |
| | points=None, |
| | labels=None, |
| | clear_old_points=True, |
| | normalize_coords=True, |
| | box=None, |
| | ): |
| | """Add new points to a frame.""" |
| | obj_idx = self._obj_id_to_idx(inference_state, obj_id) |
| | point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] |
| | mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] |
| |
|
| | if (points is not None) != (labels is not None): |
| | raise ValueError("points and labels must be provided together") |
| | if points is None and box is None: |
| | raise ValueError("at least one of points or box must be provided as input") |
| |
|
| | if points is None: |
| | points = torch.zeros(0, 2, dtype=torch.bfloat16) |
| | elif not isinstance(points, torch.Tensor): |
| | points = torch.tensor(points, dtype=torch.bfloat16) |
| | if labels is None: |
| | labels = torch.zeros(0, dtype=torch.int32) |
| | elif not isinstance(labels, torch.Tensor): |
| | labels = torch.tensor(labels, dtype=torch.int32) |
| | if points.dim() == 2: |
| | points = points.unsqueeze(0) |
| | if labels.dim() == 1: |
| | labels = labels.unsqueeze(0) |
| |
|
| | |
| | |
| | if box is not None: |
| | if not clear_old_points: |
| | raise ValueError( |
| | "cannot add box without clearing old points, since " |
| | "box prompt must be provided before any point prompt " |
| | "(please use clear_old_points=True instead)" |
| | ) |
| | if inference_state["tracking_has_started"]: |
| | warnings.warn( |
| | "You are adding a box after tracking starts. SAM 2 may not always be " |
| | "able to incorporate a box prompt for *refinement*. If you intend to " |
| | "use box prompt as an *initial* input before tracking, please call " |
| | "'reset_state' on the inference state to restart from scratch.", |
| | category=UserWarning, |
| | stacklevel=2, |
| | ) |
| | if not isinstance(box, torch.Tensor): |
| | box = torch.tensor(box, dtype=torch.bfloat16, device=points.device) |
| | box_coords = box.reshape(1, 2, 2) |
| | box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) |
| | box_labels = box_labels.reshape(1, 2) |
| | points = torch.cat([box_coords, points], dim=1) |
| | labels = torch.cat([box_labels, labels], dim=1) |
| | |
| |
|
| | if normalize_coords: |
| | video_H = inference_state["video_height"] |
| | video_W = inference_state["video_width"] |
| | points = points / torch.tensor([video_W, video_H]).to(points.device) |
| | |
| | points = points * self.image_size |
| | points = points.to(inference_state["device"]) |
| | labels = labels.to(inference_state["device"]) |
| |
|
| | if not clear_old_points: |
| | point_inputs = point_inputs_per_frame.get(frame_idx, None) |
| | else: |
| | point_inputs = None |
| | point_inputs = concat_points(point_inputs, points, labels) |
| | point_inputs_per_frame[frame_idx] = point_inputs |
| | mask_inputs_per_frame.pop(frame_idx, None) |
| | |
| | |
| | |
| | |
| | is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] |
| | |
| | if is_init_cond_frame: |
| | reverse = False |
| | else: |
| | reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| | |
| | |
| | is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond |
| | storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| |
|
| | |
| | |
| | prev_sam_mask_logits = None |
| | |
| | |
| | prev_out = obj_temp_output_dict[storage_key].get(frame_idx) |
| | if prev_out is None: |
| | prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) |
| | if prev_out is None: |
| | prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) |
| |
|
| | if prev_out is not None and prev_out["pred_masks"] is not None: |
| | device = inference_state["device"] |
| | prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) |
| | |
| | prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) |
| | current_out, _ = self._run_single_frame_inference( |
| | inference_state=inference_state, |
| | output_dict=obj_output_dict, |
| | frame_idx=frame_idx, |
| | batch_size=1, |
| | is_init_cond_frame=is_init_cond_frame, |
| | point_inputs=point_inputs, |
| | mask_inputs=None, |
| | reverse=reverse, |
| | |
| | |
| | |
| | |
| | run_mem_encoder=False, |
| | prev_sam_mask_logits=prev_sam_mask_logits, |
| | ) |
| | |
| | obj_temp_output_dict[storage_key][frame_idx] = current_out |
| |
|
| | |
| | obj_ids = inference_state["obj_ids"] |
| | consolidated_out = self._consolidate_temp_output_across_obj( |
| | inference_state, |
| | frame_idx, |
| | is_cond=is_cond, |
| | run_mem_encoder=False, |
| | consolidate_at_video_res=True, |
| | ) |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, consolidated_out["pred_masks_video_res"] |
| | ) |
| | return frame_idx, obj_ids, video_res_masks |
| |
|
| | @torch.inference_mode() |
| | def add_new_box_embeding( |
| | self, |
| | inference_state, |
| | frame_idx, |
| | obj_id, |
| | box_embeding = None, |
| | points=None, |
| | labels=None, |
| | clear_old_points=True, |
| | normalize_coords=True, |
| | box=None, |
| | ): |
| | """Add new points to a frame.""" |
| | obj_idx = self._obj_id_to_idx(inference_state, obj_id) |
| | point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] |
| | mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] |
| |
|
| | if (points is not None) != (labels is not None): |
| | raise ValueError("points and labels must be provided together") |
| | if points is None and box is None: |
| | raise ValueError("at least one of points or box must be provided as input") |
| |
|
| | if points is None: |
| | points = torch.zeros(0, 2, dtype=torch.bfloat16) |
| | elif not isinstance(points, torch.Tensor): |
| | points = torch.tensor(points, dtype=torch.bfloat16) |
| | if labels is None: |
| | labels = torch.zeros(0, dtype=torch.int32) |
| | elif not isinstance(labels, torch.Tensor): |
| | labels = torch.tensor(labels, dtype=torch.int32) |
| | if points.dim() == 2: |
| | points = points.unsqueeze(0) |
| | if labels.dim() == 1: |
| | labels = labels.unsqueeze(0) |
| |
|
| | |
| | |
| | if box is not None: |
| | if not clear_old_points: |
| | raise ValueError( |
| | "cannot add box without clearing old points, since " |
| | "box prompt must be provided before any point prompt " |
| | "(please use clear_old_points=True instead)" |
| | ) |
| | if inference_state["tracking_has_started"]: |
| | warnings.warn( |
| | "You are adding a box after tracking starts. SAM 2 may not always be " |
| | "able to incorporate a box prompt for *refinement*. If you intend to " |
| | "use box prompt as an *initial* input before tracking, please call " |
| | "'reset_state' on the inference state to restart from scratch.", |
| | category=UserWarning, |
| | stacklevel=2, |
| | ) |
| | if not isinstance(box, torch.Tensor): |
| | box = torch.tensor(box, dtype=torch.bfloat16, device=points.device) |
| | box_coords = box.reshape(1, 2, 2) |
| | box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) |
| | box_labels = box_labels.reshape(1, 2) |
| | points = torch.cat([box_coords, points], dim=1) |
| | labels = torch.cat([box_labels, labels], dim=1) |
| |
|
| | if normalize_coords: |
| | video_H = inference_state["video_height"] |
| | video_W = inference_state["video_width"] |
| | points = points / torch.tensor([video_W, video_H]).to(points.device) |
| | |
| | points = points * self.image_size |
| | points = points.to(inference_state["device"]) |
| | labels = labels.to(inference_state["device"]) |
| |
|
| | if not clear_old_points: |
| | point_inputs = point_inputs_per_frame.get(frame_idx, None) |
| | else: |
| | point_inputs = None |
| | point_inputs = concat_points(point_inputs, points, labels) |
| | point_inputs_per_frame[frame_idx] = point_inputs |
| | mask_inputs_per_frame.pop(frame_idx, None) |
| | |
| | |
| | |
| | |
| | is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] |
| | |
| | if is_init_cond_frame: |
| | reverse = False |
| | else: |
| | reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| | |
| | |
| | is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond |
| | storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| |
|
| | |
| | |
| | prev_sam_mask_logits = None |
| | |
| | |
| | prev_out = obj_temp_output_dict[storage_key].get(frame_idx) |
| | if prev_out is None: |
| | prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) |
| | if prev_out is None: |
| | prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) |
| |
|
| | if prev_out is not None and prev_out["pred_masks"] is not None: |
| | device = inference_state["device"] |
| | prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) |
| | |
| | prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) |
| | current_out, _ = self._run_single_frame_inference_embed( |
| | inference_state=inference_state, |
| | output_dict=obj_output_dict, |
| | frame_idx=frame_idx, |
| | batch_size=1, |
| | is_init_cond_frame=is_init_cond_frame, |
| | box_embed=box_embeding, |
| | point_inputs=point_inputs, |
| | mask_inputs=None, |
| | reverse=reverse, |
| | |
| | |
| | |
| | |
| | run_mem_encoder=False, |
| | prev_sam_mask_logits=prev_sam_mask_logits, |
| | ) |
| | |
| | obj_temp_output_dict[storage_key][frame_idx] = current_out |
| |
|
| | |
| | obj_ids = inference_state["obj_ids"] |
| | consolidated_out = self._consolidate_temp_output_across_obj( |
| | inference_state, |
| | frame_idx, |
| | is_cond=is_cond, |
| | run_mem_encoder=False, |
| | consolidate_at_video_res=True, |
| | ) |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, consolidated_out["pred_masks_video_res"] |
| | ) |
| | return frame_idx, obj_ids, video_res_masks |
| |
|
| | def get_prompt_embeding( |
| | self, |
| | inference_state, |
| | points=None, |
| | labels=None, |
| | normalize_coords=False, |
| | box=None, |
| | device = None |
| | ): |
| | if not isinstance(box, torch.Tensor): |
| | box = torch.tensor(box, dtype=torch.bfloat16, device=device) |
| | box_coords = box.reshape(1, 2, 2) |
| | box_labels = torch.tensor([2, 3], dtype=torch.int32, device=device) |
| | box_labels = box_labels.reshape(1, 2) |
| | if points is None: |
| | points = torch.zeros(0, 2, dtype=torch.bfloat16, device=device) |
| | labels = torch.zeros(0, dtype=torch.int32, device=device) |
| | if points.dim() == 2: |
| | points = points.unsqueeze(0) |
| | if labels.dim() == 1: |
| | labels = labels.unsqueeze(0) |
| | points = torch.cat([box_coords, points], dim=1) |
| | labels = torch.cat([box_labels, labels], dim=1) |
| | if normalize_coords: |
| | video_H = inference_state["video_height"] |
| | video_W = inference_state["video_width"] |
| | points = points / torch.tensor([video_W, video_H]).to(points.device) |
| | |
| | points = points * self.image_size |
| | points = points.to(inference_state["device"]) |
| | labels = labels.to(inference_state["device"]) |
| | point_inputs = concat_points(None, points, labels) |
| | sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( |
| | points=(point_inputs['point_coords'], point_inputs['point_labels']), |
| | boxes=None, |
| | masks=None, |
| | ) |
| | return sparse_embeddings |
| |
|
| |
|
| |
|
| | def add_new_points(self, *args, **kwargs): |
| | """Deprecated method. Please use `add_new_points_or_box` instead.""" |
| | return self.add_new_points_or_box(*args, **kwargs) |
| |
|
| | @torch.inference_mode() |
| | def add_new_mask( |
| | self, |
| | inference_state, |
| | frame_idx, |
| | obj_id, |
| | mask, |
| | ): |
| | """Add new mask to a frame.""" |
| | obj_idx = self._obj_id_to_idx(inference_state, obj_id) |
| | point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] |
| | mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] |
| |
|
| | if not isinstance(mask, torch.Tensor): |
| | mask = torch.tensor(mask, dtype=torch.bool) |
| | assert mask.dim() == 2 |
| | mask_H, mask_W = mask.shape |
| | mask_inputs_orig = mask[None, None] |
| | mask_inputs_orig = mask_inputs_orig.to(inference_state["device"]) |
| |
|
| | |
| | if mask_H != self.image_size or mask_W != self.image_size: |
| | mask_inputs = torch.nn.functional.interpolate( |
| | mask_inputs_orig, |
| | size=(self.image_size, self.image_size), |
| | align_corners=False, |
| | mode="bilinear", |
| | antialias=True, |
| | ) |
| | mask_inputs = (mask_inputs >= 0.5) |
| | else: |
| | mask_inputs = mask_inputs_orig |
| |
|
| | mask_inputs_per_frame[frame_idx] = mask_inputs |
| | point_inputs_per_frame.pop(frame_idx, None) |
| | |
| | |
| | |
| | |
| | is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] |
| | |
| | if is_init_cond_frame: |
| | reverse = False |
| | else: |
| | reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| | |
| | |
| | is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond |
| | storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| |
|
| | current_out, _ = self._run_single_frame_inference( |
| | inference_state=inference_state, |
| | output_dict=obj_output_dict, |
| | frame_idx=frame_idx, |
| | batch_size=1, |
| | is_init_cond_frame=is_init_cond_frame, |
| | point_inputs=None, |
| | mask_inputs=mask_inputs, |
| | reverse=reverse, |
| | |
| | |
| | |
| | |
| | run_mem_encoder=False, |
| | ) |
| | |
| | obj_temp_output_dict[storage_key][frame_idx] = current_out |
| |
|
| | |
| | obj_ids = inference_state["obj_ids"] |
| | consolidated_out = self._consolidate_temp_output_across_obj( |
| | inference_state, |
| | frame_idx, |
| | is_cond=is_cond, |
| | run_mem_encoder=False, |
| | consolidate_at_video_res=True, |
| | ) |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, consolidated_out["pred_masks_video_res"] |
| | ) |
| | return frame_idx, obj_ids, video_res_masks |
| |
|
| | def _get_orig_video_res_output(self, inference_state, any_res_masks): |
| | """ |
| | Resize the object scores to the original video resolution (video_res_masks) |
| | and apply non-overlapping constraints for final output. |
| | """ |
| | device = inference_state["device"] |
| | video_H = inference_state["video_height"] |
| | video_W = inference_state["video_width"] |
| | any_res_masks = any_res_masks.to(device, non_blocking=True) |
| | if any_res_masks.shape[-2:] == (video_H, video_W): |
| | video_res_masks = any_res_masks |
| | else: |
| | video_res_masks = torch.nn.functional.interpolate( |
| | any_res_masks, |
| | size=(video_H, video_W), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | if self.non_overlap_masks: |
| | video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) |
| | return any_res_masks, video_res_masks |
| |
|
| | def _consolidate_temp_output_across_obj( |
| | self, |
| | inference_state, |
| | frame_idx, |
| | is_cond, |
| | run_mem_encoder, |
| | consolidate_at_video_res=False, |
| | ): |
| | """ |
| | Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on |
| | a frame into a single output for all objects, including |
| | 1) fill any missing objects either from `output_dict_per_obj` (if they exist in |
| | `output_dict_per_obj` for this frame) or leave them as placeholder values |
| | (if they don't exist in `output_dict_per_obj` for this frame); |
| | 2) if specified, rerun memory encoder after apply non-overlapping constraints |
| | on the object scores. |
| | """ |
| | batch_size = self._get_obj_num(inference_state) |
| | storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| | |
| | |
| | if consolidate_at_video_res: |
| | assert not run_mem_encoder, "memory encoder cannot run at video resolution" |
| | consolidated_H = inference_state["video_height"] |
| | consolidated_W = inference_state["video_width"] |
| | consolidated_mask_key = "pred_masks_video_res" |
| | else: |
| | consolidated_H = consolidated_W = self.image_size // 4 |
| | consolidated_mask_key = "pred_masks" |
| |
|
| | |
| | |
| | |
| | |
| | consolidated_out = { |
| | "maskmem_features": None, |
| | "maskmem_pos_enc": None, |
| | consolidated_mask_key: torch.full( |
| | size=(batch_size, 1, consolidated_H, consolidated_W), |
| | fill_value=NO_OBJ_SCORE, |
| | dtype=torch.bfloat16, |
| | device=inference_state["storage_device"], |
| | ), |
| | "obj_ptr": torch.full( |
| | size=(batch_size, self.hidden_dim), |
| | fill_value=NO_OBJ_SCORE, |
| | dtype=torch.bfloat16, |
| | device=inference_state["device"], |
| | ), |
| | } |
| | empty_mask_ptr = None |
| | for obj_idx in range(batch_size): |
| | obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] |
| | obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] |
| | out = obj_temp_output_dict[storage_key].get(frame_idx, None) |
| | |
| | |
| | |
| | |
| | if out is None: |
| | out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) |
| | if out is None: |
| | out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) |
| | |
| | |
| | |
| | if out is None: |
| | |
| | |
| | |
| | if run_mem_encoder: |
| | if empty_mask_ptr is None: |
| | empty_mask_ptr = self._get_empty_mask_ptr( |
| | inference_state, frame_idx |
| | ) |
| | |
| | consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr |
| | continue |
| | |
| | obj_mask = out["pred_masks"] |
| | consolidated_pred_masks = consolidated_out[consolidated_mask_key] |
| | if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: |
| | consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask |
| | else: |
| | |
| | resized_obj_mask = torch.nn.functional.interpolate( |
| | obj_mask, |
| | size=consolidated_pred_masks.shape[-2:], |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask |
| | consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] |
| |
|
| | |
| | |
| | if run_mem_encoder: |
| | device = inference_state["device"] |
| | high_res_masks = torch.nn.functional.interpolate( |
| | consolidated_out["pred_masks"].to(device, non_blocking=True), |
| | size=(self.image_size, self.image_size), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | if self.non_overlap_masks_for_mem_enc: |
| | high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) |
| | maskmem_features, maskmem_pos_enc = self._run_memory_encoder( |
| | inference_state=inference_state, |
| | frame_idx=frame_idx, |
| | batch_size=batch_size, |
| | high_res_masks=high_res_masks, |
| | is_mask_from_pts=True, |
| | ) |
| | consolidated_out["maskmem_features"] = maskmem_features |
| | consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc |
| |
|
| | return consolidated_out |
| |
|
| | def _get_empty_mask_ptr(self, inference_state, frame_idx): |
| | """Get a dummy object pointer based on an empty mask on the current frame.""" |
| | |
| | batch_size = 1 |
| | mask_inputs = torch.zeros( |
| | (batch_size, 1, self.image_size, self.image_size), |
| | dtype=torch.bfloat16, |
| | device=inference_state["device"], |
| | ) |
| |
|
| | |
| | ( |
| | _, |
| | _, |
| | current_vision_feats, |
| | current_vision_pos_embeds, |
| | feat_sizes, |
| | ) = self._get_image_feature(inference_state, frame_idx, batch_size) |
| |
|
| | |
| | current_out = self.track_step( |
| | frame_idx=frame_idx, |
| | is_init_cond_frame=True, |
| | current_vision_feats=current_vision_feats, |
| | current_vision_pos_embeds=current_vision_pos_embeds, |
| | feat_sizes=feat_sizes, |
| | point_inputs=None, |
| | mask_inputs=mask_inputs, |
| | output_dict={}, |
| | num_frames=inference_state["num_frames"], |
| | track_in_reverse=False, |
| | run_mem_encoder=False, |
| | prev_sam_mask_logits=None, |
| | ) |
| | return current_out["obj_ptr"] |
| |
|
| | @torch.inference_mode() |
| | def propagate_in_video_preflight(self, inference_state): |
| | """Prepare inference_state and consolidate temporary outputs before tracking.""" |
| | |
| | inference_state["tracking_has_started"] = True |
| | batch_size = self._get_obj_num(inference_state) |
| |
|
| | |
| | |
| | temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] |
| | output_dict = inference_state["output_dict"] |
| | |
| | |
| | |
| | consolidated_frame_inds = inference_state["consolidated_frame_inds"] |
| | for is_cond in [False, True]: |
| | |
| | storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" |
| | |
| | |
| | |
| | temp_frame_inds = set() |
| | for obj_temp_output_dict in temp_output_dict_per_obj.values(): |
| | temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) |
| | consolidated_frame_inds[storage_key].update(temp_frame_inds) |
| | |
| | for frame_idx in temp_frame_inds: |
| | consolidated_out = self._consolidate_temp_output_across_obj( |
| | inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True |
| | ) |
| | |
| | output_dict[storage_key][frame_idx] = consolidated_out |
| | self._add_output_per_object( |
| | inference_state, frame_idx, consolidated_out, storage_key |
| | ) |
| | clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( |
| | self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 |
| | ) |
| | if clear_non_cond_mem: |
| | |
| | self._clear_non_cond_mem_around_input(inference_state, frame_idx) |
| |
|
| | |
| | for obj_temp_output_dict in temp_output_dict_per_obj.values(): |
| | obj_temp_output_dict[storage_key].clear() |
| |
|
| | |
| | |
| | for frame_idx in output_dict["cond_frame_outputs"]: |
| | output_dict["non_cond_frame_outputs"].pop(frame_idx, None) |
| | for obj_output_dict in inference_state["output_dict_per_obj"].values(): |
| | for frame_idx in obj_output_dict["cond_frame_outputs"]: |
| | obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) |
| | for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: |
| | assert frame_idx in output_dict["cond_frame_outputs"] |
| | consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) |
| |
|
| | |
| | |
| | all_consolidated_frame_inds = ( |
| | consolidated_frame_inds["cond_frame_outputs"] |
| | | consolidated_frame_inds["non_cond_frame_outputs"] |
| | ) |
| | input_frames_inds = set() |
| | for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): |
| | input_frames_inds.update(point_inputs_per_frame.keys()) |
| | for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): |
| | input_frames_inds.update(mask_inputs_per_frame.keys()) |
| | assert all_consolidated_frame_inds == input_frames_inds |
| |
|
| | @torch.inference_mode() |
| | def propagate_in_video( |
| | self, |
| | inference_state, |
| | start_frame_idx=None, |
| | max_frame_num_to_track=None, |
| | reverse=False, |
| | ): |
| | """Propagate the input points across frames to track in the entire video.""" |
| | self.propagate_in_video_preflight(inference_state) |
| |
|
| | output_dict = inference_state["output_dict"] |
| | consolidated_frame_inds = inference_state["consolidated_frame_inds"] |
| | obj_ids = inference_state["obj_ids"] |
| | num_frames = inference_state["num_frames"] |
| | batch_size = self._get_obj_num(inference_state) |
| | if len(output_dict["cond_frame_outputs"]) == 0: |
| | raise RuntimeError("No points are provided; please add points first") |
| | clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( |
| | self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 |
| | ) |
| |
|
| | |
| | if start_frame_idx is None: |
| | |
| | start_frame_idx = min(output_dict["cond_frame_outputs"]) |
| | if max_frame_num_to_track is None: |
| | |
| | max_frame_num_to_track = num_frames |
| | if reverse: |
| | end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) |
| | if start_frame_idx > 0: |
| | processing_order = range(start_frame_idx, end_frame_idx - 1, -1) |
| | else: |
| | processing_order = [] |
| | else: |
| | end_frame_idx = min( |
| | start_frame_idx + max_frame_num_to_track, num_frames - 1 |
| | ) |
| | processing_order = range(start_frame_idx, end_frame_idx + 1) |
| |
|
| | for frame_idx in tqdm(processing_order, desc="propagate in video"): |
| | |
| | |
| | |
| | |
| | if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: |
| | storage_key = "cond_frame_outputs" |
| | current_out = output_dict[storage_key][frame_idx] |
| | pred_masks = current_out["pred_masks"] |
| | if clear_non_cond_mem: |
| | |
| | self._clear_non_cond_mem_around_input(inference_state, frame_idx) |
| | elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: |
| | storage_key = "non_cond_frame_outputs" |
| | current_out = output_dict[storage_key][frame_idx] |
| | pred_masks = current_out["pred_masks"] |
| | else: |
| | storage_key = "non_cond_frame_outputs" |
| | current_out, pred_masks = self._run_single_frame_inference( |
| | inference_state=inference_state, |
| | output_dict=output_dict, |
| | frame_idx=frame_idx, |
| | batch_size=batch_size, |
| | is_init_cond_frame=False, |
| | point_inputs=None, |
| | mask_inputs=None, |
| | reverse=reverse, |
| | run_mem_encoder=True, |
| | ) |
| | output_dict[storage_key][frame_idx] = current_out |
| | |
| | |
| | self._add_output_per_object( |
| | inference_state, frame_idx, current_out, storage_key |
| | ) |
| | inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} |
| |
|
| | |
| | |
| | _, video_res_masks = self._get_orig_video_res_output( |
| | inference_state, pred_masks |
| | ) |
| | yield frame_idx, obj_ids, video_res_masks |
| |
|
| | def _add_output_per_object( |
| | self, inference_state, frame_idx, current_out, storage_key |
| | ): |
| | """ |
| | Split a multi-object output into per-object output slices and add them into |
| | `output_dict_per_obj`. The resulting slices share the same tensor storage. |
| | """ |
| | maskmem_features = current_out["maskmem_features"] |
| | assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) |
| |
|
| | maskmem_pos_enc = current_out["maskmem_pos_enc"] |
| | assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) |
| |
|
| | output_dict_per_obj = inference_state["output_dict_per_obj"] |
| | for obj_idx, obj_output_dict in output_dict_per_obj.items(): |
| | obj_slice = slice(obj_idx, obj_idx + 1) |
| | obj_out = { |
| | "maskmem_features": None, |
| | "maskmem_pos_enc": None, |
| | "pred_masks": current_out["pred_masks"][obj_slice], |
| | "obj_ptr": current_out["obj_ptr"][obj_slice], |
| | } |
| | if maskmem_features is not None: |
| | obj_out["maskmem_features"] = maskmem_features[obj_slice] |
| | if maskmem_pos_enc is not None: |
| | obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] |
| | obj_output_dict[storage_key][frame_idx] = obj_out |
| |
|
| | @torch.inference_mode() |
| | def reset_state(self, inference_state): |
| | """Remove all input points or mask in all frames throughout the video.""" |
| | self._reset_tracking_results(inference_state) |
| | |
| | inference_state["obj_id_to_idx"].clear() |
| | inference_state["obj_idx_to_id"].clear() |
| | inference_state["obj_ids"].clear() |
| | inference_state["point_inputs_per_obj"].clear() |
| | inference_state["mask_inputs_per_obj"].clear() |
| | inference_state["output_dict_per_obj"].clear() |
| | inference_state["temp_output_dict_per_obj"].clear() |
| |
|
| | def _reset_tracking_results(self, inference_state): |
| | """Reset all tracking inputs and results across the videos.""" |
| | for v in inference_state["point_inputs_per_obj"].values(): |
| | v.clear() |
| | for v in inference_state["mask_inputs_per_obj"].values(): |
| | v.clear() |
| | for v in inference_state["output_dict_per_obj"].values(): |
| | v["cond_frame_outputs"].clear() |
| | v["non_cond_frame_outputs"].clear() |
| | for v in inference_state["temp_output_dict_per_obj"].values(): |
| | v["cond_frame_outputs"].clear() |
| | v["non_cond_frame_outputs"].clear() |
| | inference_state["output_dict"]["cond_frame_outputs"].clear() |
| | inference_state["output_dict"]["non_cond_frame_outputs"].clear() |
| | inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() |
| | inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() |
| | inference_state["tracking_has_started"] = False |
| | inference_state["frames_already_tracked"].clear() |
| |
|
| | def _get_image_feature(self, inference_state, frame_idx, batch_size): |
| | """Compute the image features on a given frame.""" |
| | |
| | image, backbone_out = inference_state["cached_features"].get( |
| | frame_idx, (None, None) |
| | ) |
| | if backbone_out is None: |
| | |
| | device = inference_state["device"] |
| | |
| | |
| | image = inference_state["images"][frame_idx].to(device).unsqueeze(0) |
| | |
| | backbone_out = self.forward_image(image) |
| | |
| | |
| | inference_state["cached_features"] = {frame_idx: (image, backbone_out)} |
| |
|
| | |
| | expanded_image = image.expand(batch_size, -1, -1, -1) |
| | expanded_backbone_out = { |
| | "backbone_fpn": backbone_out["backbone_fpn"].copy(), |
| | "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), |
| | } |
| | for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): |
| | expanded_backbone_out["backbone_fpn"][i] = feat.expand( |
| | batch_size, -1, -1, -1 |
| | ) |
| | for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): |
| | pos = pos.expand(batch_size, -1, -1, -1) |
| | expanded_backbone_out["vision_pos_enc"][i] = pos |
| |
|
| | features = self._prepare_backbone_features(expanded_backbone_out) |
| | features = (expanded_image,) + features |
| | return features |
| |
|
| | def _run_single_frame_inference( |
| | self, |
| | inference_state, |
| | output_dict, |
| | frame_idx, |
| | batch_size, |
| | is_init_cond_frame, |
| | point_inputs, |
| | mask_inputs, |
| | reverse, |
| | run_mem_encoder, |
| | prev_sam_mask_logits=None, |
| | ): |
| | """Run tracking on a single frame based on current inputs and previous memory.""" |
| | |
| | ( |
| | _, |
| | _, |
| | current_vision_feats, |
| | current_vision_pos_embeds, |
| | feat_sizes, |
| | ) = self._get_image_feature(inference_state, frame_idx, batch_size) |
| |
|
| | |
| | assert point_inputs is None or mask_inputs is None |
| | current_out = self.track_step( |
| | frame_idx=frame_idx, |
| | is_init_cond_frame=is_init_cond_frame, |
| | current_vision_feats=current_vision_feats, |
| | current_vision_pos_embeds=current_vision_pos_embeds, |
| | feat_sizes=feat_sizes, |
| | point_inputs=point_inputs, |
| | mask_inputs=mask_inputs, |
| | output_dict=output_dict, |
| | num_frames=inference_state["num_frames"], |
| | track_in_reverse=reverse, |
| | run_mem_encoder=run_mem_encoder, |
| | prev_sam_mask_logits=prev_sam_mask_logits, |
| | ) |
| |
|
| | |
| | storage_device = inference_state["storage_device"] |
| | maskmem_features = current_out["maskmem_features"] |
| | if maskmem_features is not None: |
| | maskmem_features = maskmem_features.to(torch.bfloat16) |
| | maskmem_features = maskmem_features.to(storage_device, non_blocking=True) |
| | pred_masks_gpu = current_out["pred_masks"] |
| | |
| | if self.fill_hole_area > 0: |
| | pred_masks_gpu = fill_holes_in_mask_scores( |
| | pred_masks_gpu, self.fill_hole_area |
| | ) |
| | pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) |
| | |
| | maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) |
| | |
| | obj_ptr = current_out["obj_ptr"] |
| | |
| | compact_current_out = { |
| | "maskmem_features": maskmem_features, |
| | "maskmem_pos_enc": maskmem_pos_enc, |
| | "pred_masks": pred_masks, |
| | "obj_ptr": obj_ptr, |
| | } |
| | return compact_current_out, pred_masks_gpu |
| |
|
| |
|
| | def _run_single_frame_inference_embed( |
| | self, |
| | inference_state, |
| | output_dict, |
| | frame_idx, |
| | batch_size, |
| | is_init_cond_frame, |
| | box_embed, |
| | point_inputs, |
| | mask_inputs, |
| | reverse, |
| | run_mem_encoder, |
| | prev_sam_mask_logits=None, |
| | ): |
| | """Run tracking on a single frame based on current inputs and previous memory.""" |
| | |
| | ( |
| | _, |
| | _, |
| | current_vision_feats, |
| | current_vision_pos_embeds, |
| | feat_sizes, |
| | ) = self._get_image_feature(inference_state, frame_idx, batch_size) |
| |
|
| | |
| | assert point_inputs is None or mask_inputs is None |
| | current_out = self.track_step_embed( |
| | frame_idx=frame_idx, |
| | is_init_cond_frame=is_init_cond_frame, |
| | current_vision_feats=current_vision_feats, |
| | current_vision_pos_embeds=current_vision_pos_embeds, |
| | feat_sizes=feat_sizes, |
| | box_embed=box_embed, |
| | point_inputs=point_inputs, |
| | mask_inputs=mask_inputs, |
| | output_dict=output_dict, |
| | num_frames=inference_state["num_frames"], |
| | track_in_reverse=reverse, |
| | run_mem_encoder=run_mem_encoder, |
| | prev_sam_mask_logits=prev_sam_mask_logits, |
| | ) |
| |
|
| | |
| | storage_device = inference_state["storage_device"] |
| | maskmem_features = current_out["maskmem_features"] |
| | if maskmem_features is not None: |
| | maskmem_features = maskmem_features.to(torch.bfloat16) |
| | maskmem_features = maskmem_features.to(storage_device, non_blocking=True) |
| | pred_masks_gpu = current_out["pred_masks"] |
| | |
| | if self.fill_hole_area > 0: |
| | pred_masks_gpu = fill_holes_in_mask_scores( |
| | pred_masks_gpu, self.fill_hole_area |
| | ) |
| | pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) |
| | |
| | maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) |
| | |
| | obj_ptr = current_out["obj_ptr"] |
| | |
| | compact_current_out = { |
| | "maskmem_features": maskmem_features, |
| | "maskmem_pos_enc": maskmem_pos_enc, |
| | "pred_masks": pred_masks, |
| | "obj_ptr": obj_ptr, |
| | } |
| | return compact_current_out, pred_masks_gpu |
| |
|
| |
|
| | def _run_memory_encoder( |
| | self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts |
| | ): |
| | """ |
| | Run the memory encoder on `high_res_masks`. This is usually after applying |
| | non-overlapping constraints to object scores. Since their scores changed, their |
| | memory also need to be computed again with the memory encoder. |
| | """ |
| | |
| | _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( |
| | inference_state, frame_idx, batch_size |
| | ) |
| | maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| | current_vision_feats=current_vision_feats, |
| | feat_sizes=feat_sizes, |
| | pred_masks_high_res=high_res_masks, |
| | is_mask_from_pts=is_mask_from_pts, |
| | ) |
| |
|
| | |
| | storage_device = inference_state["storage_device"] |
| | maskmem_features = maskmem_features.to(torch.bfloat16) |
| | maskmem_features = maskmem_features.to(storage_device, non_blocking=True) |
| | |
| | maskmem_pos_enc = self._get_maskmem_pos_enc( |
| | inference_state, {"maskmem_pos_enc": maskmem_pos_enc} |
| | ) |
| | return maskmem_features, maskmem_pos_enc |
| |
|
| | def _get_maskmem_pos_enc(self, inference_state, current_out): |
| | """ |
| | `maskmem_pos_enc` is the same across frames and objects, so we cache it as |
| | a constant in the inference session to reduce session storage size. |
| | """ |
| | model_constants = inference_state["constants"] |
| | |
| | out_maskmem_pos_enc = current_out["maskmem_pos_enc"] |
| | if out_maskmem_pos_enc is not None: |
| | if "maskmem_pos_enc" not in model_constants: |
| | assert isinstance(out_maskmem_pos_enc, list) |
| | |
| | maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] |
| | model_constants["maskmem_pos_enc"] = maskmem_pos_enc |
| | else: |
| | maskmem_pos_enc = model_constants["maskmem_pos_enc"] |
| | |
| | batch_size = out_maskmem_pos_enc[0].size(0) |
| | expanded_maskmem_pos_enc = [ |
| | x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc |
| | ] |
| | else: |
| | expanded_maskmem_pos_enc = None |
| | return expanded_maskmem_pos_enc |
| |
|
| | def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): |
| | """ |
| | Remove the non-conditioning memory around the input frame. When users provide |
| | correction clicks, the surrounding frames' non-conditioning memories can still |
| | contain outdated object appearance information and could confuse the model. |
| | |
| | This method clears those non-conditioning memories surrounding the interacted |
| | frame to avoid giving the model both old and new information about the object. |
| | """ |
| | r = self.memory_temporal_stride_for_eval |
| | frame_idx_begin = frame_idx - r * self.num_maskmem |
| | frame_idx_end = frame_idx + r * self.num_maskmem |
| | output_dict = inference_state["output_dict"] |
| | non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] |
| | for t in range(frame_idx_begin, frame_idx_end + 1): |
| | non_cond_frame_outputs.pop(t, None) |
| | for obj_output_dict in inference_state["output_dict_per_obj"].values(): |
| | obj_output_dict["non_cond_frame_outputs"].pop(t, None) |
| |
|