bellmake's picture
SAM3 Video Segmentation - Clean deployment
14114e8
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Transforms and data augmentation for both image + bbox.
"""
import logging
import numbers
import random
from collections.abc import Sequence
from typing import Iterable
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as Fv2
from PIL import Image as PILImage
from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.transforms import InterpolationMode
def crop(
datapoint,
index,
region,
v2=False,
check_validity=True,
check_input_validity=True,
recompute_box_from_mask=False,
):
if v2:
rtop, rleft, rheight, rwidth = (int(round(r)) for r in region)
datapoint.images[index].data = Fv2.crop(
datapoint.images[index].data,
top=rtop,
left=rleft,
height=rheight,
width=rwidth,
)
else:
datapoint.images[index].data = F.crop(datapoint.images[index].data, *region)
i, j, h, w = region
# should we do something wrt the original size?
datapoint.images[index].size = (h, w)
for obj in datapoint.images[index].objects:
# crop the mask
if obj.segment is not None:
obj.segment = F.crop(obj.segment, int(i), int(j), int(h), int(w))
# crop the bounding box
if recompute_box_from_mask and obj.segment is not None:
# here the boxes are still in XYXY format with absolute coordinates (they are
# converted to CxCyWH with relative coordinates in basic_for_api.NormalizeAPI)
obj.bbox, obj.area = get_bbox_xyxy_abs_coords_from_mask(obj.segment)
else:
if recompute_box_from_mask and obj.segment is None and obj.area > 0:
logging.warning(
"Cannot recompute bounding box from mask since `obj.segment` is None. "
"Falling back to directly cropping from the input bounding box."
)
boxes = obj.bbox.view(1, 4)
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
obj.area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
obj.bbox = cropped_boxes.reshape(-1, 4)
for query in datapoint.find_queries:
if query.semantic_target is not None:
query.semantic_target = F.crop(
query.semantic_target, int(i), int(j), int(h), int(w)
)
if query.image_id == index and query.input_bbox is not None:
boxes = query.input_bbox
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32)
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
# cur_area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
# if check_input_validity:
# assert (
# (cur_area > 0).all().item()
# ), "Some input box got cropped out by the crop transform"
query.input_bbox = cropped_boxes.reshape(-1, 4)
if query.image_id == index and query.input_points is not None:
print(
"Warning! Point cropping with this function may lead to unexpected results"
)
points = query.input_points
# Unlike right-lower box edges, which are exclusive, the
# point must be in [0, length-1], hence the -1
max_size = torch.as_tensor([w, h], dtype=torch.float32) - 1
cropped_points = points - torch.as_tensor([j, i, 0], dtype=torch.float32)
cropped_points[:, :, :2] = torch.min(cropped_points[:, :, :2], max_size)
cropped_points[:, :, :2] = cropped_points[:, :, :2].clamp(min=0)
query.input_points = cropped_points
if check_validity:
# Check that all boxes are still valid
for obj in datapoint.images[index].objects:
assert obj.area > 0, "Box {} has no area".format(obj.bbox)
return datapoint
def hflip(datapoint, index):
datapoint.images[index].data = F.hflip(datapoint.images[index].data)
w, h = datapoint.images[index].data.size
for obj in datapoint.images[index].objects:
boxes = obj.bbox.view(1, 4)
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]
) + torch.as_tensor([w, 0, w, 0])
obj.bbox = boxes
if obj.segment is not None:
obj.segment = F.hflip(obj.segment)
for query in datapoint.find_queries:
if query.semantic_target is not None:
query.semantic_target = F.hflip(query.semantic_target)
if query.image_id == index and query.input_bbox is not None:
boxes = query.input_bbox
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
[-1, 1, -1, 1]
) + torch.as_tensor([w, 0, w, 0])
query.input_bbox = boxes
if query.image_id == index and query.input_points is not None:
points = query.input_points
points = points * torch.as_tensor([-1, 1, 1]) + torch.as_tensor([w, 0, 0])
query.input_points = points
return datapoint
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = max_size * min_original_size / max_original_size
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = int(round(size))
oh = int(round(size * h / w))
else:
oh = int(round(size))
ow = int(round(size * w / h))
return (oh, ow)
def resize(datapoint, index, size, max_size=None, square=False, v2=False):
# size can be min_size (scalar) or (w, h) tuple
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
if square:
size = size, size
else:
cur_size = (
datapoint.images[index].data.size()[-2:][::-1]
if v2
else datapoint.images[index].data.size
)
size = get_size(cur_size, size, max_size)
old_size = (
datapoint.images[index].data.size()[-2:][::-1]
if v2
else datapoint.images[index].data.size
)
if v2:
datapoint.images[index].data = Fv2.resize(
datapoint.images[index].data, size, antialias=True
)
else:
datapoint.images[index].data = F.resize(datapoint.images[index].data, size)
new_size = (
datapoint.images[index].data.size()[-2:][::-1]
if v2
else datapoint.images[index].data.size
)
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, old_size))
ratio_width, ratio_height = ratios
for obj in datapoint.images[index].objects:
boxes = obj.bbox.view(1, 4)
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32
)
obj.bbox = scaled_boxes
obj.area *= ratio_width * ratio_height
if obj.segment is not None:
obj.segment = F.resize(obj.segment[None, None], size).squeeze()
for query in datapoint.find_queries:
if query.semantic_target is not None:
query.semantic_target = F.resize(
query.semantic_target[None, None], size
).squeeze()
if query.image_id == index and query.input_bbox is not None:
boxes = query.input_bbox
scaled_boxes = boxes * torch.as_tensor(
[ratio_width, ratio_height, ratio_width, ratio_height],
dtype=torch.float32,
)
query.input_bbox = scaled_boxes
if query.image_id == index and query.input_points is not None:
points = query.input_points
scaled_points = points * torch.as_tensor(
[ratio_width, ratio_height, 1],
dtype=torch.float32,
)
query.input_points = scaled_points
h, w = size
datapoint.images[index].size = (h, w)
return datapoint
def pad(datapoint, index, padding, v2=False):
old_h, old_w = datapoint.images[index].size
h, w = old_h, old_w
if len(padding) == 2:
# assumes that we only pad on the bottom right corners
if v2:
datapoint.images[index].data = Fv2.pad(
datapoint.images[index].data, (0, 0, padding[0], padding[1])
)
else:
datapoint.images[index].data = F.pad(
datapoint.images[index].data, (0, 0, padding[0], padding[1])
)
h += padding[1]
w += padding[0]
else:
if v2:
# left, top, right, bottom
datapoint.images[index].data = Fv2.pad(
datapoint.images[index].data,
(padding[0], padding[1], padding[2], padding[3]),
)
else:
# left, top, right, bottom
datapoint.images[index].data = F.pad(
datapoint.images[index].data,
(padding[0], padding[1], padding[2], padding[3]),
)
h += padding[1] + padding[3]
w += padding[0] + padding[2]
datapoint.images[index].size = (h, w)
for obj in datapoint.images[index].objects:
if len(padding) != 2:
obj.bbox += torch.as_tensor(
[padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32
)
if obj.segment is not None:
if v2:
if len(padding) == 2:
obj.segment = Fv2.pad(
obj.segment[None], (0, 0, padding[0], padding[1])
).squeeze(0)
else:
obj.segment = Fv2.pad(obj.segment[None], tuple(padding)).squeeze(0)
else:
if len(padding) == 2:
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
else:
obj.segment = F.pad(obj.segment, tuple(padding))
for query in datapoint.find_queries:
if query.semantic_target is not None:
if v2:
if len(padding) == 2:
query.semantic_target = Fv2.pad(
query.semantic_target[None, None],
(0, 0, padding[0], padding[1]),
).squeeze()
else:
query.semantic_target = Fv2.pad(
query.semantic_target[None, None], tuple(padding)
).squeeze()
else:
if len(padding) == 2:
query.semantic_target = F.pad(
query.semantic_target[None, None],
(0, 0, padding[0], padding[1]),
).squeeze()
else:
query.semantic_target = F.pad(
query.semantic_target[None, None], tuple(padding)
).squeeze()
if query.image_id == index and query.input_bbox is not None:
if len(padding) != 2:
query.input_bbox += torch.as_tensor(
[padding[0], padding[1], padding[0], padding[1]],
dtype=torch.float32,
)
if query.image_id == index and query.input_points is not None:
if len(padding) != 2:
query.input_points += torch.as_tensor(
[padding[0], padding[1], 0], dtype=torch.float32
)
return datapoint
class RandomSizeCropAPI:
def __init__(
self,
min_size: int,
max_size: int,
respect_boxes: bool,
consistent_transform: bool,
respect_input_boxes: bool = True,
v2: bool = False,
recompute_box_from_mask: bool = False,
):
self.min_size = min_size
self.max_size = max_size
self.respect_boxes = respect_boxes # if True we can't crop a box out
self.respect_input_boxes = respect_input_boxes
self.consistent_transform = consistent_transform
self.v2 = v2
self.recompute_box_from_mask = recompute_box_from_mask
def _sample_no_respect_boxes(self, img):
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
return T.RandomCrop.get_params(img, (h, w))
def _sample_respect_boxes(self, img, boxes, points, min_box_size=10.0):
"""
Assure that no box or point is dropped via cropping, though portions
of boxes may be removed.
"""
if len(boxes) == 0 and len(points) == 0:
return self._sample_no_respect_boxes(img)
if self.v2:
img_height, img_width = img.size()[-2:]
else:
img_width, img_height = img.size
minW, minH, maxW, maxH = (
min(img_width, self.min_size),
min(img_height, self.min_size),
min(img_width, self.max_size),
min(img_height, self.max_size),
)
# The crop box must extend one pixel beyond points to the bottom/right
# to assure the exclusive box contains the points.
minX = (
torch.cat([boxes[:, 0] + min_box_size, points[:, 0] + 1], dim=0)
.max()
.item()
)
minY = (
torch.cat([boxes[:, 1] + min_box_size, points[:, 1] + 1], dim=0)
.max()
.item()
)
minX = min(img_width, minX)
minY = min(img_height, minY)
maxX = torch.cat([boxes[:, 2] - min_box_size, points[:, 0]], dim=0).min().item()
maxY = torch.cat([boxes[:, 3] - min_box_size, points[:, 1]], dim=0).min().item()
maxX = max(0.0, maxX)
maxY = max(0.0, maxY)
minW = max(minW, minX - maxX)
minH = max(minH, minY - maxY)
w = random.uniform(minW, max(minW, maxW))
h = random.uniform(minH, max(minH, maxH))
if minX > maxX:
# i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1)))
i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w)))
else:
i = random.uniform(
max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1))
)
if minY > maxY:
# j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1)))
j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h)))
else:
j = random.uniform(
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
)
return [j, i, h, w]
def __call__(self, datapoint, **kwargs):
if self.respect_boxes or self.respect_input_boxes:
if self.consistent_transform:
# Check that all the images are the same size
w, h = datapoint.images[0].data.size
for img in datapoint.images:
assert img.data.size == (w, h)
all_boxes = []
# Getting all boxes in all the images
if self.respect_boxes:
all_boxes += [
obj.bbox.view(-1, 4)
for img in datapoint.images
for obj in img.objects
]
# Get all the boxes in the find queries
if self.respect_input_boxes:
all_boxes += [
q.input_bbox.view(-1, 4)
for q in datapoint.find_queries
if q.input_bbox is not None
]
if all_boxes:
all_boxes = torch.cat(all_boxes, 0)
else:
all_boxes = torch.empty(0, 4)
all_points = [
q.input_points.view(-1, 3)[:, :2]
for q in datapoint.find_queries
if q.input_points is not None
]
if all_points:
all_points = torch.cat(all_points, 0)
else:
all_points = torch.empty(0, 2)
crop_param = self._sample_respect_boxes(
datapoint.images[0].data, all_boxes, all_points
)
for i in range(len(datapoint.images)):
datapoint = crop(
datapoint,
i,
crop_param,
v2=self.v2,
check_validity=self.respect_boxes,
check_input_validity=self.respect_input_boxes,
recompute_box_from_mask=self.recompute_box_from_mask,
)
return datapoint
else:
for i in range(len(datapoint.images)):
all_boxes = []
# Get all boxes in the current image
if self.respect_boxes:
all_boxes += [
obj.bbox.view(-1, 4) for obj in datapoint.images[i].objects
]
# Get all the boxes in the find queries that correspond to this image
if self.respect_input_boxes:
all_boxes += [
q.input_bbox.view(-1, 4)
for q in datapoint.find_queries
if q.image_id == i and q.input_bbox is not None
]
if all_boxes:
all_boxes = torch.cat(all_boxes, 0)
else:
all_boxes = torch.empty(0, 4)
all_points = [
q.input_points.view(-1, 3)[:, :2]
for q in datapoint.find_queries
if q.input_points is not None
]
if all_points:
all_points = torch.cat(all_points, 0)
else:
all_points = torch.empty(0, 2)
crop_param = self._sample_respect_boxes(
datapoint.images[i].data, all_boxes, all_points
)
datapoint = crop(
datapoint,
i,
crop_param,
v2=self.v2,
check_validity=self.respect_boxes,
check_input_validity=self.respect_input_boxes,
recompute_box_from_mask=self.recompute_box_from_mask,
)
return datapoint
else:
if self.consistent_transform:
# Check that all the images are the same size
w, h = datapoint.images[0].data.size
for img in datapoint.images:
assert img.data.size == (w, h)
crop_param = self._sample_no_respect_boxes(datapoint.images[0].data)
for i in range(len(datapoint.images)):
datapoint = crop(
datapoint,
i,
crop_param,
v2=self.v2,
check_validity=self.respect_boxes,
check_input_validity=self.respect_input_boxes,
recompute_box_from_mask=self.recompute_box_from_mask,
)
return datapoint
else:
for i in range(len(datapoint.images)):
crop_param = self._sample_no_respect_boxes(datapoint.images[i].data)
datapoint = crop(
datapoint,
i,
crop_param,
v2=self.v2,
check_validity=self.respect_boxes,
check_input_validity=self.respect_input_boxes,
recompute_box_from_mask=self.recompute_box_from_mask,
)
return datapoint
class CenterCropAPI:
def __init__(self, size, consistent_transform, recompute_box_from_mask=False):
self.size = size
self.consistent_transform = consistent_transform
self.recompute_box_from_mask = recompute_box_from_mask
def _sample_crop(self, image_width, image_height):
crop_height, crop_width = self.size
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop_top, crop_left, crop_height, crop_width
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
# Check that all the images are the same size
w, h = datapoint.images[0].data.size
for img in datapoint.images:
assert img.size == (w, h)
crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h)
for i in range(len(datapoint.images)):
datapoint = crop(
datapoint,
i,
(crop_top, crop_left, crop_height, crop_width),
recompute_box_from_mask=self.recompute_box_from_mask,
)
return datapoint
for i in range(len(datapoint.images)):
w, h = datapoint.images[i].data.size
crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h)
datapoint = crop(
datapoint,
i,
(crop_top, crop_left, crop_height, crop_width),
recompute_box_from_mask=self.recompute_box_from_mask,
)
return datapoint
class RandomHorizontalFlip:
def __init__(self, consistent_transform, p=0.5):
self.p = p
self.consistent_transform = consistent_transform
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
if random.random() < self.p:
for i in range(len(datapoint.images)):
datapoint = hflip(datapoint, i)
return datapoint
for i in range(len(datapoint.images)):
if random.random() < self.p:
datapoint = hflip(datapoint, i)
return datapoint
class RandomResizeAPI:
def __init__(
self, sizes, consistent_transform, max_size=None, square=False, v2=False
):
if isinstance(sizes, int):
sizes = (sizes,)
assert isinstance(sizes, Iterable)
self.sizes = list(sizes)
self.max_size = max_size
self.square = square
self.consistent_transform = consistent_transform
self.v2 = v2
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
size = random.choice(self.sizes)
for i in range(len(datapoint.images)):
datapoint = resize(
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
)
return datapoint
for i in range(len(datapoint.images)):
size = random.choice(self.sizes)
datapoint = resize(
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
)
return datapoint
class ScheduledRandomResizeAPI(RandomResizeAPI):
def __init__(self, size_scheduler, consistent_transform, square=False):
self.size_scheduler = size_scheduler
# Just a meaningful init value for super
params = self.size_scheduler(epoch_num=0)
sizes, max_size = params["sizes"], params["max_size"]
super().__init__(sizes, consistent_transform, max_size=max_size, square=square)
def __call__(self, datapoint, **kwargs):
assert "epoch" in kwargs, "Param scheduler needs to know the current epoch"
params = self.size_scheduler(kwargs["epoch"])
sizes, max_size = params["sizes"], params["max_size"]
self.sizes = sizes
self.max_size = max_size
datapoint = super(ScheduledRandomResizeAPI, self).__call__(datapoint, **kwargs)
return datapoint
class RandomPadAPI:
def __init__(self, max_pad, consistent_transform):
self.max_pad = max_pad
self.consistent_transform = consistent_transform
def _sample_pad(self):
pad_x = random.randint(0, self.max_pad)
pad_y = random.randint(0, self.max_pad)
return pad_x, pad_y
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
pad_x, pad_y = self._sample_pad()
for i in range(len(datapoint.images)):
datapoint = pad(datapoint, i, (pad_x, pad_y))
return datapoint
for i in range(len(datapoint.images)):
pad_x, pad_y = self._sample_pad()
datapoint = pad(datapoint, i, (pad_x, pad_y))
return datapoint
class PadToSizeAPI:
def __init__(self, size, consistent_transform, bottom_right=False, v2=False):
self.size = size
self.consistent_transform = consistent_transform
self.v2 = v2
self.bottom_right = bottom_right
def _sample_pad(self, w, h):
pad_x = self.size - w
pad_y = self.size - h
assert pad_x >= 0 and pad_y >= 0
pad_left = random.randint(0, pad_x)
pad_right = pad_x - pad_left
pad_top = random.randint(0, pad_y)
pad_bottom = pad_y - pad_top
return pad_left, pad_top, pad_right, pad_bottom
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
# Check that all the images are the same size
w, h = datapoint.images[0].data.size
for img in datapoint.images:
assert img.size == (w, h)
if self.bottom_right:
pad_right = self.size - w
pad_bottom = self.size - h
padding = (pad_right, pad_bottom)
else:
padding = self._sample_pad(w, h)
for i in range(len(datapoint.images)):
datapoint = pad(datapoint, i, padding, v2=self.v2)
return datapoint
for i, img in enumerate(datapoint.images):
w, h = img.data.size
if self.bottom_right:
pad_right = self.size - w
pad_bottom = self.size - h
padding = (pad_right, pad_bottom)
else:
padding = self._sample_pad(w, h)
datapoint = pad(datapoint, i, padding, v2=self.v2)
return datapoint
class RandomMosaicVideoAPI:
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
self.prob = prob
self.grid_h = grid_h
self.grid_w = grid_w
self.use_random_hflip = use_random_hflip
def __call__(self, datapoint, **kwargs):
if random.random() > self.prob:
return datapoint
# select a random location to place the target mask in the mosaic
target_grid_y = random.randint(0, self.grid_h - 1)
target_grid_x = random.randint(0, self.grid_w - 1)
# whether to flip each grid in the mosaic horizontally
if self.use_random_hflip:
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
else:
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
for i in range(len(datapoint.images)):
datapoint = random_mosaic_frame(
datapoint,
i,
grid_h=self.grid_h,
grid_w=self.grid_w,
target_grid_y=target_grid_y,
target_grid_x=target_grid_x,
should_hflip=should_hflip,
)
return datapoint
def random_mosaic_frame(
datapoint,
index,
grid_h,
grid_w,
target_grid_y,
target_grid_x,
should_hflip,
):
# Step 1: downsize the images and paste them into a mosaic
image_data = datapoint.images[index].data
is_pil = isinstance(image_data, PILImage.Image)
if is_pil:
H_im = image_data.height
W_im = image_data.width
image_data_output = PILImage.new("RGB", (W_im, H_im))
else:
H_im = image_data.size(-2)
W_im = image_data.size(-1)
image_data_output = torch.zeros_like(image_data)
downsize_cache = {}
for grid_y in range(grid_h):
for grid_x in range(grid_w):
y_offset_b = grid_y * H_im // grid_h
x_offset_b = grid_x * W_im // grid_w
y_offset_e = (grid_y + 1) * H_im // grid_h
x_offset_e = (grid_x + 1) * W_im // grid_w
H_im_downsize = y_offset_e - y_offset_b
W_im_downsize = x_offset_e - x_offset_b
if (H_im_downsize, W_im_downsize) in downsize_cache:
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
else:
image_data_downsize = F.resize(
image_data,
size=(H_im_downsize, W_im_downsize),
interpolation=InterpolationMode.BILINEAR,
antialias=True, # antialiasing for downsizing
)
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
if should_hflip[grid_y, grid_x].item():
image_data_downsize = F.hflip(image_data_downsize)
if is_pil:
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
else:
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
image_data_downsize
)
datapoint.images[index].data = image_data_output
# Step 2: downsize the masks and paste them into the target grid of the mosaic
# (note that we don't scale input/target boxes since they are not used in TA)
for obj in datapoint.images[index].objects:
if obj.segment is None:
continue
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
segment_output = torch.zeros_like(obj.segment)
target_y_offset_b = target_grid_y * H_im // grid_h
target_x_offset_b = target_grid_x * W_im // grid_w
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
target_H_im_downsize = target_y_offset_e - target_y_offset_b
target_W_im_downsize = target_x_offset_e - target_x_offset_b
segment_downsize = F.resize(
obj.segment[None, None],
size=(target_H_im_downsize, target_W_im_downsize),
interpolation=InterpolationMode.BILINEAR,
antialias=True, # antialiasing for downsizing
)[0, 0]
if should_hflip[target_grid_y, target_grid_x].item():
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
segment_output[
target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
] = segment_downsize
obj.segment = segment_output
return datapoint
class ScheduledPadToSizeAPI(PadToSizeAPI):
def __init__(self, size_scheduler, consistent_transform):
self.size_scheduler = size_scheduler
size = self.size_scheduler(epoch_num=0)["sizes"]
super().__init__(size, consistent_transform)
def __call__(self, datapoint, **kwargs):
assert "epoch" in kwargs, "Param scheduler needs to know the current epoch"
params = self.size_scheduler(kwargs["epoch"])
self.size = params["resolution"]
return super(ScheduledPadToSizeAPI, self).__call__(datapoint, **kwargs)
class IdentityAPI:
def __call__(self, datapoint, **kwargs):
return datapoint
class RandomSelectAPI:
"""
Randomly selects between transforms1 and transforms2,
with probability p for transforms1 and (1 - p) for transforms2
"""
def __init__(self, transforms1=None, transforms2=None, p=0.5):
self.transforms1 = transforms1 or IdentityAPI()
self.transforms2 = transforms2 or IdentityAPI()
self.p = p
def __call__(self, datapoint, **kwargs):
if random.random() < self.p:
return self.transforms1(datapoint, **kwargs)
return self.transforms2(datapoint, **kwargs)
class ToTensorAPI:
def __init__(self, v2=False):
self.v2 = v2
def __call__(self, datapoint: Datapoint, **kwargs):
for img in datapoint.images:
if self.v2:
img.data = Fv2.to_image_tensor(img.data)
# img.data = Fv2.to_dtype(img.data, torch.uint8, scale=True)
# img.data = Fv2.convert_image_dtype(img.data, torch.uint8)
else:
img.data = F.to_tensor(img.data)
return datapoint
class NormalizeAPI:
def __init__(self, mean, std, v2=False):
self.mean = mean
self.std = std
self.v2 = v2
def __call__(self, datapoint: Datapoint, **kwargs):
for img in datapoint.images:
if self.v2:
img.data = Fv2.convert_image_dtype(img.data, torch.float32)
img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
else:
img.data = F.normalize(img.data, mean=self.mean, std=self.std)
for obj in img.objects:
boxes = obj.bbox
cur_h, cur_w = img.data.shape[-2:]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor(
[cur_w, cur_h, cur_w, cur_h], dtype=torch.float32
)
obj.bbox = boxes
for query in datapoint.find_queries:
if query.input_bbox is not None:
boxes = query.input_bbox
cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:]
boxes = box_xyxy_to_cxcywh(boxes)
boxes = boxes / torch.tensor(
[cur_w, cur_h, cur_w, cur_h], dtype=torch.float32
)
query.input_bbox = boxes
if query.input_points is not None:
points = query.input_points
cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:]
points = points / torch.tensor([cur_w, cur_h, 1.0], dtype=torch.float32)
query.input_points = points
return datapoint
class ComposeAPI:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, datapoint, **kwargs):
for t in self.transforms:
datapoint = t(datapoint, **kwargs)
return datapoint
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class RandomGrayscale:
def __init__(self, consistent_transform, p=0.5):
self.p = p
self.consistent_transform = consistent_transform
self.Grayscale = T.Grayscale(num_output_channels=3)
def __call__(self, datapoint: Datapoint, **kwargs):
if self.consistent_transform:
if random.random() < self.p:
for img in datapoint.images:
img.data = self.Grayscale(img.data)
return datapoint
for img in datapoint.images:
if random.random() < self.p:
img.data = self.Grayscale(img.data)
return datapoint
class ColorJitter:
def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
self.consistent_transform = consistent_transform
self.brightness = (
brightness
if isinstance(brightness, list)
else [max(0, 1 - brightness), 1 + brightness]
)
self.contrast = (
contrast
if isinstance(contrast, list)
else [max(0, 1 - contrast), 1 + contrast]
)
self.saturation = (
saturation
if isinstance(saturation, list)
else [max(0, 1 - saturation), 1 + saturation]
)
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
def __call__(self, datapoint: Datapoint, **kwargs):
if self.consistent_transform:
# Create a color jitter transformation params
(
fn_idx,
brightness_factor,
contrast_factor,
saturation_factor,
hue_factor,
) = T.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for img in datapoint.images:
if not self.consistent_transform:
(
fn_idx,
brightness_factor,
contrast_factor,
saturation_factor,
hue_factor,
) = T.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img.data = F.adjust_brightness(img.data, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img.data = F.adjust_contrast(img.data, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img.data = F.adjust_saturation(img.data, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img.data = F.adjust_hue(img.data, hue_factor)
return datapoint
class RandomAffine:
def __init__(
self,
degrees,
consistent_transform,
scale=None,
translate=None,
shear=None,
image_mean=(123, 116, 103),
log_warning=True,
num_tentatives=1,
image_interpolation="bicubic",
):
"""
The mask is required for this transform.
if consistent_transform if True, then the same random affine is applied to all frames and masks.
"""
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
self.scale = scale
self.shear = (
shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
)
self.translate = translate
self.fill_img = image_mean
self.consistent_transform = consistent_transform
self.log_warning = log_warning
self.num_tentatives = num_tentatives
if image_interpolation == "bicubic":
self.image_interpolation = InterpolationMode.BICUBIC
elif image_interpolation == "bilinear":
self.image_interpolation = InterpolationMode.BILINEAR
else:
raise NotImplementedError
def __call__(self, datapoint: Datapoint, **kwargs):
for _tentative in range(self.num_tentatives):
res = self.transform_datapoint(datapoint)
if res is not None:
return res
if self.log_warning:
logging.warning(
f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
)
return datapoint
def transform_datapoint(self, datapoint: Datapoint):
_, height, width = F.get_dimensions(datapoint.images[0].data)
img_size = [width, height]
if self.consistent_transform:
# Create a random affine transformation
affine_params = T.RandomAffine.get_params(
degrees=self.degrees,
translate=self.translate,
scale_ranges=self.scale,
shears=self.shear,
img_size=img_size,
)
for img_idx, img in enumerate(datapoint.images):
this_masks = [
obj.segment.unsqueeze(0) if obj.segment is not None else None
for obj in img.objects
]
if not self.consistent_transform:
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
affine_params = T.RandomAffine.get_params(
degrees=self.degrees,
translate=self.translate,
scale_ranges=self.scale,
shears=self.shear,
img_size=img_size,
)
transformed_bboxes, transformed_masks = [], []
for i in range(len(img.objects)):
if this_masks[i] is None:
transformed_masks.append(None)
# Dummy bbox for a dummy target
transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]]))
else:
transformed_mask = F.affine(
this_masks[i],
*affine_params,
interpolation=InterpolationMode.NEAREST,
fill=0.0,
)
if img_idx == 0 and transformed_mask.max() == 0:
# We are dealing with a video and the object is not visible in the first frame
# Return the datapoint without transformation
return None
transformed_bbox = masks_to_boxes(transformed_mask)
transformed_bboxes.append(transformed_bbox)
transformed_masks.append(transformed_mask.squeeze())
for i in range(len(img.objects)):
img.objects[i].bbox = transformed_bboxes[i]
img.objects[i].segment = transformed_masks[i]
img.data = F.affine(
img.data,
*affine_params,
interpolation=self.image_interpolation,
fill=self.fill_img,
)
return datapoint
class RandomResizedCrop:
def __init__(
self,
consistent_transform,
size,
scale=None,
ratio=None,
log_warning=True,
num_tentatives=4,
keep_aspect_ratio=False,
):
"""
The mask is required for this transform.
if consistent_transform if True, then the same random resized crop is applied to all frames and masks.
"""
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
elif len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
else:
self.size = size
self.scale = scale if scale is not None else (0.08, 1.0)
self.ratio = ratio if ratio is not None else (3.0 / 4.0, 4.0 / 3.0)
self.consistent_transform = consistent_transform
self.log_warning = log_warning
self.num_tentatives = num_tentatives
self.keep_aspect_ratio = keep_aspect_ratio
def __call__(self, datapoint: Datapoint, **kwargs):
for _tentative in range(self.num_tentatives):
res = self.transform_datapoint(datapoint)
if res is not None:
return res
if self.log_warning:
logging.warning(
f"Skip RandomResizeCrop for zero-area mask in first frame after {self.num_tentatives} tentatives"
)
return datapoint
def transform_datapoint(self, datapoint: Datapoint):
if self.keep_aspect_ratio:
original_size = datapoint.images[0].size
original_ratio = original_size[1] / original_size[0]
ratio = [r * original_ratio for r in self.ratio]
else:
ratio = self.ratio
if self.consistent_transform:
# Create a random crop transformation
crop_params = T.RandomResizedCrop.get_params(
img=datapoint.images[0].data,
scale=self.scale,
ratio=ratio,
)
for img_idx, img in enumerate(datapoint.images):
if not self.consistent_transform:
# Create a random crop transformation
crop_params = T.RandomResizedCrop.get_params(
img=img.data,
scale=self.scale,
ratio=ratio,
)
this_masks = [
obj.segment.unsqueeze(0) if obj.segment is not None else None
for obj in img.objects
]
transformed_bboxes, transformed_masks = [], []
for i in range(len(img.objects)):
if this_masks[i] is None:
transformed_masks.append(None)
# Dummy bbox for a dummy target
transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]]))
else:
transformed_mask = F.resized_crop(
this_masks[i],
*crop_params,
size=self.size,
interpolation=InterpolationMode.NEAREST,
)
if img_idx == 0 and transformed_mask.max() == 0:
# We are dealing with a video and the object is not visible in the first frame
# Return the datapoint without transformation
return None
transformed_masks.append(transformed_mask.squeeze())
transformed_bbox = masks_to_boxes(transformed_mask)
transformed_bboxes.append(transformed_bbox)
# Set the new boxes and masks if all transformed masks and boxes are good.
for i in range(len(img.objects)):
img.objects[i].bbox = transformed_bboxes[i]
img.objects[i].segment = transformed_masks[i]
img.data = F.resized_crop(
img.data,
*crop_params,
size=self.size,
interpolation=InterpolationMode.BILINEAR,
)
return datapoint
class ResizeToMaxIfAbove:
# Resize datapoint image if one of its sides is larger that max_size
def __init__(
self,
max_size=None,
):
self.max_size = max_size
def __call__(self, datapoint: Datapoint, **kwargs):
_, height, width = F.get_dimensions(datapoint.images[0].data)
if height <= self.max_size and width <= self.max_size:
# The original frames are small enough
return datapoint
elif height >= width:
new_height = self.max_size
new_width = int(round(self.max_size * width / height))
else:
new_height = int(round(self.max_size * height / width))
new_width = self.max_size
size = new_height, new_width
for index in range(len(datapoint.images)):
datapoint.images[index].data = F.resize(datapoint.images[index].data, size)
for obj in datapoint.images[index].objects:
obj.segment = F.resize(
obj.segment[None, None],
size,
interpolation=InterpolationMode.NEAREST,
).squeeze()
h, w = size
datapoint.images[index].size = (h, w)
return datapoint
def get_bbox_xyxy_abs_coords_from_mask(mask):
"""Get the bounding box (XYXY format w/ absolute coordinates) of a binary mask."""
assert mask.dim() == 2
rows = torch.any(mask, dim=1)
cols = torch.any(mask, dim=0)
row_inds = rows.nonzero().view(-1)
col_inds = cols.nonzero().view(-1)
if row_inds.numel() == 0:
# mask is empty
bbox = torch.zeros(1, 4, dtype=torch.float32)
bbox_area = 0.0
else:
ymin, ymax = row_inds.min(), row_inds.max()
xmin, xmax = col_inds.min(), col_inds.max()
bbox = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32).view(1, 4)
bbox_area = float((ymax - ymin) * (xmax - xmin))
return bbox, bbox_area
class MotionBlur:
def __init__(self, kernel_size=5, consistent_transform=True, p=0.5):
assert kernel_size % 2 == 1, "Kernel size must be odd."
self.kernel_size = kernel_size
self.consistent_transform = consistent_transform
self.p = p
def __call__(self, datapoint: Datapoint, **kwargs):
if random.random() >= self.p:
return datapoint
if self.consistent_transform:
# Generate a single motion blur kernel for all images
kernel = self._generate_motion_blur_kernel()
for img in datapoint.images:
if not self.consistent_transform:
# Generate a new motion blur kernel for each image
kernel = self._generate_motion_blur_kernel()
img.data = self._apply_motion_blur(img.data, kernel)
return datapoint
def _generate_motion_blur_kernel(self):
kernel = torch.zeros((self.kernel_size, self.kernel_size))
direction = random.choice(["horizontal", "vertical", "diagonal"])
if direction == "horizontal":
kernel[self.kernel_size // 2, :] = 1.0
elif direction == "vertical":
kernel[:, self.kernel_size // 2] = 1.0
elif direction == "diagonal":
for i in range(self.kernel_size):
kernel[i, i] = 1.0
kernel /= kernel.sum()
return kernel
def _apply_motion_blur(self, image, kernel):
if isinstance(image, PILImage.Image):
image = F.to_tensor(image)
channels = image.shape[0]
kernel = kernel.to(image.device).unsqueeze(0).unsqueeze(0)
blurred_image = torch.nn.functional.conv2d(
image.unsqueeze(0),
kernel.repeat(channels, 1, 1, 1),
padding=self.kernel_size // 2,
groups=channels,
)
return F.to_pil_image(blurred_image.squeeze(0))
class LargeScaleJitter:
def __init__(
self,
scale_range=(0.1, 2.0),
aspect_ratio_range=(0.75, 1.33),
crop_size=(640, 640),
consistent_transform=True,
p=0.5,
):
"""
Args:rack
scale_range (tuple): Range of scaling factors (min_scale, max_scale).
aspect_ratio_range (tuple): Range of aspect ratios (min_aspect_ratio, max_aspect_ratio).
crop_size (tuple): Target size of the cropped region (width, height).
consistent_transform (bool): Whether to apply the same transformation across all frames.
p (float): Probability of applying the transformation.
"""
self.scale_range = scale_range
self.aspect_ratio_range = aspect_ratio_range
self.crop_size = crop_size
self.consistent_transform = consistent_transform
self.p = p
def __call__(self, datapoint: Datapoint, **kwargs):
if random.random() >= self.p:
return datapoint
# Sample a single scale factor and aspect ratio for all frames
log_ratio = torch.log(torch.tensor(self.aspect_ratio_range))
scale_factor = torch.empty(1).uniform_(*self.scale_range).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
for idx, img in enumerate(datapoint.images):
if not self.consistent_transform:
# Sample a new scale factor and aspect ratio for each frame
log_ratio = torch.log(torch.tensor(self.aspect_ratio_range))
scale_factor = torch.empty(1).uniform_(*self.scale_range).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
# Compute the dimensions of the jittered crop
original_width, original_height = img.data.size
target_area = original_width * original_height * scale_factor
crop_width = int(round((target_area * aspect_ratio) ** 0.5))
crop_height = int(round((target_area / aspect_ratio) ** 0.5))
# Randomly select the top-left corner of the crop
crop_x = random.randint(0, max(0, original_width - crop_width))
crop_y = random.randint(0, max(0, original_height - crop_height))
# Extract the cropped region
datapoint = crop(datapoint, idx, (crop_x, crop_y, crop_width, crop_height))
# Resize the cropped region to the target crop size
datapoint = resize(datapoint, idx, self.crop_size)
return datapoint