| | import os |
| | import numpy as np |
| | import sam3 |
| | from sam3 import build_sam3_image_model |
| | from sam3.model.sam3_image_processor import Sam3Processor |
| | from .masks import rle_to_mask |
| |
|
| | def get_index(dataset, image_id): |
| | idx = dataset.metadata['image_id'] == image_id |
| | if idx.sum() != 1: |
| | raise ValueError('image_id not found or found multiple times.') |
| | return dataset.metadata[idx].index[0] |
| |
|
| | def mask_centroid(mask): |
| | ys, xs = np.nonzero(mask) |
| | return np.array([xs.mean(), ys.mean()]) |
| |
|
| | def rle_centroid(rle): |
| | return mask_centroid(rle_to_mask(rle)) |
| |
|
| | def assign_flippers(df): |
| | df = df.copy() |
| |
|
| | |
| | head_rows = df[df['label'] == 'head'] |
| | if len(head_rows) != 1: |
| | return df |
| | |
| | |
| | head_center = rle_centroid(head_rows.iloc[0]['mask']) |
| |
|
| | |
| | flippers = df[df['label'] == 'flipper'] |
| | n_flippers = len(flippers) |
| | if n_flippers == 0: |
| | return df |
| |
|
| | |
| | flipper_centers = np.vstack([ |
| | rle_centroid(rle) for rle in flippers['mask'] |
| | ]) |
| |
|
| | |
| | turtle_center = flipper_centers.mean(axis=0) |
| | forward_vec = head_center - turtle_center |
| | forward_vec /= np.linalg.norm(forward_vec) |
| |
|
| | |
| | left_vec = np.array([-forward_vec[1], forward_vec[0]]) |
| |
|
| | |
| | forward_proj = flipper_centers @ forward_vec |
| | lateral_proj = flipper_centers @ left_vec |
| |
|
| | if n_flippers <= 2: |
| | |
| | order = np.argsort(lateral_proj) |
| | left_idx, right_idx = order[0], order[-1] |
| |
|
| | df.loc[flippers.index[left_idx], 'label'] = 'flipper_fl' |
| | df.loc[flippers.index[right_idx], 'label'] = 'flipper_fr' |
| | return df |
| | elif n_flippers <= 4: |
| | |
| | order_fwd = np.argsort(forward_proj) |
| | rear_idxs = order_fwd[:2] |
| | front_idxs = order_fwd[-2:] |
| |
|
| | |
| | front_l = front_idxs[np.argmin(lateral_proj[front_idxs])] |
| | front_r = front_idxs[np.argmax(lateral_proj[front_idxs])] |
| |
|
| | df.loc[flippers.index[front_l], 'label'] = 'flipper_fl' |
| | df.loc[flippers.index[front_r], 'label'] = 'flipper_fr' |
| |
|
| | |
| | if len(rear_idxs) == 2: |
| | rear_l = rear_idxs[np.argmin(lateral_proj[rear_idxs])] |
| | rear_r = rear_idxs[np.argmax(lateral_proj[rear_idxs])] |
| |
|
| | df.loc[flippers.index[rear_l], 'label'] = 'flipper_rl' |
| | df.loc[flippers.index[rear_r], 'label'] = 'flipper_rr' |
| | else: |
| | |
| | idx = rear_idxs[0] |
| | side = 'l' if lateral_proj[idx] < 0 else 'r' |
| | df.loc[flippers.index[idx], 'label'] = f'flipper_r{side}' |
| |
|
| | return df |
| |
|
| | def initialize_sam3(): |
| | sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") |
| | bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz" |
| | model = build_sam3_image_model(bpe_path=bpe_path) |
| | processor = Sam3Processor(model, confidence_threshold=0.5) |
| | return model, processor |