WildlifeDatasets's picture
Added training scripts
bb14d6a unverified
import os
import numpy as np
import sam3
from sam3 import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from .masks import rle_to_mask
def get_index(dataset, image_id):
idx = dataset.metadata['image_id'] == image_id
if idx.sum() != 1:
raise ValueError('image_id not found or found multiple times.')
return dataset.metadata[idx].index[0]
def mask_centroid(mask):
ys, xs = np.nonzero(mask)
return np.array([xs.mean(), ys.mean()])
def rle_centroid(rle):
return mask_centroid(rle_to_mask(rle))
def assign_flippers(df):
df = df.copy()
# Check that there is only one head
head_rows = df[df['label'] == 'head']
if len(head_rows) != 1:
return df
# Compute the head centroid
head_center = rle_centroid(head_rows.iloc[0]['mask'])
# Extract the flippers
flippers = df[df['label'] == 'flipper']
n_flippers = len(flippers)
if n_flippers == 0:
return df
# Compute the flipper centroids
flipper_centers = np.vstack([
rle_centroid(rle) for rle in flippers['mask']
])
# Vector from turtle center to head defines "forward"
turtle_center = flipper_centers.mean(axis=0)
forward_vec = head_center - turtle_center
forward_vec /= np.linalg.norm(forward_vec)
# Perpendicular defines left/right
left_vec = np.array([-forward_vec[1], forward_vec[0]])
# Project flippers
forward_proj = flipper_centers @ forward_vec
lateral_proj = flipper_centers @ left_vec
if n_flippers <= 2:
# Always front flippers
order = np.argsort(lateral_proj)
left_idx, right_idx = order[0], order[-1]
df.loc[flippers.index[left_idx], 'label'] = 'flipper_fl'
df.loc[flippers.index[right_idx], 'label'] = 'flipper_fr'
return df
elif n_flippers <= 4:
# Sort by forward distance
order_fwd = np.argsort(forward_proj)
rear_idxs = order_fwd[:2]
front_idxs = order_fwd[-2:]
# Front flippers
front_l = front_idxs[np.argmin(lateral_proj[front_idxs])]
front_r = front_idxs[np.argmax(lateral_proj[front_idxs])]
df.loc[flippers.index[front_l], 'label'] = 'flipper_fl'
df.loc[flippers.index[front_r], 'label'] = 'flipper_fr'
# Rear flippers (if present)
if len(rear_idxs) == 2:
rear_l = rear_idxs[np.argmin(lateral_proj[rear_idxs])]
rear_r = rear_idxs[np.argmax(lateral_proj[rear_idxs])]
df.loc[flippers.index[rear_l], 'label'] = 'flipper_rl'
df.loc[flippers.index[rear_r], 'label'] = 'flipper_rr'
else:
# 3 flippers: assign only the most rear one
idx = rear_idxs[0]
side = 'l' if lateral_proj[idx] < 0 else 'r'
df.loc[flippers.index[idx], 'label'] = f'flipper_r{side}'
return df
def initialize_sam3():
sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
model = build_sam3_image_model(bpe_path=bpe_path)
processor = Sam3Processor(model, confidence_threshold=0.5)
return model, processor