# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved """ Transforms and data augmentation for both image + bbox. """ import logging import numbers import random from collections.abc import Sequence from typing import Iterable import torch import torchvision.transforms as T import torchvision.transforms.functional as F import torchvision.transforms.v2.functional as Fv2 from PIL import Image as PILImage from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes from sam3.train.data.sam3_image_dataset import Datapoint from torchvision.transforms import InterpolationMode def crop( datapoint, index, region, v2=False, check_validity=True, check_input_validity=True, recompute_box_from_mask=False, ): if v2: rtop, rleft, rheight, rwidth = (int(round(r)) for r in region) datapoint.images[index].data = Fv2.crop( datapoint.images[index].data, top=rtop, left=rleft, height=rheight, width=rwidth, ) else: datapoint.images[index].data = F.crop(datapoint.images[index].data, *region) i, j, h, w = region # should we do something wrt the original size? datapoint.images[index].size = (h, w) for obj in datapoint.images[index].objects: # crop the mask if obj.segment is not None: obj.segment = F.crop(obj.segment, int(i), int(j), int(h), int(w)) # crop the bounding box if recompute_box_from_mask and obj.segment is not None: # here the boxes are still in XYXY format with absolute coordinates (they are # converted to CxCyWH with relative coordinates in basic_for_api.NormalizeAPI) obj.bbox, obj.area = get_bbox_xyxy_abs_coords_from_mask(obj.segment) else: if recompute_box_from_mask and obj.segment is None and obj.area > 0: logging.warning( "Cannot recompute bounding box from mask since `obj.segment` is None. " "Falling back to directly cropping from the input bounding box." ) boxes = obj.bbox.view(1, 4) max_size = torch.as_tensor([w, h], dtype=torch.float32) cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32) cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) cropped_boxes = cropped_boxes.clamp(min=0) obj.area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) obj.bbox = cropped_boxes.reshape(-1, 4) for query in datapoint.find_queries: if query.semantic_target is not None: query.semantic_target = F.crop( query.semantic_target, int(i), int(j), int(h), int(w) ) if query.image_id == index and query.input_bbox is not None: boxes = query.input_bbox max_size = torch.as_tensor([w, h], dtype=torch.float32) cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32) cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) cropped_boxes = cropped_boxes.clamp(min=0) # cur_area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) # if check_input_validity: # assert ( # (cur_area > 0).all().item() # ), "Some input box got cropped out by the crop transform" query.input_bbox = cropped_boxes.reshape(-1, 4) if query.image_id == index and query.input_points is not None: print( "Warning! Point cropping with this function may lead to unexpected results" ) points = query.input_points # Unlike right-lower box edges, which are exclusive, the # point must be in [0, length-1], hence the -1 max_size = torch.as_tensor([w, h], dtype=torch.float32) - 1 cropped_points = points - torch.as_tensor([j, i, 0], dtype=torch.float32) cropped_points[:, :, :2] = torch.min(cropped_points[:, :, :2], max_size) cropped_points[:, :, :2] = cropped_points[:, :, :2].clamp(min=0) query.input_points = cropped_points if check_validity: # Check that all boxes are still valid for obj in datapoint.images[index].objects: assert obj.area > 0, "Box {} has no area".format(obj.bbox) return datapoint def hflip(datapoint, index): datapoint.images[index].data = F.hflip(datapoint.images[index].data) w, h = datapoint.images[index].data.size for obj in datapoint.images[index].objects: boxes = obj.bbox.view(1, 4) boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( [-1, 1, -1, 1] ) + torch.as_tensor([w, 0, w, 0]) obj.bbox = boxes if obj.segment is not None: obj.segment = F.hflip(obj.segment) for query in datapoint.find_queries: if query.semantic_target is not None: query.semantic_target = F.hflip(query.semantic_target) if query.image_id == index and query.input_bbox is not None: boxes = query.input_bbox boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( [-1, 1, -1, 1] ) + torch.as_tensor([w, 0, w, 0]) query.input_bbox = boxes if query.image_id == index and query.input_points is not None: points = query.input_points points = points * torch.as_tensor([-1, 1, 1]) + torch.as_tensor([w, 0, 0]) query.input_points = points return datapoint def get_size_with_aspect_ratio(image_size, size, max_size=None): w, h = image_size if max_size is not None: min_original_size = float(min((w, h))) max_original_size = float(max((w, h))) if max_original_size / min_original_size * size > max_size: size = max_size * min_original_size / max_original_size if (w <= h and w == size) or (h <= w and h == size): return (h, w) if w < h: ow = int(round(size)) oh = int(round(size * h / w)) else: oh = int(round(size)) ow = int(round(size * w / h)) return (oh, ow) def resize(datapoint, index, size, max_size=None, square=False, v2=False): # size can be min_size (scalar) or (w, h) tuple def get_size(image_size, size, max_size=None): if isinstance(size, (list, tuple)): return size[::-1] else: return get_size_with_aspect_ratio(image_size, size, max_size) if square: size = size, size else: cur_size = ( datapoint.images[index].data.size()[-2:][::-1] if v2 else datapoint.images[index].data.size ) size = get_size(cur_size, size, max_size) old_size = ( datapoint.images[index].data.size()[-2:][::-1] if v2 else datapoint.images[index].data.size ) if v2: datapoint.images[index].data = Fv2.resize( datapoint.images[index].data, size, antialias=True ) else: datapoint.images[index].data = F.resize(datapoint.images[index].data, size) new_size = ( datapoint.images[index].data.size()[-2:][::-1] if v2 else datapoint.images[index].data.size ) ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, old_size)) ratio_width, ratio_height = ratios for obj in datapoint.images[index].objects: boxes = obj.bbox.view(1, 4) scaled_boxes = boxes * torch.as_tensor( [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32 ) obj.bbox = scaled_boxes obj.area *= ratio_width * ratio_height if obj.segment is not None: obj.segment = F.resize(obj.segment[None, None], size).squeeze() for query in datapoint.find_queries: if query.semantic_target is not None: query.semantic_target = F.resize( query.semantic_target[None, None], size ).squeeze() if query.image_id == index and query.input_bbox is not None: boxes = query.input_bbox scaled_boxes = boxes * torch.as_tensor( [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32, ) query.input_bbox = scaled_boxes if query.image_id == index and query.input_points is not None: points = query.input_points scaled_points = points * torch.as_tensor( [ratio_width, ratio_height, 1], dtype=torch.float32, ) query.input_points = scaled_points h, w = size datapoint.images[index].size = (h, w) return datapoint def pad(datapoint, index, padding, v2=False): old_h, old_w = datapoint.images[index].size h, w = old_h, old_w if len(padding) == 2: # assumes that we only pad on the bottom right corners if v2: datapoint.images[index].data = Fv2.pad( datapoint.images[index].data, (0, 0, padding[0], padding[1]) ) else: datapoint.images[index].data = F.pad( datapoint.images[index].data, (0, 0, padding[0], padding[1]) ) h += padding[1] w += padding[0] else: if v2: # left, top, right, bottom datapoint.images[index].data = Fv2.pad( datapoint.images[index].data, (padding[0], padding[1], padding[2], padding[3]), ) else: # left, top, right, bottom datapoint.images[index].data = F.pad( datapoint.images[index].data, (padding[0], padding[1], padding[2], padding[3]), ) h += padding[1] + padding[3] w += padding[0] + padding[2] datapoint.images[index].size = (h, w) for obj in datapoint.images[index].objects: if len(padding) != 2: obj.bbox += torch.as_tensor( [padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32 ) if obj.segment is not None: if v2: if len(padding) == 2: obj.segment = Fv2.pad( obj.segment[None], (0, 0, padding[0], padding[1]) ).squeeze(0) else: obj.segment = Fv2.pad(obj.segment[None], tuple(padding)).squeeze(0) else: if len(padding) == 2: obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) else: obj.segment = F.pad(obj.segment, tuple(padding)) for query in datapoint.find_queries: if query.semantic_target is not None: if v2: if len(padding) == 2: query.semantic_target = Fv2.pad( query.semantic_target[None, None], (0, 0, padding[0], padding[1]), ).squeeze() else: query.semantic_target = Fv2.pad( query.semantic_target[None, None], tuple(padding) ).squeeze() else: if len(padding) == 2: query.semantic_target = F.pad( query.semantic_target[None, None], (0, 0, padding[0], padding[1]), ).squeeze() else: query.semantic_target = F.pad( query.semantic_target[None, None], tuple(padding) ).squeeze() if query.image_id == index and query.input_bbox is not None: if len(padding) != 2: query.input_bbox += torch.as_tensor( [padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32, ) if query.image_id == index and query.input_points is not None: if len(padding) != 2: query.input_points += torch.as_tensor( [padding[0], padding[1], 0], dtype=torch.float32 ) return datapoint class RandomSizeCropAPI: def __init__( self, min_size: int, max_size: int, respect_boxes: bool, consistent_transform: bool, respect_input_boxes: bool = True, v2: bool = False, recompute_box_from_mask: bool = False, ): self.min_size = min_size self.max_size = max_size self.respect_boxes = respect_boxes # if True we can't crop a box out self.respect_input_boxes = respect_input_boxes self.consistent_transform = consistent_transform self.v2 = v2 self.recompute_box_from_mask = recompute_box_from_mask def _sample_no_respect_boxes(self, img): w = random.randint(self.min_size, min(img.width, self.max_size)) h = random.randint(self.min_size, min(img.height, self.max_size)) return T.RandomCrop.get_params(img, (h, w)) def _sample_respect_boxes(self, img, boxes, points, min_box_size=10.0): """ Assure that no box or point is dropped via cropping, though portions of boxes may be removed. """ if len(boxes) == 0 and len(points) == 0: return self._sample_no_respect_boxes(img) if self.v2: img_height, img_width = img.size()[-2:] else: img_width, img_height = img.size minW, minH, maxW, maxH = ( min(img_width, self.min_size), min(img_height, self.min_size), min(img_width, self.max_size), min(img_height, self.max_size), ) # The crop box must extend one pixel beyond points to the bottom/right # to assure the exclusive box contains the points. minX = ( torch.cat([boxes[:, 0] + min_box_size, points[:, 0] + 1], dim=0) .max() .item() ) minY = ( torch.cat([boxes[:, 1] + min_box_size, points[:, 1] + 1], dim=0) .max() .item() ) minX = min(img_width, minX) minY = min(img_height, minY) maxX = torch.cat([boxes[:, 2] - min_box_size, points[:, 0]], dim=0).min().item() maxY = torch.cat([boxes[:, 3] - min_box_size, points[:, 1]], dim=0).min().item() maxX = max(0.0, maxX) maxY = max(0.0, maxY) minW = max(minW, minX - maxX) minH = max(minH, minY - maxY) w = random.uniform(minW, max(minW, maxW)) h = random.uniform(minH, max(minH, maxH)) if minX > maxX: # i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1))) i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w))) else: i = random.uniform( max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1)) ) if minY > maxY: # j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1))) j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h))) else: j = random.uniform( max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1)) ) return [j, i, h, w] def __call__(self, datapoint, **kwargs): if self.respect_boxes or self.respect_input_boxes: if self.consistent_transform: # Check that all the images are the same size w, h = datapoint.images[0].data.size for img in datapoint.images: assert img.data.size == (w, h) all_boxes = [] # Getting all boxes in all the images if self.respect_boxes: all_boxes += [ obj.bbox.view(-1, 4) for img in datapoint.images for obj in img.objects ] # Get all the boxes in the find queries if self.respect_input_boxes: all_boxes += [ q.input_bbox.view(-1, 4) for q in datapoint.find_queries if q.input_bbox is not None ] if all_boxes: all_boxes = torch.cat(all_boxes, 0) else: all_boxes = torch.empty(0, 4) all_points = [ q.input_points.view(-1, 3)[:, :2] for q in datapoint.find_queries if q.input_points is not None ] if all_points: all_points = torch.cat(all_points, 0) else: all_points = torch.empty(0, 2) crop_param = self._sample_respect_boxes( datapoint.images[0].data, all_boxes, all_points ) for i in range(len(datapoint.images)): datapoint = crop( datapoint, i, crop_param, v2=self.v2, check_validity=self.respect_boxes, check_input_validity=self.respect_input_boxes, recompute_box_from_mask=self.recompute_box_from_mask, ) return datapoint else: for i in range(len(datapoint.images)): all_boxes = [] # Get all boxes in the current image if self.respect_boxes: all_boxes += [ obj.bbox.view(-1, 4) for obj in datapoint.images[i].objects ] # Get all the boxes in the find queries that correspond to this image if self.respect_input_boxes: all_boxes += [ q.input_bbox.view(-1, 4) for q in datapoint.find_queries if q.image_id == i and q.input_bbox is not None ] if all_boxes: all_boxes = torch.cat(all_boxes, 0) else: all_boxes = torch.empty(0, 4) all_points = [ q.input_points.view(-1, 3)[:, :2] for q in datapoint.find_queries if q.input_points is not None ] if all_points: all_points = torch.cat(all_points, 0) else: all_points = torch.empty(0, 2) crop_param = self._sample_respect_boxes( datapoint.images[i].data, all_boxes, all_points ) datapoint = crop( datapoint, i, crop_param, v2=self.v2, check_validity=self.respect_boxes, check_input_validity=self.respect_input_boxes, recompute_box_from_mask=self.recompute_box_from_mask, ) return datapoint else: if self.consistent_transform: # Check that all the images are the same size w, h = datapoint.images[0].data.size for img in datapoint.images: assert img.data.size == (w, h) crop_param = self._sample_no_respect_boxes(datapoint.images[0].data) for i in range(len(datapoint.images)): datapoint = crop( datapoint, i, crop_param, v2=self.v2, check_validity=self.respect_boxes, check_input_validity=self.respect_input_boxes, recompute_box_from_mask=self.recompute_box_from_mask, ) return datapoint else: for i in range(len(datapoint.images)): crop_param = self._sample_no_respect_boxes(datapoint.images[i].data) datapoint = crop( datapoint, i, crop_param, v2=self.v2, check_validity=self.respect_boxes, check_input_validity=self.respect_input_boxes, recompute_box_from_mask=self.recompute_box_from_mask, ) return datapoint class CenterCropAPI: def __init__(self, size, consistent_transform, recompute_box_from_mask=False): self.size = size self.consistent_transform = consistent_transform self.recompute_box_from_mask = recompute_box_from_mask def _sample_crop(self, image_width, image_height): crop_height, crop_width = self.size crop_top = int(round((image_height - crop_height) / 2.0)) crop_left = int(round((image_width - crop_width) / 2.0)) return crop_top, crop_left, crop_height, crop_width def __call__(self, datapoint, **kwargs): if self.consistent_transform: # Check that all the images are the same size w, h = datapoint.images[0].data.size for img in datapoint.images: assert img.size == (w, h) crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h) for i in range(len(datapoint.images)): datapoint = crop( datapoint, i, (crop_top, crop_left, crop_height, crop_width), recompute_box_from_mask=self.recompute_box_from_mask, ) return datapoint for i in range(len(datapoint.images)): w, h = datapoint.images[i].data.size crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h) datapoint = crop( datapoint, i, (crop_top, crop_left, crop_height, crop_width), recompute_box_from_mask=self.recompute_box_from_mask, ) return datapoint class RandomHorizontalFlip: def __init__(self, consistent_transform, p=0.5): self.p = p self.consistent_transform = consistent_transform def __call__(self, datapoint, **kwargs): if self.consistent_transform: if random.random() < self.p: for i in range(len(datapoint.images)): datapoint = hflip(datapoint, i) return datapoint for i in range(len(datapoint.images)): if random.random() < self.p: datapoint = hflip(datapoint, i) return datapoint class RandomResizeAPI: def __init__( self, sizes, consistent_transform, max_size=None, square=False, v2=False ): if isinstance(sizes, int): sizes = (sizes,) assert isinstance(sizes, Iterable) self.sizes = list(sizes) self.max_size = max_size self.square = square self.consistent_transform = consistent_transform self.v2 = v2 def __call__(self, datapoint, **kwargs): if self.consistent_transform: size = random.choice(self.sizes) for i in range(len(datapoint.images)): datapoint = resize( datapoint, i, size, self.max_size, square=self.square, v2=self.v2 ) return datapoint for i in range(len(datapoint.images)): size = random.choice(self.sizes) datapoint = resize( datapoint, i, size, self.max_size, square=self.square, v2=self.v2 ) return datapoint class ScheduledRandomResizeAPI(RandomResizeAPI): def __init__(self, size_scheduler, consistent_transform, square=False): self.size_scheduler = size_scheduler # Just a meaningful init value for super params = self.size_scheduler(epoch_num=0) sizes, max_size = params["sizes"], params["max_size"] super().__init__(sizes, consistent_transform, max_size=max_size, square=square) def __call__(self, datapoint, **kwargs): assert "epoch" in kwargs, "Param scheduler needs to know the current epoch" params = self.size_scheduler(kwargs["epoch"]) sizes, max_size = params["sizes"], params["max_size"] self.sizes = sizes self.max_size = max_size datapoint = super(ScheduledRandomResizeAPI, self).__call__(datapoint, **kwargs) return datapoint class RandomPadAPI: def __init__(self, max_pad, consistent_transform): self.max_pad = max_pad self.consistent_transform = consistent_transform def _sample_pad(self): pad_x = random.randint(0, self.max_pad) pad_y = random.randint(0, self.max_pad) return pad_x, pad_y def __call__(self, datapoint, **kwargs): if self.consistent_transform: pad_x, pad_y = self._sample_pad() for i in range(len(datapoint.images)): datapoint = pad(datapoint, i, (pad_x, pad_y)) return datapoint for i in range(len(datapoint.images)): pad_x, pad_y = self._sample_pad() datapoint = pad(datapoint, i, (pad_x, pad_y)) return datapoint class PadToSizeAPI: def __init__(self, size, consistent_transform, bottom_right=False, v2=False): self.size = size self.consistent_transform = consistent_transform self.v2 = v2 self.bottom_right = bottom_right def _sample_pad(self, w, h): pad_x = self.size - w pad_y = self.size - h assert pad_x >= 0 and pad_y >= 0 pad_left = random.randint(0, pad_x) pad_right = pad_x - pad_left pad_top = random.randint(0, pad_y) pad_bottom = pad_y - pad_top return pad_left, pad_top, pad_right, pad_bottom def __call__(self, datapoint, **kwargs): if self.consistent_transform: # Check that all the images are the same size w, h = datapoint.images[0].data.size for img in datapoint.images: assert img.size == (w, h) if self.bottom_right: pad_right = self.size - w pad_bottom = self.size - h padding = (pad_right, pad_bottom) else: padding = self._sample_pad(w, h) for i in range(len(datapoint.images)): datapoint = pad(datapoint, i, padding, v2=self.v2) return datapoint for i, img in enumerate(datapoint.images): w, h = img.data.size if self.bottom_right: pad_right = self.size - w pad_bottom = self.size - h padding = (pad_right, pad_bottom) else: padding = self._sample_pad(w, h) datapoint = pad(datapoint, i, padding, v2=self.v2) return datapoint class RandomMosaicVideoAPI: def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): self.prob = prob self.grid_h = grid_h self.grid_w = grid_w self.use_random_hflip = use_random_hflip def __call__(self, datapoint, **kwargs): if random.random() > self.prob: return datapoint # select a random location to place the target mask in the mosaic target_grid_y = random.randint(0, self.grid_h - 1) target_grid_x = random.randint(0, self.grid_w - 1) # whether to flip each grid in the mosaic horizontally if self.use_random_hflip: should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 else: should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) for i in range(len(datapoint.images)): datapoint = random_mosaic_frame( datapoint, i, grid_h=self.grid_h, grid_w=self.grid_w, target_grid_y=target_grid_y, target_grid_x=target_grid_x, should_hflip=should_hflip, ) return datapoint def random_mosaic_frame( datapoint, index, grid_h, grid_w, target_grid_y, target_grid_x, should_hflip, ): # Step 1: downsize the images and paste them into a mosaic image_data = datapoint.images[index].data is_pil = isinstance(image_data, PILImage.Image) if is_pil: H_im = image_data.height W_im = image_data.width image_data_output = PILImage.new("RGB", (W_im, H_im)) else: H_im = image_data.size(-2) W_im = image_data.size(-1) image_data_output = torch.zeros_like(image_data) downsize_cache = {} for grid_y in range(grid_h): for grid_x in range(grid_w): y_offset_b = grid_y * H_im // grid_h x_offset_b = grid_x * W_im // grid_w y_offset_e = (grid_y + 1) * H_im // grid_h x_offset_e = (grid_x + 1) * W_im // grid_w H_im_downsize = y_offset_e - y_offset_b W_im_downsize = x_offset_e - x_offset_b if (H_im_downsize, W_im_downsize) in downsize_cache: image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] else: image_data_downsize = F.resize( image_data, size=(H_im_downsize, W_im_downsize), interpolation=InterpolationMode.BILINEAR, antialias=True, # antialiasing for downsizing ) downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize if should_hflip[grid_y, grid_x].item(): image_data_downsize = F.hflip(image_data_downsize) if is_pil: image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) else: image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( image_data_downsize ) datapoint.images[index].data = image_data_output # Step 2: downsize the masks and paste them into the target grid of the mosaic # (note that we don't scale input/target boxes since they are not used in TA) for obj in datapoint.images[index].objects: if obj.segment is None: continue assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 segment_output = torch.zeros_like(obj.segment) target_y_offset_b = target_grid_y * H_im // grid_h target_x_offset_b = target_grid_x * W_im // grid_w target_y_offset_e = (target_grid_y + 1) * H_im // grid_h target_x_offset_e = (target_grid_x + 1) * W_im // grid_w target_H_im_downsize = target_y_offset_e - target_y_offset_b target_W_im_downsize = target_x_offset_e - target_x_offset_b segment_downsize = F.resize( obj.segment[None, None], size=(target_H_im_downsize, target_W_im_downsize), interpolation=InterpolationMode.BILINEAR, antialias=True, # antialiasing for downsizing )[0, 0] if should_hflip[target_grid_y, target_grid_x].item(): segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] segment_output[ target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e ] = segment_downsize obj.segment = segment_output return datapoint class ScheduledPadToSizeAPI(PadToSizeAPI): def __init__(self, size_scheduler, consistent_transform): self.size_scheduler = size_scheduler size = self.size_scheduler(epoch_num=0)["sizes"] super().__init__(size, consistent_transform) def __call__(self, datapoint, **kwargs): assert "epoch" in kwargs, "Param scheduler needs to know the current epoch" params = self.size_scheduler(kwargs["epoch"]) self.size = params["resolution"] return super(ScheduledPadToSizeAPI, self).__call__(datapoint, **kwargs) class IdentityAPI: def __call__(self, datapoint, **kwargs): return datapoint class RandomSelectAPI: """ Randomly selects between transforms1 and transforms2, with probability p for transforms1 and (1 - p) for transforms2 """ def __init__(self, transforms1=None, transforms2=None, p=0.5): self.transforms1 = transforms1 or IdentityAPI() self.transforms2 = transforms2 or IdentityAPI() self.p = p def __call__(self, datapoint, **kwargs): if random.random() < self.p: return self.transforms1(datapoint, **kwargs) return self.transforms2(datapoint, **kwargs) class ToTensorAPI: def __init__(self, v2=False): self.v2 = v2 def __call__(self, datapoint: Datapoint, **kwargs): for img in datapoint.images: if self.v2: img.data = Fv2.to_image_tensor(img.data) # img.data = Fv2.to_dtype(img.data, torch.uint8, scale=True) # img.data = Fv2.convert_image_dtype(img.data, torch.uint8) else: img.data = F.to_tensor(img.data) return datapoint class NormalizeAPI: def __init__(self, mean, std, v2=False): self.mean = mean self.std = std self.v2 = v2 def __call__(self, datapoint: Datapoint, **kwargs): for img in datapoint.images: if self.v2: img.data = Fv2.convert_image_dtype(img.data, torch.float32) img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) else: img.data = F.normalize(img.data, mean=self.mean, std=self.std) for obj in img.objects: boxes = obj.bbox cur_h, cur_w = img.data.shape[-2:] boxes = box_xyxy_to_cxcywh(boxes) boxes = boxes / torch.tensor( [cur_w, cur_h, cur_w, cur_h], dtype=torch.float32 ) obj.bbox = boxes for query in datapoint.find_queries: if query.input_bbox is not None: boxes = query.input_bbox cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:] boxes = box_xyxy_to_cxcywh(boxes) boxes = boxes / torch.tensor( [cur_w, cur_h, cur_w, cur_h], dtype=torch.float32 ) query.input_bbox = boxes if query.input_points is not None: points = query.input_points cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:] points = points / torch.tensor([cur_w, cur_h, 1.0], dtype=torch.float32) query.input_points = points return datapoint class ComposeAPI: def __init__(self, transforms): self.transforms = transforms def __call__(self, datapoint, **kwargs): for t in self.transforms: datapoint = t(datapoint, **kwargs) return datapoint def __repr__(self): format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += "\n" format_string += " {0}".format(t) format_string += "\n)" return format_string class RandomGrayscale: def __init__(self, consistent_transform, p=0.5): self.p = p self.consistent_transform = consistent_transform self.Grayscale = T.Grayscale(num_output_channels=3) def __call__(self, datapoint: Datapoint, **kwargs): if self.consistent_transform: if random.random() < self.p: for img in datapoint.images: img.data = self.Grayscale(img.data) return datapoint for img in datapoint.images: if random.random() < self.p: img.data = self.Grayscale(img.data) return datapoint class ColorJitter: def __init__(self, consistent_transform, brightness, contrast, saturation, hue): self.consistent_transform = consistent_transform self.brightness = ( brightness if isinstance(brightness, list) else [max(0, 1 - brightness), 1 + brightness] ) self.contrast = ( contrast if isinstance(contrast, list) else [max(0, 1 - contrast), 1 + contrast] ) self.saturation = ( saturation if isinstance(saturation, list) else [max(0, 1 - saturation), 1 + saturation] ) self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) def __call__(self, datapoint: Datapoint, **kwargs): if self.consistent_transform: # Create a color jitter transformation params ( fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, ) = T.ColorJitter.get_params( self.brightness, self.contrast, self.saturation, self.hue ) for img in datapoint.images: if not self.consistent_transform: ( fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, ) = T.ColorJitter.get_params( self.brightness, self.contrast, self.saturation, self.hue ) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: img.data = F.adjust_brightness(img.data, brightness_factor) elif fn_id == 1 and contrast_factor is not None: img.data = F.adjust_contrast(img.data, contrast_factor) elif fn_id == 2 and saturation_factor is not None: img.data = F.adjust_saturation(img.data, saturation_factor) elif fn_id == 3 and hue_factor is not None: img.data = F.adjust_hue(img.data, hue_factor) return datapoint class RandomAffine: def __init__( self, degrees, consistent_transform, scale=None, translate=None, shear=None, image_mean=(123, 116, 103), log_warning=True, num_tentatives=1, image_interpolation="bicubic", ): """ The mask is required for this transform. if consistent_transform if True, then the same random affine is applied to all frames and masks. """ self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) self.scale = scale self.shear = ( shear if isinstance(shear, list) else ([-shear, shear] if shear else None) ) self.translate = translate self.fill_img = image_mean self.consistent_transform = consistent_transform self.log_warning = log_warning self.num_tentatives = num_tentatives if image_interpolation == "bicubic": self.image_interpolation = InterpolationMode.BICUBIC elif image_interpolation == "bilinear": self.image_interpolation = InterpolationMode.BILINEAR else: raise NotImplementedError def __call__(self, datapoint: Datapoint, **kwargs): for _tentative in range(self.num_tentatives): res = self.transform_datapoint(datapoint) if res is not None: return res if self.log_warning: logging.warning( f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" ) return datapoint def transform_datapoint(self, datapoint: Datapoint): _, height, width = F.get_dimensions(datapoint.images[0].data) img_size = [width, height] if self.consistent_transform: # Create a random affine transformation affine_params = T.RandomAffine.get_params( degrees=self.degrees, translate=self.translate, scale_ranges=self.scale, shears=self.shear, img_size=img_size, ) for img_idx, img in enumerate(datapoint.images): this_masks = [ obj.segment.unsqueeze(0) if obj.segment is not None else None for obj in img.objects ] if not self.consistent_transform: # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation affine_params = T.RandomAffine.get_params( degrees=self.degrees, translate=self.translate, scale_ranges=self.scale, shears=self.shear, img_size=img_size, ) transformed_bboxes, transformed_masks = [], [] for i in range(len(img.objects)): if this_masks[i] is None: transformed_masks.append(None) # Dummy bbox for a dummy target transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]])) else: transformed_mask = F.affine( this_masks[i], *affine_params, interpolation=InterpolationMode.NEAREST, fill=0.0, ) if img_idx == 0 and transformed_mask.max() == 0: # We are dealing with a video and the object is not visible in the first frame # Return the datapoint without transformation return None transformed_bbox = masks_to_boxes(transformed_mask) transformed_bboxes.append(transformed_bbox) transformed_masks.append(transformed_mask.squeeze()) for i in range(len(img.objects)): img.objects[i].bbox = transformed_bboxes[i] img.objects[i].segment = transformed_masks[i] img.data = F.affine( img.data, *affine_params, interpolation=self.image_interpolation, fill=self.fill_img, ) return datapoint class RandomResizedCrop: def __init__( self, consistent_transform, size, scale=None, ratio=None, log_warning=True, num_tentatives=4, keep_aspect_ratio=False, ): """ The mask is required for this transform. if consistent_transform if True, then the same random resized crop is applied to all frames and masks. """ if isinstance(size, numbers.Number): self.size = (int(size), int(size)) elif isinstance(size, Sequence) and len(size) == 1: self.size = (size[0], size[0]) elif len(size) != 2: raise ValueError("Please provide only two dimensions (h, w) for size.") else: self.size = size self.scale = scale if scale is not None else (0.08, 1.0) self.ratio = ratio if ratio is not None else (3.0 / 4.0, 4.0 / 3.0) self.consistent_transform = consistent_transform self.log_warning = log_warning self.num_tentatives = num_tentatives self.keep_aspect_ratio = keep_aspect_ratio def __call__(self, datapoint: Datapoint, **kwargs): for _tentative in range(self.num_tentatives): res = self.transform_datapoint(datapoint) if res is not None: return res if self.log_warning: logging.warning( f"Skip RandomResizeCrop for zero-area mask in first frame after {self.num_tentatives} tentatives" ) return datapoint def transform_datapoint(self, datapoint: Datapoint): if self.keep_aspect_ratio: original_size = datapoint.images[0].size original_ratio = original_size[1] / original_size[0] ratio = [r * original_ratio for r in self.ratio] else: ratio = self.ratio if self.consistent_transform: # Create a random crop transformation crop_params = T.RandomResizedCrop.get_params( img=datapoint.images[0].data, scale=self.scale, ratio=ratio, ) for img_idx, img in enumerate(datapoint.images): if not self.consistent_transform: # Create a random crop transformation crop_params = T.RandomResizedCrop.get_params( img=img.data, scale=self.scale, ratio=ratio, ) this_masks = [ obj.segment.unsqueeze(0) if obj.segment is not None else None for obj in img.objects ] transformed_bboxes, transformed_masks = [], [] for i in range(len(img.objects)): if this_masks[i] is None: transformed_masks.append(None) # Dummy bbox for a dummy target transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]])) else: transformed_mask = F.resized_crop( this_masks[i], *crop_params, size=self.size, interpolation=InterpolationMode.NEAREST, ) if img_idx == 0 and transformed_mask.max() == 0: # We are dealing with a video and the object is not visible in the first frame # Return the datapoint without transformation return None transformed_masks.append(transformed_mask.squeeze()) transformed_bbox = masks_to_boxes(transformed_mask) transformed_bboxes.append(transformed_bbox) # Set the new boxes and masks if all transformed masks and boxes are good. for i in range(len(img.objects)): img.objects[i].bbox = transformed_bboxes[i] img.objects[i].segment = transformed_masks[i] img.data = F.resized_crop( img.data, *crop_params, size=self.size, interpolation=InterpolationMode.BILINEAR, ) return datapoint class ResizeToMaxIfAbove: # Resize datapoint image if one of its sides is larger that max_size def __init__( self, max_size=None, ): self.max_size = max_size def __call__(self, datapoint: Datapoint, **kwargs): _, height, width = F.get_dimensions(datapoint.images[0].data) if height <= self.max_size and width <= self.max_size: # The original frames are small enough return datapoint elif height >= width: new_height = self.max_size new_width = int(round(self.max_size * width / height)) else: new_height = int(round(self.max_size * height / width)) new_width = self.max_size size = new_height, new_width for index in range(len(datapoint.images)): datapoint.images[index].data = F.resize(datapoint.images[index].data, size) for obj in datapoint.images[index].objects: obj.segment = F.resize( obj.segment[None, None], size, interpolation=InterpolationMode.NEAREST, ).squeeze() h, w = size datapoint.images[index].size = (h, w) return datapoint def get_bbox_xyxy_abs_coords_from_mask(mask): """Get the bounding box (XYXY format w/ absolute coordinates) of a binary mask.""" assert mask.dim() == 2 rows = torch.any(mask, dim=1) cols = torch.any(mask, dim=0) row_inds = rows.nonzero().view(-1) col_inds = cols.nonzero().view(-1) if row_inds.numel() == 0: # mask is empty bbox = torch.zeros(1, 4, dtype=torch.float32) bbox_area = 0.0 else: ymin, ymax = row_inds.min(), row_inds.max() xmin, xmax = col_inds.min(), col_inds.max() bbox = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32).view(1, 4) bbox_area = float((ymax - ymin) * (xmax - xmin)) return bbox, bbox_area class MotionBlur: def __init__(self, kernel_size=5, consistent_transform=True, p=0.5): assert kernel_size % 2 == 1, "Kernel size must be odd." self.kernel_size = kernel_size self.consistent_transform = consistent_transform self.p = p def __call__(self, datapoint: Datapoint, **kwargs): if random.random() >= self.p: return datapoint if self.consistent_transform: # Generate a single motion blur kernel for all images kernel = self._generate_motion_blur_kernel() for img in datapoint.images: if not self.consistent_transform: # Generate a new motion blur kernel for each image kernel = self._generate_motion_blur_kernel() img.data = self._apply_motion_blur(img.data, kernel) return datapoint def _generate_motion_blur_kernel(self): kernel = torch.zeros((self.kernel_size, self.kernel_size)) direction = random.choice(["horizontal", "vertical", "diagonal"]) if direction == "horizontal": kernel[self.kernel_size // 2, :] = 1.0 elif direction == "vertical": kernel[:, self.kernel_size // 2] = 1.0 elif direction == "diagonal": for i in range(self.kernel_size): kernel[i, i] = 1.0 kernel /= kernel.sum() return kernel def _apply_motion_blur(self, image, kernel): if isinstance(image, PILImage.Image): image = F.to_tensor(image) channels = image.shape[0] kernel = kernel.to(image.device).unsqueeze(0).unsqueeze(0) blurred_image = torch.nn.functional.conv2d( image.unsqueeze(0), kernel.repeat(channels, 1, 1, 1), padding=self.kernel_size // 2, groups=channels, ) return F.to_pil_image(blurred_image.squeeze(0)) class LargeScaleJitter: def __init__( self, scale_range=(0.1, 2.0), aspect_ratio_range=(0.75, 1.33), crop_size=(640, 640), consistent_transform=True, p=0.5, ): """ Args:rack scale_range (tuple): Range of scaling factors (min_scale, max_scale). aspect_ratio_range (tuple): Range of aspect ratios (min_aspect_ratio, max_aspect_ratio). crop_size (tuple): Target size of the cropped region (width, height). consistent_transform (bool): Whether to apply the same transformation across all frames. p (float): Probability of applying the transformation. """ self.scale_range = scale_range self.aspect_ratio_range = aspect_ratio_range self.crop_size = crop_size self.consistent_transform = consistent_transform self.p = p def __call__(self, datapoint: Datapoint, **kwargs): if random.random() >= self.p: return datapoint # Sample a single scale factor and aspect ratio for all frames log_ratio = torch.log(torch.tensor(self.aspect_ratio_range)) scale_factor = torch.empty(1).uniform_(*self.scale_range).item() aspect_ratio = torch.exp( torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) ).item() for idx, img in enumerate(datapoint.images): if not self.consistent_transform: # Sample a new scale factor and aspect ratio for each frame log_ratio = torch.log(torch.tensor(self.aspect_ratio_range)) scale_factor = torch.empty(1).uniform_(*self.scale_range).item() aspect_ratio = torch.exp( torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) ).item() # Compute the dimensions of the jittered crop original_width, original_height = img.data.size target_area = original_width * original_height * scale_factor crop_width = int(round((target_area * aspect_ratio) ** 0.5)) crop_height = int(round((target_area / aspect_ratio) ** 0.5)) # Randomly select the top-left corner of the crop crop_x = random.randint(0, max(0, original_width - crop_width)) crop_y = random.randint(0, max(0, original_height - crop_height)) # Extract the cropped region datapoint = crop(datapoint, idx, (crop_x, crop_y, crop_width, crop_height)) # Resize the cropped region to the target crop size datapoint = resize(datapoint, idx, self.crop_size) return datapoint