import os import numpy as np import sam3 from sam3 import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from .masks import rle_to_mask def get_index(dataset, image_id): idx = dataset.metadata['image_id'] == image_id if idx.sum() != 1: raise ValueError('image_id not found or found multiple times.') return dataset.metadata[idx].index[0] def mask_centroid(mask): ys, xs = np.nonzero(mask) return np.array([xs.mean(), ys.mean()]) def rle_centroid(rle): return mask_centroid(rle_to_mask(rle)) def assign_flippers(df): df = df.copy() # Check that there is only one head head_rows = df[df['label'] == 'head'] if len(head_rows) != 1: return df # Compute the head centroid head_center = rle_centroid(head_rows.iloc[0]['mask']) # Extract the flippers flippers = df[df['label'] == 'flipper'] n_flippers = len(flippers) if n_flippers == 0: return df # Compute the flipper centroids flipper_centers = np.vstack([ rle_centroid(rle) for rle in flippers['mask'] ]) # Vector from turtle center to head defines "forward" turtle_center = flipper_centers.mean(axis=0) forward_vec = head_center - turtle_center forward_vec /= np.linalg.norm(forward_vec) # Perpendicular defines left/right left_vec = np.array([-forward_vec[1], forward_vec[0]]) # Project flippers forward_proj = flipper_centers @ forward_vec lateral_proj = flipper_centers @ left_vec if n_flippers <= 2: # Always front flippers order = np.argsort(lateral_proj) left_idx, right_idx = order[0], order[-1] df.loc[flippers.index[left_idx], 'label'] = 'flipper_fl' df.loc[flippers.index[right_idx], 'label'] = 'flipper_fr' return df elif n_flippers <= 4: # Sort by forward distance order_fwd = np.argsort(forward_proj) rear_idxs = order_fwd[:2] front_idxs = order_fwd[-2:] # Front flippers front_l = front_idxs[np.argmin(lateral_proj[front_idxs])] front_r = front_idxs[np.argmax(lateral_proj[front_idxs])] df.loc[flippers.index[front_l], 'label'] = 'flipper_fl' df.loc[flippers.index[front_r], 'label'] = 'flipper_fr' # Rear flippers (if present) if len(rear_idxs) == 2: rear_l = rear_idxs[np.argmin(lateral_proj[rear_idxs])] rear_r = rear_idxs[np.argmax(lateral_proj[rear_idxs])] df.loc[flippers.index[rear_l], 'label'] = 'flipper_rl' df.loc[flippers.index[rear_r], 'label'] = 'flipper_rr' else: # 3 flippers: assign only the most rear one idx = rear_idxs[0] side = 'l' if lateral_proj[idx] < 0 else 'r' df.loc[flippers.index[idx], 'label'] = f'flipper_r{side}' return df def initialize_sam3(): sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz" model = build_sam3_image_model(bpe_path=bpe_path) processor = Sam3Processor(model, confidence_threshold=0.5) return model, processor