Spaces:
Runtime error
Runtime error
File size: 84,091 Bytes
14114e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import datetime
import logging
import math
import os
from collections import defaultdict
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Set
import numpy as np
import numpy.typing as npt
import torch
import torch.distributed as dist
import torch.nn.functional as F
from sam3 import perflib
from sam3.logger import get_logger
from sam3.model.box_ops import fast_diag_box_iou
from sam3.model.data_misc import BatchedDatapoint
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box
from sam3.perflib.masks_ops import mask_iou
from sam3.train.masks_ops import rle_encode
from torch import nn, Tensor
logger = get_logger(__name__)
class MaskletConfirmationStatus(Enum):
UNCONFIRMED = 1 # newly added masklet, not confirmed by any detection yet
CONFIRMED = 2 # confirmed by at least one detection
class Sam3VideoBase(nn.Module):
def __init__(
self,
detector: nn.Module,
tracker: nn.Module,
# prob threshold for detection outputs -- only keep detections above this threshold
# enters NMS and det-to-track matching
score_threshold_detection=0.5,
# IoU threshold for detection NMS
det_nms_thresh=0.0,
# IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it
# overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1
assoc_iou_thresh=0.5,
# IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched"
# by any detections -- it is often a stricter threshold like 0.5
trk_assoc_iou_thresh=0.5,
# prob threshold for a detection to be added as a new object
new_det_thresh=0.0,
# hotstart parameters: we hold off the outputs for `hotstart_delay` frames and
# 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh`
# 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh`
hotstart_delay=0,
hotstart_unmatch_thresh=3,
hotstart_dup_thresh=3,
# Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period.
suppress_unmatched_only_within_hotstart=True,
init_trk_keep_alive=0,
max_trk_keep_alive=8,
min_trk_keep_alive=-4,
# Threshold for suppressing overlapping objects based on recent occlusion
suppress_overlapping_based_on_recent_occlusion_threshold=0.0,
decrease_trk_keep_alive_for_empty_masklets=False,
o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets
suppress_det_close_to_boundary=False,
fill_hole_area=16,
# The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1)
max_num_objects=-1,
recondition_every_nth_frame=-1,
# masket confirmation status (to suppress unconfirmed masklets)
masklet_confirmation_enable=False,
# a masklet is confirmed after being consecutively detected and matched for
# `masklet_confirmation_consecutive_det_thresh`
masklet_confirmation_consecutive_det_thresh=3,
# bbox heuristic parameters
reconstruction_bbox_iou_thresh=0.0,
reconstruction_bbox_det_score=0.0,
):
super().__init__()
self.detector = detector
self.tracker = tracker
self.score_threshold_detection = score_threshold_detection
self.det_nms_thresh = det_nms_thresh
self.assoc_iou_thresh = assoc_iou_thresh
self.trk_assoc_iou_thresh = trk_assoc_iou_thresh
self.new_det_thresh = new_det_thresh
# hotstart parameters
if hotstart_delay > 0:
assert hotstart_unmatch_thresh <= hotstart_delay
assert hotstart_dup_thresh <= hotstart_delay
self.hotstart_delay = hotstart_delay
self.hotstart_unmatch_thresh = hotstart_unmatch_thresh
self.hotstart_dup_thresh = hotstart_dup_thresh
self.suppress_unmatched_only_within_hotstart = (
suppress_unmatched_only_within_hotstart
)
self.init_trk_keep_alive = init_trk_keep_alive
self.max_trk_keep_alive = max_trk_keep_alive
self.min_trk_keep_alive = min_trk_keep_alive
self.suppress_overlapping_based_on_recent_occlusion_threshold = (
suppress_overlapping_based_on_recent_occlusion_threshold
)
self.suppress_det_close_to_boundary = suppress_det_close_to_boundary
self.decrease_trk_keep_alive_for_empty_masklets = (
decrease_trk_keep_alive_for_empty_masklets
)
self.o2o_matching_masklets_enable = o2o_matching_masklets_enable
self.fill_hole_area = fill_hole_area
self.eval()
self.rank = int(os.getenv("RANK", "0"))
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use)
# the maximum object number
if max_num_objects > 0:
num_obj_for_compile = math.ceil(max_num_objects / self.world_size)
else:
max_num_objects = 10000 # no limit
num_obj_for_compile = 16
logger.info(f"setting {max_num_objects=} and {num_obj_for_compile=}")
self.max_num_objects = max_num_objects
self.num_obj_for_compile = num_obj_for_compile
self.recondition_every_nth_frame = recondition_every_nth_frame
self.masklet_confirmation_enable = masklet_confirmation_enable
self.masklet_confirmation_consecutive_det_thresh = (
masklet_confirmation_consecutive_det_thresh
)
self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh
self.reconstruction_bbox_det_score = reconstruction_bbox_det_score
@property
def device(self):
self._device = getattr(self, "_device", None) or next(self.parameters()).device
return self._device
def _init_dist_pg_cpu(self):
# a short 3-min timeout to quickly detect any synchronization failures
timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180"))
timeout = datetime.timedelta(seconds=timeout_sec)
self._dist_pg_cpu = dist.new_group(backend="gloo", timeout=timeout)
def broadcast_python_obj_cpu(self, python_obj_list, src):
if self._dist_pg_cpu is None:
self._init_dist_pg_cpu()
dist.broadcast_object_list(python_obj_list, src=src, group=self._dist_pg_cpu)
def _det_track_one_frame(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
input_batch: BatchedDatapoint,
geometric_prompt: Any,
tracker_states_local: List[Any],
tracker_metadata_prev: Dict[str, Any],
feature_cache: Dict,
orig_vid_height: int,
orig_vid_width: int,
is_image_only: bool = False,
allow_new_detections: bool = True,
):
"""
This function handles one-step inference for the DenseTracking model in an SPMD manner.
At a high-level, all GPUs execute the same function calls as if it's done on a single GPU,
while under the hood, some function calls involve distributed computation based on sharded
SAM2 states.
- `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs
- `tracker_states_local` holds the local masklet information in this GPU shard
- `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs
it contains both global and local masklet information
"""
# Step 1: run backbone and detector in a distributed manner -- this is done via Sam3ImageOnVideoMultiGPU,
# a MultiGPU model (assigned to `self.detector`) that shards frames in a round-robin manner.
# It returns a "det_out" dict for `frame_idx` and fills SAM2 backbone features for `frame_idx`
# into `feature_cache`. Despite its distributed inference under the hood, the results would be
# the same as if it is running backbone and detector for every frame on a single GPU.
det_out = self.run_backbone_and_detection(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
input_batch=input_batch,
geometric_prompt=geometric_prompt,
feature_cache=feature_cache,
allow_new_detections=allow_new_detections,
)
# Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks.
# the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions
# gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only
# runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks;
# we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics.
if tracker_metadata_prev == {}:
# initialize masklet metadata if it's uninitialized (empty dict)
tracker_metadata_prev.update(self._initialize_metadata())
tracker_low_res_masks_global, tracker_obj_scores_global = (
self.run_tracker_propagation(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
tracker_states_local=tracker_states_local,
tracker_metadata_prev=tracker_metadata_prev,
)
)
# Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans
# for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc).
# We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints.
# **This step should involve all the heuristics needed for any updates.** Most of the update
# planning will be done on the master rank (GPU 0) and the resulting plan `tracker_update_plan` is
# broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the
# new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`).
tracker_update_plan, tracker_metadata_new = (
self.run_tracker_update_planning_phase(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
det_out=det_out,
tracker_low_res_masks_global=tracker_low_res_masks_global,
tracker_obj_scores_global=tracker_obj_scores_global,
tracker_metadata_prev=tracker_metadata_prev,
tracker_states_local=tracker_states_local,
is_image_only=is_image_only,
)
)
# Get reconditioning info from the update plan
reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set())
det_to_matched_trk_obj_ids = tracker_update_plan.get(
"det_to_matched_trk_obj_ids", {}
)
# Step 4: based on `tracker_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states
tracker_states_local_new = self.run_tracker_update_execution_phase(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
det_out=det_out,
tracker_states_local=tracker_states_local,
tracker_update_plan=tracker_update_plan,
orig_vid_height=orig_vid_height,
orig_vid_width=orig_vid_width,
feature_cache=feature_cache,
)
# Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since
# only GPU 0 will send outputs to the server).
if self.rank == 0:
obj_id_to_mask = self.build_outputs(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
det_out=det_out,
tracker_low_res_masks_global=tracker_low_res_masks_global,
tracker_obj_scores_global=tracker_obj_scores_global,
tracker_metadata_prev=tracker_metadata_prev,
tracker_update_plan=tracker_update_plan,
orig_vid_height=orig_vid_height,
orig_vid_width=orig_vid_width,
reconditioned_obj_ids=reconditioned_obj_ids,
det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
)
obj_id_to_score = tracker_metadata_new["obj_id_to_score"]
else:
obj_id_to_mask, obj_id_to_score = {}, {} # dummy outputs on other GPUs
# a few statistics for the current frame as a part of the output
frame_stats = {
"num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]),
"num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"],
}
# add tracker scores to metadata, it should be fired for frames except the first frame
if tracker_obj_scores_global.shape[0] > 0:
# Convert tracker_obj_scores_global to sigmoid scores before updating
tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist()
tracker_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"]
tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][
frame_idx
].update(dict(zip(tracker_obj_ids, tracker_obj_scores_global)))
return (
obj_id_to_mask, # a dict: obj_id --> output mask
obj_id_to_score, # a dict: obj_id --> output score (prob)
tracker_states_local_new,
tracker_metadata_new,
frame_stats,
tracker_obj_scores_global, # a dict: obj_id --> tracker frame-level scores
)
def _suppress_detections_close_to_boundary(self, boxes, margin=0.025):
"""
Suppress detections too close to image edges (for normalized boxes).
boxes: (N, 4) in xyxy format, normalized [0,1]
margin: fraction of image
"""
x_min, y_min, x_max, y_max = boxes.unbind(-1)
x_c = (x_min + x_max) / 2
y_c = (y_min + y_max) / 2
keep = (
(x_c > margin)
& (x_c < 1.0 - margin)
& (y_c > margin)
& (y_c < 1.0 - margin)
)
return keep
def run_backbone_and_detection(
self,
frame_idx: int,
num_frames: int,
input_batch: BatchedDatapoint,
geometric_prompt: Any,
feature_cache: Dict,
reverse: bool,
allow_new_detections: bool,
):
# Step 1: if text feature is not cached in `feature_cache`, compute and cache it
text_batch_key = tuple(input_batch.find_text_batch)
if "text" not in feature_cache or text_batch_key not in feature_cache["text"]:
text_outputs = self.detector.backbone.forward_text(
input_batch.find_text_batch, device=self.device
)
# note: we only cache the text feature of the most recent prompt
feature_cache["text"] = {text_batch_key: text_outputs}
else:
text_outputs = feature_cache["text"][text_batch_key]
# Step 2: run backbone, detector, and post-processing with NMS
if "multigpu_buffer" not in feature_cache:
# "multigpu_buffer" is a buffer cache used by `self.detector` and it needs
# to be passed to `forward_video_grounding_multigpu` for every call
feature_cache["multigpu_buffer"] = {}
# Extract max_frame_num_to_track from feature_cache if available
tracking_bounds = feature_cache.get("tracking_bounds", {})
max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track")
start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx")
sam3_image_out, _ = self.detector.forward_video_grounding_multigpu(
backbone_out={
"img_batch_all_stages": input_batch.img_batch,
**text_outputs,
},
find_inputs=input_batch.find_inputs,
geometric_prompt=geometric_prompt,
frame_idx=frame_idx,
num_frames=num_frames,
multigpu_buffer=feature_cache["multigpu_buffer"],
track_in_reverse=reverse,
# also get the SAM2 backbone features
return_tracker_backbone_feats=True,
# run NMS as a part of distributed computation
run_nms=self.det_nms_thresh > 0.0,
nms_prob_thresh=self.score_threshold_detection,
nms_iou_thresh=self.det_nms_thresh,
# pass max_frame_num_to_track to respect tracking limits
max_frame_num_to_track=max_frame_num_to_track,
propagate_in_video_start_frame_idx=start_frame_idx,
)
# note: detections in `sam3_image_out` has already gone through NMS
pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid()
if not allow_new_detections:
pred_probs = pred_probs - 1e8 # make sure no detections are kept
pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"]
pred_masks = sam3_image_out["pred_masks"]
# get the positive detection outputs above threshold
pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection)
det_out = {
"bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]],
"mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]],
"scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]],
}
# Step 3: build SAM2 backbone features and store them in `feature_cache`
backbone_cache = {}
sam_mask_decoder = self.tracker.sam_mask_decoder
tracker_backbone_fpn = [
sam_mask_decoder.conv_s0(sam3_image_out["tracker_backbone_fpn_0"]),
sam_mask_decoder.conv_s1(sam3_image_out["tracker_backbone_fpn_1"]),
sam3_image_out["tracker_backbone_fpn_2"], # fpn_2 doesn't need conv
]
tracker_backbone_out = {
"vision_features": tracker_backbone_fpn[-1], # top-level feature
"vision_pos_enc": sam3_image_out["tracker_backbone_pos_enc"],
"backbone_fpn": tracker_backbone_fpn,
}
backbone_cache["tracker_backbone_out"] = tracker_backbone_out
feature_cache[frame_idx] = (
input_batch.img_batch[frame_idx],
backbone_cache,
)
# remove from `feature_cache` old features to save GPU memory
feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None)
return det_out
def run_tracker_propagation(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
tracker_states_local: List[Any],
tracker_metadata_prev: Dict[str, npt.NDArray],
):
# Step 1: propagate the local SAM2 states to get the current frame's prediction
# `low_res_masks_local` of the existing masklets on this GPU
# - obj_ids_local: List[int] -- list of object IDs
# - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask)
obj_ids_local, low_res_masks_local, obj_scores_local = (
self._propogate_tracker_one_frame_local_gpu(
tracker_states_local, frame_idx=frame_idx, reverse=reverse
)
)
assert np.all(
obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank]
), "{} != {}".format(
obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank]
)
# Step 2: all-gather `low_res_masks_local` into `low_res_masks_global`
# - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask)
_, H_mask, W_mask = low_res_masks_local.shape
if self.world_size > 1:
# `low_res_masks_local` and `obj_scores_local` need to be contiguous and float32
# (they could be non-contiguous due to slicing and/or bfloat16 due to autocast)
low_res_masks_local = low_res_masks_local.float().contiguous()
obj_scores_local = obj_scores_local.float().contiguous()
num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank]
assert low_res_masks_local.size(0) == num_obj_this_gpu
assert obj_scores_local.size(0) == num_obj_this_gpu
low_res_masks_peers = [
low_res_masks_local.new_empty(num_obj, H_mask, W_mask)
for num_obj in tracker_metadata_prev["num_obj_per_gpu"]
]
obj_scores_peers = [
obj_scores_local.new_empty(num_obj)
for num_obj in tracker_metadata_prev["num_obj_per_gpu"]
]
dist.all_gather(low_res_masks_peers, low_res_masks_local)
dist.all_gather(obj_scores_peers, obj_scores_local)
low_res_masks_global = torch.cat(low_res_masks_peers, dim=0)
obj_scores_global = torch.cat(obj_scores_peers, dim=0)
else:
low_res_masks_global = low_res_masks_local
obj_scores_global = obj_scores_local
return low_res_masks_global, obj_scores_global
def _recondition_masklets(
self,
frame_idx,
det_out: Dict[str, Tensor],
trk_id_to_max_iou_high_conf_det: List[int],
tracker_states_local: List[Any],
tracker_metadata: Dict[str, npt.NDArray],
tracker_obj_scores_global: Tensor,
):
# Recondition the masklets based on the new detections
for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
new_mask = det_out["mask"][det_idx : det_idx + 1]
input_mask_res = self.tracker.input_mask_size
new_mask_binary = (
F.interpolate(
new_mask.unsqueeze(1),
size=(input_mask_res, input_mask_res),
mode="bilinear",
align_corners=False,
).squeeze(1)[0]
> 0
)
HIGH_CONF_THRESH = 0.8
reconditioned_states_idx = set()
obj_idx = np.where(tracker_metadata["obj_ids_all_gpu"] == trk_obj_id)[
0
].item()
obj_score = tracker_obj_scores_global[obj_idx]
for state_idx, inference_state in enumerate(tracker_states_local):
if (
trk_obj_id in inference_state["obj_ids"]
# NOTE: Goal of this condition is to avoid reconditioning masks that are occluded/low qualiy.
# Unfortunately, these can get reconditioned anyway due to batching. We should consider removing these heuristics.
and obj_score > HIGH_CONF_THRESH
):
logger.debug(
f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned."
)
self.tracker.add_new_mask(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=trk_obj_id,
mask=new_mask_binary,
)
reconditioned_states_idx.add(state_idx)
for idx in reconditioned_states_idx:
self.tracker.propagate_in_video_preflight(
tracker_states_local[idx], run_mem_encoder=True
)
return tracker_states_local
def run_tracker_update_planning_phase(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
det_out: Dict[str, Tensor],
tracker_low_res_masks_global: Tensor,
tracker_obj_scores_global: Tensor,
tracker_metadata_prev: Dict[str, npt.NDArray],
tracker_states_local: List[Any],
is_image_only: bool = False,
):
# initialize new metadata from previous metadata (its values will be updated later)
tracker_metadata_new = {
"obj_ids_per_gpu": deepcopy(tracker_metadata_prev["obj_ids_per_gpu"]),
"obj_ids_all_gpu": None, # will be filled later
"num_obj_per_gpu": deepcopy(tracker_metadata_prev["num_obj_per_gpu"]),
"obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]),
"obj_id_to_tracker_score_frame_wise": deepcopy(
tracker_metadata_prev["obj_id_to_tracker_score_frame_wise"]
),
"obj_id_to_last_occluded": {}, # will be filled later
"max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]),
}
# Initialize reconditioned_obj_ids early to avoid UnboundLocalError
reconditioned_obj_ids = set()
# Step 1: make the update plan and resolve heuristics on GPU 0
det_mask_preds: Tensor = det_out["mask"] # low-res mask logits
det_scores_np: npt.NDArray = det_out["scores"].float().cpu().numpy()
det_bbox_xyxy: Tensor = det_out["bbox"]
if self.rank == 0:
# a) match detector and tracker masks and find new objects
(
new_det_fa_inds,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
trk_id_to_max_iou_high_conf_det,
empty_trk_obj_ids,
) = self._associate_det_trk(
det_masks=det_mask_preds,
det_scores_np=det_scores_np,
trk_masks=tracker_low_res_masks_global,
trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"],
)
if self.suppress_det_close_to_boundary:
keep = self._suppress_detections_close_to_boundary(
det_bbox_xyxy[new_det_fa_inds]
)
new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()]
# check whether we've hit the maximum number of objects we can track (and if so, drop some detections)
prev_obj_num = np.sum(tracker_metadata_prev["num_obj_per_gpu"])
new_det_num = len(new_det_fa_inds)
num_obj_dropped_due_to_limit = 0
if not is_image_only and prev_obj_num + new_det_num > self.max_num_objects:
logger.warning(
f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}"
)
new_det_num_to_keep = self.max_num_objects - prev_obj_num
num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep
new_det_fa_inds = self._drop_new_det_with_obj_limit(
new_det_fa_inds, det_scores_np, new_det_num_to_keep
)
assert len(new_det_fa_inds) == new_det_num_to_keep
new_det_num = len(new_det_fa_inds)
# assign object IDs to new detections and decide which GPU to place them
new_det_start_obj_id = tracker_metadata_prev["max_obj_id"] + 1
new_det_obj_ids = new_det_start_obj_id + np.arange(new_det_num)
prev_workload_per_gpu = tracker_metadata_prev["num_obj_per_gpu"]
new_det_gpu_ids = self._assign_new_det_to_gpus(
new_det_num=new_det_num,
prev_workload_per_gpu=prev_workload_per_gpu,
)
# b) handle hotstart heuristics to remove objects
# here `rank0_metadata` contains metadata stored on (and only accessible to) GPU 0;
# we avoid broadcasting them to other GPUs to save communication cost, assuming
# that `rank0_metadata` is not needed by other GPUs
rank0_metadata_new = deepcopy(tracker_metadata_prev["rank0_metadata"])
if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
obj_ids_newly_removed, rank0_metadata_new = self._process_hotstart(
frame_idx=frame_idx,
num_frames=num_frames,
reverse=reverse,
det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
new_det_obj_ids=new_det_obj_ids,
empty_trk_obj_ids=empty_trk_obj_ids,
unmatched_trk_obj_ids=unmatched_trk_obj_ids,
rank0_metadata=rank0_metadata_new,
tracker_metadata=tracker_metadata_prev,
)
else:
# if warm-up is not complete, we don't remove any objects
obj_ids_newly_removed = set()
tracker_metadata_new["rank0_metadata"] = rank0_metadata_new
# Step 2: broadcast the update plan to other GPUs
NUM_BROADCAST_ITEMS = 9
if self.rank == 0 and self.world_size > 1:
# `num_obj_per_gpu_on_rank0` is used for metadata consistency check on other GPUs
# (it's a small array with length==self.world_size, so broadcasting it is cheap)
num_obj_per_gpu_on_rank0 = tracker_metadata_prev["num_obj_per_gpu"]
update_plan = [
new_det_fa_inds,
new_det_obj_ids,
new_det_gpu_ids,
num_obj_per_gpu_on_rank0,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
obj_ids_newly_removed,
num_obj_dropped_due_to_limit,
trk_id_to_max_iou_high_conf_det,
]
assert (
len(update_plan) == NUM_BROADCAST_ITEMS
), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
self.broadcast_python_obj_cpu(update_plan, src=0)
elif self.rank > 0 and self.world_size > 1:
update_plan = [
None
] * NUM_BROADCAST_ITEMS # other ranks receive the plan from rank 0
self.broadcast_python_obj_cpu(update_plan, src=0)
(
new_det_fa_inds,
new_det_obj_ids,
new_det_gpu_ids,
num_obj_per_gpu_on_rank0,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
obj_ids_newly_removed,
num_obj_dropped_due_to_limit,
trk_id_to_max_iou_high_conf_det,
) = update_plan
# metadata consistency check: verify that the received `num_obj_per_gpu_on_rank0` is consistent with the local metadata
# it's critical that all GPUs agree on the previous number of objects (otherwise the inference might hang or fail silently)
if not np.all(
num_obj_per_gpu_on_rank0 == tracker_metadata_prev["num_obj_per_gpu"]
):
raise RuntimeError(
f"{self.rank=} received {num_obj_per_gpu_on_rank0=}, which is inconsistent with local record "
f"{tracker_metadata_prev['num_obj_per_gpu']=}. There's likely a bug in update planning or execution."
)
# `tracker_update_plan` should be identical on all GPUs after broadcasting
tracker_update_plan = {
"new_det_fa_inds": new_det_fa_inds, # npt.NDArray
"new_det_obj_ids": new_det_obj_ids, # npt.NDArray
"new_det_gpu_ids": new_det_gpu_ids, # npt.NDArray
"unmatched_trk_obj_ids": unmatched_trk_obj_ids, # npt.NDArray
"det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, # dict
"obj_ids_newly_removed": obj_ids_newly_removed, # set
"num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int
"trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, # dict
"reconditioned_obj_ids": reconditioned_obj_ids, # set
}
# Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding
# NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results
should_recondition_iou = False
# Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections
if (
self.reconstruction_bbox_iou_thresh > 0
and len(trk_id_to_max_iou_high_conf_det) > 0
):
for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
det_box = det_out["bbox"][det_idx]
det_score = det_out["scores"][det_idx]
try:
trk_idx = list(tracker_metadata_prev["obj_ids_all_gpu"]).index(
trk_obj_id
)
except ValueError:
continue # Skip if tracklet not found
tracker_mask = tracker_low_res_masks_global[trk_idx]
mask_binary = tracker_mask > 0
mask_area = mask_binary.sum().item()
if mask_area == 0:
continue # Skip tracklets with zero mask area
# Get bounding box from SAM2 mask and convert to normalized coordinates
tracker_box_pixels = (
mask_to_box(mask_binary.unsqueeze(0).unsqueeze(0))
.squeeze(0)
.squeeze(0)
)
mask_height, mask_width = tracker_mask.shape[-2:]
tracker_box_normalized = torch.tensor(
[
tracker_box_pixels[0] / mask_width,
tracker_box_pixels[1] / mask_height,
tracker_box_pixels[2] / mask_width,
tracker_box_pixels[3] / mask_height,
],
device=tracker_box_pixels.device,
)
# Compute IoU between detection and SAM2 tracklet bounding boxes
det_box_batch = det_box.unsqueeze(0)
tracker_box_batch = tracker_box_normalized.unsqueeze(0)
iou = fast_diag_box_iou(det_box_batch, tracker_box_batch)[0]
if (
iou < self.reconstruction_bbox_iou_thresh
and det_score >= self.reconstruction_bbox_det_score
):
should_recondition_iou = True
reconditioned_obj_ids.add(trk_obj_id)
should_recondition_periodic = (
self.recondition_every_nth_frame > 0
and frame_idx % self.recondition_every_nth_frame == 0
and len(trk_id_to_max_iou_high_conf_det) > 0
)
# Recondition if periodic or IoU condition met
if should_recondition_periodic or should_recondition_iou:
self._recondition_masklets(
frame_idx,
det_out,
trk_id_to_max_iou_high_conf_det,
tracker_states_local,
tracker_metadata_prev,
tracker_obj_scores_global,
)
# Step 4: Run SAM2 memory encoder on the current frame's prediction masks
# This is done on all GPUs
batch_size = tracker_low_res_masks_global.size(0)
if batch_size > 0:
if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0:
# NOTE: tracker_low_res_masks_global is updated in-place then returned
tracker_low_res_masks_global = (
self._suppress_overlapping_based_on_recent_occlusion(
frame_idx,
tracker_low_res_masks_global,
tracker_metadata_prev,
tracker_metadata_new,
obj_ids_newly_removed,
reverse,
)
)
self._tracker_update_memories(
tracker_states_local,
frame_idx,
tracker_metadata=tracker_metadata_prev,
low_res_masks=tracker_low_res_masks_global,
)
# Step 4: update the SAM2 metadata based on the update plan
# note: except for "rank0_metadata" (that is only available on GPU 0),
# the updated `tracker_metadata_new` should be identical on all GPUs
for rank in range(self.world_size):
new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank]
updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank]
if len(new_det_obj_ids_this_gpu) > 0:
updated_obj_ids_this_gpu = np.concatenate(
[updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu]
)
if len(obj_ids_newly_removed) > 0:
is_removed = np.isin(
updated_obj_ids_this_gpu, list(obj_ids_newly_removed)
)
updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed]
tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu
tracker_metadata_new["num_obj_per_gpu"][rank] = len(
updated_obj_ids_this_gpu
)
tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate(
tracker_metadata_new["obj_ids_per_gpu"]
)
# update object scores and the maximum object ID assigned so far
if len(new_det_obj_ids) > 0:
tracker_metadata_new["obj_id_to_score"].update(
zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])
)
# tracker scores are not available for new objects, use det score instead.
tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][
frame_idx
].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]))
tracker_metadata_new["max_obj_id"] = max(
tracker_metadata_new["max_obj_id"],
np.max(new_det_obj_ids),
)
# for removed objects, we set their scores to a very low value (-1e4) but still
# keep them in "obj_id_to_score" (it's easier to handle outputs this way)
for obj_id in obj_ids_newly_removed:
tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4
tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][
obj_id
] = -1e4
tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None)
# check that "rank0_metadata" is in tracker_metadata_new if and only if it's GPU 0
assert ("rank0_metadata" in tracker_metadata_new) == (self.rank == 0)
if self.rank == 0 and self.masklet_confirmation_enable:
rank0_metadata = self.update_masklet_confirmation_status(
rank0_metadata=tracker_metadata_new["rank0_metadata"],
obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"],
obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"],
det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
new_det_obj_ids=new_det_obj_ids,
)
tracker_metadata_new["rank0_metadata"] = rank0_metadata
return tracker_update_plan, tracker_metadata_new
def _suppress_overlapping_based_on_recent_occlusion(
self,
frame_idx: int,
tracker_low_res_masks_global: Tensor,
tracker_metadata_prev: Dict[str, Any],
tracker_metadata_new: Dict[str, Any],
obj_ids_newly_removed: Set[int],
reverse: bool = False,
):
"""
Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object.
Args:
frame_idx (int): The current frame index.
tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame.
tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame.
tracker_metadata_new (Dict[str, Any]): The metadata for the current frame.
obj_ids_newly_removed (Set[int]): The object IDs that have been removed.
Return:
Tensor: The updated low-resolution masks with some objects suppressed.
"""
obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"]
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
batch_size = tracker_low_res_masks_global.size(0)
if batch_size > 0:
assert (
len(obj_ids_global) == batch_size
), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
NEVER_OCCLUDED = -1
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
last_occluded_prev = torch.cat(
[
tracker_metadata_prev["obj_id_to_last_occluded"].get(
obj_id,
torch.full(
(1,),
fill_value=(
NEVER_OCCLUDED
if obj_id not in obj_ids_newly_removed
else ALWAYS_OCCLUDED
),
device=binary_tracker_low_res_masks_global.device,
dtype=torch.long,
),
)
for obj_id in obj_ids_global
],
dim=0,
)
to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded(
binary_tracker_low_res_masks_global,
last_occluded_prev,
obj_ids_global,
frame_idx,
reverse,
)
# Update metadata with occlusion information
is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2)))
is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress
last_occluded_new = last_occluded_prev.clone()
last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx
# Slice out the last occluded frame for each object
tracker_metadata_new["obj_id_to_last_occluded"] = {
obj_id: last_occluded_new[obj_idx : obj_idx + 1]
for obj_idx, obj_id in enumerate(obj_ids_global)
}
# Zero out suppressed masks before memory encoding
NO_OBJ_LOGIT = -10
tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT
return tracker_low_res_masks_global
def run_tracker_update_execution_phase(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
det_out: Dict[str, Tensor],
tracker_states_local: List[Any],
tracker_update_plan: Dict[str, npt.NDArray],
orig_vid_height: int,
orig_vid_width: int,
feature_cache: Dict,
):
# initialize tracking scores with detection scores
new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"]
new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"]
new_det_gpu_ids: npt.NDArray = tracker_update_plan["new_det_gpu_ids"]
is_on_this_gpu: npt.NDArray = new_det_gpu_ids == self.rank
new_det_obj_ids_local: npt.NDArray = new_det_obj_ids[is_on_this_gpu]
new_det_fa_inds_local: npt.NDArray = new_det_fa_inds[is_on_this_gpu]
obj_ids_newly_removed: Set[int] = tracker_update_plan["obj_ids_newly_removed"]
# Step 1: add new objects from the detector to SAM2 inference states
if len(new_det_fa_inds_local) > 0:
new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local)
new_det_masks: Tensor = det_out["mask"][new_det_fa_inds_local_t]
# initialize SAM2 with new object masks
tracker_states_local = self._tracker_add_new_objects(
frame_idx=frame_idx,
num_frames=num_frames,
new_obj_ids=new_det_obj_ids_local,
new_obj_masks=new_det_masks,
tracker_states_local=tracker_states_local,
orig_vid_height=orig_vid_height,
orig_vid_width=orig_vid_width,
feature_cache=feature_cache,
)
# Step 2: remove from SAM2 inference states those objects removed by heuristics
if len(obj_ids_newly_removed) > 0:
self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed)
return tracker_states_local
def build_outputs(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
det_out: Dict[str, Tensor],
tracker_low_res_masks_global: Tensor,
tracker_obj_scores_global: Tensor,
tracker_metadata_prev: Dict[str, npt.NDArray],
tracker_update_plan: Dict[str, npt.NDArray],
orig_vid_height: int,
orig_vid_width: int,
reconditioned_obj_ids: set = None,
det_to_matched_trk_obj_ids: dict = None,
):
new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"]
new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"]
obj_id_to_mask = {} # obj_id --> output mask tensor
# Part 1: masks from previous SAM2 propagation
existing_masklet_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"]
existing_masklet_video_res_masks = F.interpolate(
tracker_low_res_masks_global.unsqueeze(1),
size=(orig_vid_height, orig_vid_width),
mode="bilinear",
align_corners=False,
) # (num_obj, 1, H_video, W_video)
existing_masklet_binary = existing_masklet_video_res_masks > 0
assert len(existing_masklet_obj_ids) == len(existing_masklet_binary)
for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary):
obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
# Part 2: masks from new detections
new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds)
new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1)
new_det_low_res_masks = fill_holes_in_mask_scores(
new_det_low_res_masks,
max_area=self.fill_hole_area,
fill_holes=True,
remove_sprinkles=True,
)
new_masklet_video_res_masks = F.interpolate(
new_det_low_res_masks,
size=(orig_vid_height, orig_vid_width),
mode="bilinear",
align_corners=False,
) # (num_obj, 1, H_video, W_video)
new_masklet_binary = new_masklet_video_res_masks > 0
assert len(new_det_obj_ids) == len(new_masklet_video_res_masks)
for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary):
obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
# Part 3: Override masks for reconditioned objects using detection masks
if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0:
trk_id_to_max_iou_high_conf_det = tracker_update_plan.get(
"trk_id_to_max_iou_high_conf_det", {}
)
for obj_id in reconditioned_obj_ids:
det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id)
if det_idx is not None:
det_mask = det_out["mask"][det_idx]
det_mask = det_mask.unsqueeze(0).unsqueeze(0)
det_mask_resized = (
F.interpolate(
det_mask.float(),
size=(orig_vid_height, orig_vid_width),
mode="bilinear",
align_corners=False,
)
> 0
)
det_mask_final = det_mask_resized.squeeze(0)
obj_id_to_mask[obj_id] = det_mask_final
return obj_id_to_mask
def _get_objects_to_suppress_based_on_most_recently_occluded(
self,
binary_low_res_masks: Tensor,
last_occluded: List[int],
obj_ids: List[int],
frame_idx: int = None,
reverse: bool = False,
):
# Suppress overlapping masks for objects that were most recently occluded
assert (
binary_low_res_masks.dtype == torch.bool
), f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
to_suppress = torch.zeros(
binary_low_res_masks.size(0),
device=binary_low_res_masks.device,
dtype=torch.bool,
)
if len(obj_ids) <= 1:
return to_suppress
iou = mask_iou(binary_low_res_masks, binary_low_res_masks) # [N,N]
# Create masks for upper triangular matrix (i < j) and IoU threshold
mask_iou_thresh = (
iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold
)
overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N]
last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1)
last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N)
# Suppress most recently occluded
cmp_op = torch.gt if not reverse else torch.lt
suppress_i_mask = (
overlapping_pairs
& cmp_op(
last_occ_expanded_i, last_occ_expanded_j
) # (last_occ_expanded_i > last_occ_expanded_j)
& (
last_occ_expanded_j > -1
) # j can suppress i only if i was previously occluded
)
suppress_j_mask = (
overlapping_pairs
& cmp_op(last_occ_expanded_j, last_occ_expanded_i)
& (
last_occ_expanded_i > -1
) # i can suppress j only if j was previously occluded
)
# Apply suppression
to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0)
# Log for debugging
if (
self.rank == 0
and logger.isEnabledFor(logging.DEBUG)
and frame_idx is not None
):
suppress_i_mask = suppress_i_mask.cpu().numpy()
suppress_j_mask = suppress_j_mask.cpu().numpy()
last_occluded = last_occluded.cpu().numpy()
# Find all suppression pairs without using torch.where
batch_size = suppress_i_mask.shape[0]
# Log i-suppression cases (where i gets suppressed in favor of j)
for i in range(batch_size):
for j in range(batch_size):
if suppress_i_mask[i, j]:
logger.debug(
f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}"
)
# Log j-suppression cases (where j gets suppressed in favor of i)
for i in range(batch_size):
for j in range(batch_size):
if suppress_j_mask[i, j]:
logger.debug(
f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}"
)
return to_suppress
def _propogate_tracker_one_frame_local_gpu(
self,
inference_states: List[Any],
frame_idx: int,
reverse: bool,
# by default, we disable memory encoding until we gather all outputs
run_mem_encoder: bool = False,
):
"""
inference_states: List of inference states, each state corresponds to a different set of objects.
"""
obj_ids_local = []
low_res_masks_list = []
obj_scores_list = []
for inference_state in inference_states:
if len(inference_state["obj_ids"]) == 0:
continue # skip propagation on empty inference states
# propagate one frame
num_frames_propagated = 0
for out in self.tracker.propagate_in_video(
inference_state,
start_frame_idx=frame_idx,
# end_frame_idx = start_frame_idx + max_frame_num_to_track
# (i.e. propagating 1 frame since end_frame_idx is inclusive)
max_frame_num_to_track=0,
reverse=reverse,
tqdm_disable=True,
run_mem_encoder=run_mem_encoder,
):
out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = out
num_frames_propagated += 1
# only 1 frames should be propagated
assert (
num_frames_propagated == 1 and out_frame_idx == frame_idx
), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
assert isinstance(out_obj_ids, list)
obj_ids_local.extend(out_obj_ids)
low_res_masks_list.append(out_low_res_masks.squeeze(1))
obj_scores_list.append(out_obj_scores.squeeze(1))
# concatenate the output masklets from all local inference states
H_mask = W_mask = self.tracker.low_res_mask_size
if len(low_res_masks_list) > 0:
low_res_masks_local = torch.cat(low_res_masks_list, dim=0)
obj_scores_local = torch.cat(obj_scores_list, dim=0)
assert low_res_masks_local.shape[1:] == (H_mask, W_mask)
# Apply hole filling to the masks
low_res_masks_local = fill_holes_in_mask_scores(
low_res_masks_local.unsqueeze(1),
max_area=self.fill_hole_area,
fill_holes=True,
remove_sprinkles=True,
)
low_res_masks_local = low_res_masks_local.squeeze(1)
else:
low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device)
obj_scores_local = torch.zeros(0, device=self.device)
return obj_ids_local, low_res_masks_local, obj_scores_local
def _associate_det_trk(
self,
det_masks: Tensor,
det_scores_np: npt.NDArray,
trk_masks: Tensor,
trk_obj_ids: npt.NDArray,
):
"""
Match detections on the current frame with the existing masklets.
Args:
- det_masks: (N, H, W) tensor of predicted masks
- det_scores_np: (N,) array of detection scores
- trk_masks: (M, H, W) tensor of track masks
- trk_obj_ids: (M,) array of object IDs corresponding to trk_masks
Returns:
- new_det_fa_inds: array of new object indices.
- unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched
to any detections on this frame (for unmatched, we only count masklets with >0 area)
- det_to_matched_trk_obj_ids: dict[int, npt.NDArray]: mapping from detector's detection indices
to the list of matched tracklet object IDs
- empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction
"""
iou_threshold = self.assoc_iou_thresh
iou_threshold_trk = self.trk_assoc_iou_thresh
new_det_thresh = self.new_det_thresh
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
assert (
trk_masks.size(0) == len(trk_obj_ids)
), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
if trk_masks.size(0) == 0:
# all detections are new
new_det_fa_inds = np.arange(det_masks.size(0))
unmatched_trk_obj_ids = np.array([], np.int64)
empty_trk_obj_ids = np.array([], np.int64)
det_to_matched_trk_obj_ids = {}
trk_id_to_max_iou_high_conf_det = {}
return (
new_det_fa_inds,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
trk_id_to_max_iou_high_conf_det,
empty_trk_obj_ids,
)
elif det_masks.size(0) == 0:
# all previous tracklets are unmatched if they have a non-zero area
new_det_fa_inds = np.array([], np.int64)
trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy()
unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty]
empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
det_to_matched_trk_obj_ids = {}
trk_id_to_max_iou_high_conf_det = {}
return (
new_det_fa_inds,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
trk_id_to_max_iou_high_conf_det,
empty_trk_obj_ids,
)
if det_masks.shape[-2:] != trk_masks.shape[-2:]:
# resize to the smaller size to save GPU memory
if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]):
trk_masks = F.interpolate(
trk_masks.unsqueeze(1),
size=det_masks.shape[-2:],
mode="bilinear",
align_corners=False,
).squeeze(1)
else:
# resize detections to track size
det_masks = F.interpolate(
det_masks.unsqueeze(1),
size=trk_masks.shape[-2:],
mode="bilinear",
align_corners=False,
).squeeze(1)
det_masks_binary = det_masks > 0
trk_masks_binary = trk_masks > 0
ious = mask_iou(det_masks_binary, trk_masks_binary) # (N, M)
ious_np = ious.cpu().numpy()
if self.o2o_matching_masklets_enable:
from scipy.optimize import linear_sum_assignment
# Hungarian matching for tracks (one-to-one: each track matches at most one detection)
cost_matrix = 1 - ious_np # Hungarian solves for minimum cost
row_ind, col_ind = linear_sum_assignment(cost_matrix)
trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool)
for d, t in zip(row_ind, col_ind):
if ious_np[d, t] >= iou_threshold_trk:
trk_is_matched[t] = True
else:
trk_is_matched = (ious_np >= iou_threshold_trk).any(axis=0)
# Non-empty tracks not matched by Hungarian assignment above threshold are unmatched
trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)).cpu().numpy()
trk_is_unmatched = np.logical_and(trk_is_nonempty, ~trk_is_matched)
unmatched_trk_obj_ids = trk_obj_ids[trk_is_unmatched]
# also record masklets that have zero area in SAM 2 prediction
empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
# For detections: allow many tracks to match to the same detection (many-to-one)
# So, a detection is 'new' if it does not match any track above threshold
is_new_det = np.logical_and(
det_scores_np >= new_det_thresh,
np.logical_not(np.any(ious_np >= iou_threshold, axis=1)),
)
new_det_fa_inds = np.nonzero(is_new_det)[0]
# for each detection, which tracks it matched to (above threshold)
det_to_matched_trk_obj_ids = {}
trk_id_to_max_iou_high_conf_det = {} # trk id --> exactly one detection idx
HIGH_CONF_THRESH = 0.8
HIGH_IOU_THRESH = 0.8
det_to_max_iou_trk_idx = np.argmax(ious_np, axis=1)
det_is_high_conf = (det_scores_np >= HIGH_CONF_THRESH) & ~is_new_det
det_is_high_iou = np.max(ious_np, axis=1) >= HIGH_IOU_THRESH
det_is_high_conf_and_iou = set(
np.nonzero(det_is_high_conf & det_is_high_iou)[0]
)
for d in range(det_masks.size(0)):
det_to_matched_trk_obj_ids[d] = trk_obj_ids[ious_np[d, :] >= iou_threshold]
if d in det_is_high_conf_and_iou:
trk_obj_id = trk_obj_ids[det_to_max_iou_trk_idx[d]].item()
trk_id_to_max_iou_high_conf_det[trk_obj_id] = d
return (
new_det_fa_inds,
unmatched_trk_obj_ids,
det_to_matched_trk_obj_ids,
trk_id_to_max_iou_high_conf_det,
empty_trk_obj_ids,
)
def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu):
"""Distribute the new objects to the GPUs with the least workload."""
workload_per_gpu: npt.NDArray = prev_workload_per_gpu.copy()
new_det_gpu_ids = np.zeros(new_det_num, np.int64)
# assign the objects one by one
for i in range(len(new_det_gpu_ids)):
# find the GPU with the least workload
min_gpu = np.argmin(workload_per_gpu)
new_det_gpu_ids[i] = min_gpu
workload_per_gpu[min_gpu] += 1
return new_det_gpu_ids
def _process_hotstart(
self,
frame_idx: int,
num_frames: int,
reverse: bool,
det_to_matched_trk_obj_ids: Dict[int, npt.NDArray],
new_det_obj_ids: npt.NDArray,
empty_trk_obj_ids: npt.NDArray,
unmatched_trk_obj_ids: npt.NDArray,
rank0_metadata: Dict[str, Any],
tracker_metadata: Dict[str, Any],
):
"""Handle hotstart heuristics to remove unmatched or duplicated objects."""
# obj_id --> first frame index where the object was detected
obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"]
# obj_id --> [mismatched frame indices]
unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"]
trk_keep_alive = rank0_metadata["trk_keep_alive"]
# (first_appear_obj_id, obj_id) --> [overlap frame indices]
overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"]
# removed_obj_ids: object IDs that are suppressed via hot-start
removed_obj_ids = rank0_metadata["removed_obj_ids"]
suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx]
obj_ids_newly_removed = set() # object IDs to be newly removed on this frame
hotstart_diff = (
frame_idx - self.hotstart_delay
if not reverse
else frame_idx + self.hotstart_delay
)
# Step 1: log the frame index where each object ID first appears
for obj_id in new_det_obj_ids:
if obj_id not in obj_first_frame_idx:
obj_first_frame_idx[obj_id] = frame_idx
assert obj_id not in trk_keep_alive
trk_keep_alive[obj_id] = self.init_trk_keep_alive
matched_trks = set()
# We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded
for matched_trks_per_det in det_to_matched_trk_obj_ids.values():
matched_trks.update(matched_trks_per_det)
for obj_id in matched_trks:
# NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive
trk_keep_alive[obj_id] = min(
self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1
)
for obj_id in unmatched_trk_obj_ids:
unmatched_frame_inds[obj_id].append(frame_idx)
# NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
# The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough.
trk_keep_alive[obj_id] = max(
self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1
)
if self.decrease_trk_keep_alive_for_empty_masklets:
for obj_id in empty_trk_obj_ids:
# NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
trk_keep_alive[obj_id] = max(
self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1
)
# Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period
# a) add unmatched frame indices for each existing object ID
# note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask
# doesn't match any detection; it excludes those frames where SAM2 gives an empty mask
# b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more
# than `self.hotstart_unmatch_thresh` frames
for obj_id, frame_indices in unmatched_frame_inds.items():
if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
continue # skip if the object is already removed
if len(frame_indices) >= self.hotstart_unmatch_thresh:
is_within_hotstart = (
obj_first_frame_idx[obj_id] > hotstart_diff and not reverse
) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse)
if is_within_hotstart:
obj_ids_newly_removed.add(obj_id)
logger.debug(
f"Removing object {obj_id} at frame {frame_idx} "
f"since it is unmatched for frames: {frame_indices}"
)
if (
trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long
and not self.suppress_unmatched_only_within_hotstart
and obj_id not in removed_obj_ids
and obj_id not in obj_ids_newly_removed
):
logger.debug(
f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched"
)
suppressed_obj_ids.add(obj_id)
# Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames
# a) find overlaps tracks -- we consider overlap if they match to the same detection
for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items():
if len(matched_trk_obj_ids) < 2:
continue # only count detections that are matched to multiple (>=2) masklets
# if there are multiple matched track ids, we need to find the one that appeared first;
# these later appearing ids may be removed since they may be considered as duplicates
first_appear_obj_id = (
min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
if not reverse
else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
)
for obj_id in matched_trk_obj_ids:
if obj_id != first_appear_obj_id:
key = (first_appear_obj_id, obj_id)
overlap_pair_to_frame_inds[key].append(frame_idx)
# b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another
# masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames
for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items():
if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
continue # skip if the object is already removed
if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or (
obj_first_frame_idx[obj_id] < hotstart_diff and reverse
):
if len(frame_indices) >= self.hotstart_dup_thresh:
obj_ids_newly_removed.add(obj_id)
logger.debug(
f"Removing object {obj_id} at frame {frame_idx} "
f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}"
)
removed_obj_ids.update(obj_ids_newly_removed)
return obj_ids_newly_removed, rank0_metadata
def _tracker_update_memories(
self,
tracker_inference_states: List[Any],
frame_idx: int,
tracker_metadata: Dict[str, Any],
low_res_masks: Tensor,
):
"""
Run Sam2 memory encoder, enforcing non-overlapping constraints globally.
"""
if len(tracker_inference_states) == 0:
return
# Avoid an extra interpolation step by directly interpolating to `interpol_size`
high_res_H, high_res_W = (
self.tracker.maskmem_backbone.mask_downsampler.interpol_size
)
# NOTE: inspect this part if we observe OOMs in the demo
high_res_masks = F.interpolate(
low_res_masks.unsqueeze(1),
size=(high_res_H, high_res_W),
mode="bilinear",
align_corners=False,
)
# We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics.
if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
high_res_masks = self.tracker._suppress_object_pw_area_shrinkage(
high_res_masks
)
# Instead of gathering the predicted object scores, we use mask areas as a proxy.
object_score_logits = torch.where(
(high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0
)
# Run the memory encoder on local slices for each GPU
start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank])
start_idx_state = start_idx_gpu
for tracker_state in tracker_inference_states:
num_obj_per_state = len(tracker_state["obj_ids"])
if num_obj_per_state == 0:
continue
# Get the local high-res masks and object score logits for this inference state
end_idx_state = start_idx_state + num_obj_per_state
local_high_res_masks = high_res_masks[start_idx_state:end_idx_state]
local_object_score_logits = object_score_logits[
start_idx_state:end_idx_state
]
local_batch_size = local_high_res_masks.size(0)
# Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default
encoded_mem = self.tracker._run_memory_encoder(
tracker_state,
frame_idx,
local_batch_size,
local_high_res_masks,
local_object_score_logits,
is_mask_from_pts=False,
)
local_maskmem_features, local_maskmem_pos_enc = encoded_mem
# Store encoded memories in the local inference state
output_dict = tracker_state["output_dict"]
for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]:
if frame_idx not in output_dict[storage_key]:
continue
output_dict[storage_key][frame_idx]["maskmem_features"] = (
local_maskmem_features
)
output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [
pos for pos in local_maskmem_pos_enc
]
# for batched inference state, we also need to add per-object
# memory slides to support instance interactivity
self.tracker._add_output_per_object(
inference_state=tracker_state,
frame_idx=frame_idx,
current_out=output_dict[storage_key][frame_idx],
storage_key=storage_key,
)
start_idx_state += num_obj_per_state
def _tracker_add_new_objects(
self,
frame_idx: int,
num_frames: int,
new_obj_ids: List[int],
new_obj_masks: Tensor,
tracker_states_local: List[Any],
orig_vid_height: int,
orig_vid_width: int,
feature_cache: Dict,
):
"""Add a new object to SAM2 inference states."""
prev_tracker_state = (
tracker_states_local[0] if len(tracker_states_local) > 0 else None
)
# prepare inference_state
# batch objects that first appear on the same frame together
# Clear inference state. Keep the cached image features if available.
new_tracker_state = self.tracker.init_state(
cached_features=feature_cache,
video_height=orig_vid_height,
video_width=orig_vid_width,
num_frames=num_frames,
)
new_tracker_state["backbone_out"] = (
prev_tracker_state.get("backbone_out", None)
if prev_tracker_state is not None
else None
)
assert len(new_obj_ids) == new_obj_masks.size(0)
assert new_obj_masks.is_floating_point()
input_mask_res = self.tracker.input_mask_size
new_obj_masks = F.interpolate(
new_obj_masks.unsqueeze(1),
size=(input_mask_res, input_mask_res),
mode="bilinear",
align_corners=False,
).squeeze(1)
new_obj_masks = new_obj_masks > 0
# add object one by one
for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks):
self.tracker.add_new_mask(
inference_state=new_tracker_state,
frame_idx=frame_idx,
obj_id=new_obj_id,
mask=new_mask,
add_mask_to_memory=True,
)
# NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects.
self.tracker.propagate_in_video_preflight(
new_tracker_state, run_mem_encoder=True
)
tracker_states_local.append(new_tracker_state)
return tracker_states_local
def _tracker_remove_object(self, tracker_states_local: List[Any], obj_id: int):
"""
Remove an object from SAM2 inference states. This would remove the object from
all frames in the video.
"""
tracker_states_local_before_removal = tracker_states_local.copy()
tracker_states_local.clear()
for tracker_inference_state in tracker_states_local_before_removal:
# we try to remove `obj_id` on every inference state with `strict=False`
# it will not do anything if an inference state doesn't contain `obj_id`
new_obj_ids, _ = self.tracker.remove_object(
tracker_inference_state, obj_id, strict=False, need_output=False
)
# only keep an inference state if it's non-empty after object removal
if len(new_obj_ids) > 0:
tracker_states_local.append(tracker_inference_state)
def _tracker_remove_objects(
self, tracker_states_local: List[Any], obj_ids: list[int]
):
"""
Remove an object from SAM2 inference states. This would remove the object from
all frames in the video.
"""
for obj_id in obj_ids:
self._tracker_remove_object(tracker_states_local, obj_id)
def _initialize_metadata(self):
"""Initialize metadata for the masklets."""
tracker_metadata = {
"obj_ids_per_gpu": [np.array([], np.int64) for _ in range(self.world_size)],
"obj_ids_all_gpu": np.array([], np.int64),
"num_obj_per_gpu": np.zeros(self.world_size, np.int64),
"max_obj_id": -1,
"obj_id_to_score": {},
"obj_id_to_tracker_score_frame_wise": defaultdict(dict),
"obj_id_to_last_occluded": {},
}
if self.rank == 0:
# "rank0_metadata" contains metadata that is only stored on (and accessible to) GPU 0
# - obj_first_frame_idx: obj_id --> first frame index where the object was detected
# - unmatched_frame_inds: obj_id --> [mismatched frame indices]
# - overlap_pair_to_frame_inds: (first_appear_obj_id, obj_id) --> [overlap frame indices]
# - removed_obj_ids: object IDs that are suppressed via hot-start
rank0_metadata = {
"obj_first_frame_idx": {},
"unmatched_frame_inds": defaultdict(list),
"trk_keep_alive": defaultdict(
int
), # This is used only for object suppression not for removal
"overlap_pair_to_frame_inds": defaultdict(list),
"removed_obj_ids": set(),
"suppressed_obj_ids": defaultdict(
set
), # frame_idx --> set of objects with suppressed outputs, but still continue to be tracked
}
if self.masklet_confirmation_enable:
# all the following are npt.NDArray with the same shape as `obj_ids_all_gpu`
rank0_metadata["masklet_confirmation"] = {
# "status" is the confirmation status of each masklet (in `MaskletConfirmationStatus`)
"status": np.array([], np.int64),
# "consecutive_det_num" is the number of consecutive frames where the masklet is
# detected by the detector (with a matched detection)
"consecutive_det_num": np.array([], np.int64),
}
tracker_metadata["rank0_metadata"] = rank0_metadata
return tracker_metadata
def update_masklet_confirmation_status(
self,
rank0_metadata: Dict[str, Any],
obj_ids_all_gpu_prev: npt.NDArray,
obj_ids_all_gpu_updated: npt.NDArray,
det_to_matched_trk_obj_ids: Dict[int, npt.NDArray],
new_det_obj_ids: npt.NDArray,
):
confirmation_data = rank0_metadata["masklet_confirmation"]
# a) first, expand "confirmation_data" to include new masklets added in this frame
status_prev = confirmation_data["status"]
consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
assert (
status_prev.shape == obj_ids_all_gpu_prev.shape
), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
obj_id_to_updated_idx = {
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)
}
prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated)
prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated]
prev_elem_inds_in_updated = np.array(
[obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated],
dtype=np.int64,
)
# newly added masklets are initialized to "UNCONFIRMED" status
unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value
status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val)
status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated]
consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated)
consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[
prev_elem_is_in_updated
]
# b) update the confirmation status of all masklets based on the current frame
# b.1) update "consecutive_det_num"
# "is_matched": whether a masklet is matched to a detection on this frame
is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids)
for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values():
is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids)
consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0)
# b.2) update "status"
change_to_confirmed = (
consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh
)
status[change_to_confirmed] = MaskletConfirmationStatus.CONFIRMED.value
confirmation_data["status"] = status
confirmation_data["consecutive_det_num"] = consecutive_det_num
return rank0_metadata
def forward(self, input: BatchedDatapoint, is_inference: bool = False):
raise NotImplementedError("Evaluation outside demo is not implemented yet")
def _load_checkpoint(self, ckpt_path: str, strict: bool = True):
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict)
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
logger.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}")
else:
logger.info("Loaded ckpt successfully without missing or unexpected keys")
def prep_for_evaluator(self, video_frames, tracking_res, scores_labels):
"""This method is only used for benchmark eval (not used in the demo)."""
num_frames = len(video_frames)
w, h = video_frames[0].size
zero_mask = torch.zeros((1, h, w), dtype=torch.bool)
object_ids = list(scores_labels.keys())
preds = {"scores": [], "labels": [], "boxes": [], "masks_rle": []}
for oid in object_ids:
o_masks = []
o_score = scores_labels[oid][0].item()
o_label = scores_labels[oid][1]
for frame_idx in range(num_frames):
if frame_idx not in tracking_res:
o_masks.append(zero_mask)
else:
o_masks.append(tracking_res[frame_idx].get(oid, zero_mask))
o_masks = torch.cat(o_masks, dim=0) # (n_frames, H, W)
preds["scores"].append(o_score)
preds["labels"].append(o_label)
preds["boxes"].append(mask_to_box(o_masks.unsqueeze(1)).squeeze())
preds["masks_rle"].append(rle_encode(o_masks, return_areas=True))
preds["boxes"] = (
torch.stack(preds["boxes"], dim=0)
if len(preds["boxes"]) > 0
else torch.empty(
(0, num_frames, 4), dtype=torch.float32, device=self.device
)
)
preds["scores"] = (
torch.tensor(preds["scores"], device=self.device)
if len(preds["scores"]) > 0
else torch.empty((0,), device=self.device)
)
preds["per_frame_scores"] = preds["scores"]
preds["labels"] = (
torch.tensor(preds["labels"], device=self.device)
if len(preds["labels"]) > 0
else torch.empty((0,), device=self.device)
)
return preds
def _encode_prompt(self, **kwargs):
return self.detector._encode_prompt(**kwargs)
def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep):
"""
Drop a few new detections based on the maximum number of objects. We drop new objects based
on their detection scores, keeping the high-scoring ones and dropping the low-scoring ones.
"""
assert 0 <= num_to_keep <= len(new_det_fa_inds)
if num_to_keep == 0:
return np.array([], np.int64) # keep none
if num_to_keep == len(new_det_fa_inds):
return new_det_fa_inds # keep all
# keep the top-scoring detections
score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1]
new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]]
return new_det_fa_inds
|