Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved | |
| """ | |
| Transforms and data augmentation for both image + bbox. | |
| """ | |
| import logging | |
| import numbers | |
| import random | |
| from collections.abc import Sequence | |
| from typing import Iterable | |
| import torch | |
| import torchvision.transforms as T | |
| import torchvision.transforms.functional as F | |
| import torchvision.transforms.v2.functional as Fv2 | |
| from PIL import Image as PILImage | |
| from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes | |
| from sam3.train.data.sam3_image_dataset import Datapoint | |
| from torchvision.transforms import InterpolationMode | |
| def crop( | |
| datapoint, | |
| index, | |
| region, | |
| v2=False, | |
| check_validity=True, | |
| check_input_validity=True, | |
| recompute_box_from_mask=False, | |
| ): | |
| if v2: | |
| rtop, rleft, rheight, rwidth = (int(round(r)) for r in region) | |
| datapoint.images[index].data = Fv2.crop( | |
| datapoint.images[index].data, | |
| top=rtop, | |
| left=rleft, | |
| height=rheight, | |
| width=rwidth, | |
| ) | |
| else: | |
| datapoint.images[index].data = F.crop(datapoint.images[index].data, *region) | |
| i, j, h, w = region | |
| # should we do something wrt the original size? | |
| datapoint.images[index].size = (h, w) | |
| for obj in datapoint.images[index].objects: | |
| # crop the mask | |
| if obj.segment is not None: | |
| obj.segment = F.crop(obj.segment, int(i), int(j), int(h), int(w)) | |
| # crop the bounding box | |
| if recompute_box_from_mask and obj.segment is not None: | |
| # here the boxes are still in XYXY format with absolute coordinates (they are | |
| # converted to CxCyWH with relative coordinates in basic_for_api.NormalizeAPI) | |
| obj.bbox, obj.area = get_bbox_xyxy_abs_coords_from_mask(obj.segment) | |
| else: | |
| if recompute_box_from_mask and obj.segment is None and obj.area > 0: | |
| logging.warning( | |
| "Cannot recompute bounding box from mask since `obj.segment` is None. " | |
| "Falling back to directly cropping from the input bounding box." | |
| ) | |
| boxes = obj.bbox.view(1, 4) | |
| max_size = torch.as_tensor([w, h], dtype=torch.float32) | |
| cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32) | |
| cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) | |
| cropped_boxes = cropped_boxes.clamp(min=0) | |
| obj.area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) | |
| obj.bbox = cropped_boxes.reshape(-1, 4) | |
| for query in datapoint.find_queries: | |
| if query.semantic_target is not None: | |
| query.semantic_target = F.crop( | |
| query.semantic_target, int(i), int(j), int(h), int(w) | |
| ) | |
| if query.image_id == index and query.input_bbox is not None: | |
| boxes = query.input_bbox | |
| max_size = torch.as_tensor([w, h], dtype=torch.float32) | |
| cropped_boxes = boxes - torch.as_tensor([j, i, j, i], dtype=torch.float32) | |
| cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) | |
| cropped_boxes = cropped_boxes.clamp(min=0) | |
| # cur_area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) | |
| # if check_input_validity: | |
| # assert ( | |
| # (cur_area > 0).all().item() | |
| # ), "Some input box got cropped out by the crop transform" | |
| query.input_bbox = cropped_boxes.reshape(-1, 4) | |
| if query.image_id == index and query.input_points is not None: | |
| print( | |
| "Warning! Point cropping with this function may lead to unexpected results" | |
| ) | |
| points = query.input_points | |
| # Unlike right-lower box edges, which are exclusive, the | |
| # point must be in [0, length-1], hence the -1 | |
| max_size = torch.as_tensor([w, h], dtype=torch.float32) - 1 | |
| cropped_points = points - torch.as_tensor([j, i, 0], dtype=torch.float32) | |
| cropped_points[:, :, :2] = torch.min(cropped_points[:, :, :2], max_size) | |
| cropped_points[:, :, :2] = cropped_points[:, :, :2].clamp(min=0) | |
| query.input_points = cropped_points | |
| if check_validity: | |
| # Check that all boxes are still valid | |
| for obj in datapoint.images[index].objects: | |
| assert obj.area > 0, "Box {} has no area".format(obj.bbox) | |
| return datapoint | |
| def hflip(datapoint, index): | |
| datapoint.images[index].data = F.hflip(datapoint.images[index].data) | |
| w, h = datapoint.images[index].data.size | |
| for obj in datapoint.images[index].objects: | |
| boxes = obj.bbox.view(1, 4) | |
| boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( | |
| [-1, 1, -1, 1] | |
| ) + torch.as_tensor([w, 0, w, 0]) | |
| obj.bbox = boxes | |
| if obj.segment is not None: | |
| obj.segment = F.hflip(obj.segment) | |
| for query in datapoint.find_queries: | |
| if query.semantic_target is not None: | |
| query.semantic_target = F.hflip(query.semantic_target) | |
| if query.image_id == index and query.input_bbox is not None: | |
| boxes = query.input_bbox | |
| boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor( | |
| [-1, 1, -1, 1] | |
| ) + torch.as_tensor([w, 0, w, 0]) | |
| query.input_bbox = boxes | |
| if query.image_id == index and query.input_points is not None: | |
| points = query.input_points | |
| points = points * torch.as_tensor([-1, 1, 1]) + torch.as_tensor([w, 0, 0]) | |
| query.input_points = points | |
| return datapoint | |
| def get_size_with_aspect_ratio(image_size, size, max_size=None): | |
| w, h = image_size | |
| if max_size is not None: | |
| min_original_size = float(min((w, h))) | |
| max_original_size = float(max((w, h))) | |
| if max_original_size / min_original_size * size > max_size: | |
| size = max_size * min_original_size / max_original_size | |
| if (w <= h and w == size) or (h <= w and h == size): | |
| return (h, w) | |
| if w < h: | |
| ow = int(round(size)) | |
| oh = int(round(size * h / w)) | |
| else: | |
| oh = int(round(size)) | |
| ow = int(round(size * w / h)) | |
| return (oh, ow) | |
| def resize(datapoint, index, size, max_size=None, square=False, v2=False): | |
| # size can be min_size (scalar) or (w, h) tuple | |
| def get_size(image_size, size, max_size=None): | |
| if isinstance(size, (list, tuple)): | |
| return size[::-1] | |
| else: | |
| return get_size_with_aspect_ratio(image_size, size, max_size) | |
| if square: | |
| size = size, size | |
| else: | |
| cur_size = ( | |
| datapoint.images[index].data.size()[-2:][::-1] | |
| if v2 | |
| else datapoint.images[index].data.size | |
| ) | |
| size = get_size(cur_size, size, max_size) | |
| old_size = ( | |
| datapoint.images[index].data.size()[-2:][::-1] | |
| if v2 | |
| else datapoint.images[index].data.size | |
| ) | |
| if v2: | |
| datapoint.images[index].data = Fv2.resize( | |
| datapoint.images[index].data, size, antialias=True | |
| ) | |
| else: | |
| datapoint.images[index].data = F.resize(datapoint.images[index].data, size) | |
| new_size = ( | |
| datapoint.images[index].data.size()[-2:][::-1] | |
| if v2 | |
| else datapoint.images[index].data.size | |
| ) | |
| ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(new_size, old_size)) | |
| ratio_width, ratio_height = ratios | |
| for obj in datapoint.images[index].objects: | |
| boxes = obj.bbox.view(1, 4) | |
| scaled_boxes = boxes * torch.as_tensor( | |
| [ratio_width, ratio_height, ratio_width, ratio_height], dtype=torch.float32 | |
| ) | |
| obj.bbox = scaled_boxes | |
| obj.area *= ratio_width * ratio_height | |
| if obj.segment is not None: | |
| obj.segment = F.resize(obj.segment[None, None], size).squeeze() | |
| for query in datapoint.find_queries: | |
| if query.semantic_target is not None: | |
| query.semantic_target = F.resize( | |
| query.semantic_target[None, None], size | |
| ).squeeze() | |
| if query.image_id == index and query.input_bbox is not None: | |
| boxes = query.input_bbox | |
| scaled_boxes = boxes * torch.as_tensor( | |
| [ratio_width, ratio_height, ratio_width, ratio_height], | |
| dtype=torch.float32, | |
| ) | |
| query.input_bbox = scaled_boxes | |
| if query.image_id == index and query.input_points is not None: | |
| points = query.input_points | |
| scaled_points = points * torch.as_tensor( | |
| [ratio_width, ratio_height, 1], | |
| dtype=torch.float32, | |
| ) | |
| query.input_points = scaled_points | |
| h, w = size | |
| datapoint.images[index].size = (h, w) | |
| return datapoint | |
| def pad(datapoint, index, padding, v2=False): | |
| old_h, old_w = datapoint.images[index].size | |
| h, w = old_h, old_w | |
| if len(padding) == 2: | |
| # assumes that we only pad on the bottom right corners | |
| if v2: | |
| datapoint.images[index].data = Fv2.pad( | |
| datapoint.images[index].data, (0, 0, padding[0], padding[1]) | |
| ) | |
| else: | |
| datapoint.images[index].data = F.pad( | |
| datapoint.images[index].data, (0, 0, padding[0], padding[1]) | |
| ) | |
| h += padding[1] | |
| w += padding[0] | |
| else: | |
| if v2: | |
| # left, top, right, bottom | |
| datapoint.images[index].data = Fv2.pad( | |
| datapoint.images[index].data, | |
| (padding[0], padding[1], padding[2], padding[3]), | |
| ) | |
| else: | |
| # left, top, right, bottom | |
| datapoint.images[index].data = F.pad( | |
| datapoint.images[index].data, | |
| (padding[0], padding[1], padding[2], padding[3]), | |
| ) | |
| h += padding[1] + padding[3] | |
| w += padding[0] + padding[2] | |
| datapoint.images[index].size = (h, w) | |
| for obj in datapoint.images[index].objects: | |
| if len(padding) != 2: | |
| obj.bbox += torch.as_tensor( | |
| [padding[0], padding[1], padding[0], padding[1]], dtype=torch.float32 | |
| ) | |
| if obj.segment is not None: | |
| if v2: | |
| if len(padding) == 2: | |
| obj.segment = Fv2.pad( | |
| obj.segment[None], (0, 0, padding[0], padding[1]) | |
| ).squeeze(0) | |
| else: | |
| obj.segment = Fv2.pad(obj.segment[None], tuple(padding)).squeeze(0) | |
| else: | |
| if len(padding) == 2: | |
| obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) | |
| else: | |
| obj.segment = F.pad(obj.segment, tuple(padding)) | |
| for query in datapoint.find_queries: | |
| if query.semantic_target is not None: | |
| if v2: | |
| if len(padding) == 2: | |
| query.semantic_target = Fv2.pad( | |
| query.semantic_target[None, None], | |
| (0, 0, padding[0], padding[1]), | |
| ).squeeze() | |
| else: | |
| query.semantic_target = Fv2.pad( | |
| query.semantic_target[None, None], tuple(padding) | |
| ).squeeze() | |
| else: | |
| if len(padding) == 2: | |
| query.semantic_target = F.pad( | |
| query.semantic_target[None, None], | |
| (0, 0, padding[0], padding[1]), | |
| ).squeeze() | |
| else: | |
| query.semantic_target = F.pad( | |
| query.semantic_target[None, None], tuple(padding) | |
| ).squeeze() | |
| if query.image_id == index and query.input_bbox is not None: | |
| if len(padding) != 2: | |
| query.input_bbox += torch.as_tensor( | |
| [padding[0], padding[1], padding[0], padding[1]], | |
| dtype=torch.float32, | |
| ) | |
| if query.image_id == index and query.input_points is not None: | |
| if len(padding) != 2: | |
| query.input_points += torch.as_tensor( | |
| [padding[0], padding[1], 0], dtype=torch.float32 | |
| ) | |
| return datapoint | |
| class RandomSizeCropAPI: | |
| def __init__( | |
| self, | |
| min_size: int, | |
| max_size: int, | |
| respect_boxes: bool, | |
| consistent_transform: bool, | |
| respect_input_boxes: bool = True, | |
| v2: bool = False, | |
| recompute_box_from_mask: bool = False, | |
| ): | |
| self.min_size = min_size | |
| self.max_size = max_size | |
| self.respect_boxes = respect_boxes # if True we can't crop a box out | |
| self.respect_input_boxes = respect_input_boxes | |
| self.consistent_transform = consistent_transform | |
| self.v2 = v2 | |
| self.recompute_box_from_mask = recompute_box_from_mask | |
| def _sample_no_respect_boxes(self, img): | |
| w = random.randint(self.min_size, min(img.width, self.max_size)) | |
| h = random.randint(self.min_size, min(img.height, self.max_size)) | |
| return T.RandomCrop.get_params(img, (h, w)) | |
| def _sample_respect_boxes(self, img, boxes, points, min_box_size=10.0): | |
| """ | |
| Assure that no box or point is dropped via cropping, though portions | |
| of boxes may be removed. | |
| """ | |
| if len(boxes) == 0 and len(points) == 0: | |
| return self._sample_no_respect_boxes(img) | |
| if self.v2: | |
| img_height, img_width = img.size()[-2:] | |
| else: | |
| img_width, img_height = img.size | |
| minW, minH, maxW, maxH = ( | |
| min(img_width, self.min_size), | |
| min(img_height, self.min_size), | |
| min(img_width, self.max_size), | |
| min(img_height, self.max_size), | |
| ) | |
| # The crop box must extend one pixel beyond points to the bottom/right | |
| # to assure the exclusive box contains the points. | |
| minX = ( | |
| torch.cat([boxes[:, 0] + min_box_size, points[:, 0] + 1], dim=0) | |
| .max() | |
| .item() | |
| ) | |
| minY = ( | |
| torch.cat([boxes[:, 1] + min_box_size, points[:, 1] + 1], dim=0) | |
| .max() | |
| .item() | |
| ) | |
| minX = min(img_width, minX) | |
| minY = min(img_height, minY) | |
| maxX = torch.cat([boxes[:, 2] - min_box_size, points[:, 0]], dim=0).min().item() | |
| maxY = torch.cat([boxes[:, 3] - min_box_size, points[:, 1]], dim=0).min().item() | |
| maxX = max(0.0, maxX) | |
| maxY = max(0.0, maxY) | |
| minW = max(minW, minX - maxX) | |
| minH = max(minH, minY - maxY) | |
| w = random.uniform(minW, max(minW, maxW)) | |
| h = random.uniform(minH, max(minH, maxH)) | |
| if minX > maxX: | |
| # i = random.uniform(max(0, minX - w + 1), max(maxX, max(0, minX - w + 1))) | |
| i = random.uniform(max(0, minX - w), max(maxX, max(0, minX - w))) | |
| else: | |
| i = random.uniform( | |
| max(0, minX - w + 1), max(maxX - 1, max(0, minX - w + 1)) | |
| ) | |
| if minY > maxY: | |
| # j = random.uniform(max(0, minY - h + 1), max(maxY, max(0, minY - h + 1))) | |
| j = random.uniform(max(0, minY - h), max(maxY, max(0, minY - h))) | |
| else: | |
| j = random.uniform( | |
| max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1)) | |
| ) | |
| return [j, i, h, w] | |
| def __call__(self, datapoint, **kwargs): | |
| if self.respect_boxes or self.respect_input_boxes: | |
| if self.consistent_transform: | |
| # Check that all the images are the same size | |
| w, h = datapoint.images[0].data.size | |
| for img in datapoint.images: | |
| assert img.data.size == (w, h) | |
| all_boxes = [] | |
| # Getting all boxes in all the images | |
| if self.respect_boxes: | |
| all_boxes += [ | |
| obj.bbox.view(-1, 4) | |
| for img in datapoint.images | |
| for obj in img.objects | |
| ] | |
| # Get all the boxes in the find queries | |
| if self.respect_input_boxes: | |
| all_boxes += [ | |
| q.input_bbox.view(-1, 4) | |
| for q in datapoint.find_queries | |
| if q.input_bbox is not None | |
| ] | |
| if all_boxes: | |
| all_boxes = torch.cat(all_boxes, 0) | |
| else: | |
| all_boxes = torch.empty(0, 4) | |
| all_points = [ | |
| q.input_points.view(-1, 3)[:, :2] | |
| for q in datapoint.find_queries | |
| if q.input_points is not None | |
| ] | |
| if all_points: | |
| all_points = torch.cat(all_points, 0) | |
| else: | |
| all_points = torch.empty(0, 2) | |
| crop_param = self._sample_respect_boxes( | |
| datapoint.images[0].data, all_boxes, all_points | |
| ) | |
| for i in range(len(datapoint.images)): | |
| datapoint = crop( | |
| datapoint, | |
| i, | |
| crop_param, | |
| v2=self.v2, | |
| check_validity=self.respect_boxes, | |
| check_input_validity=self.respect_input_boxes, | |
| recompute_box_from_mask=self.recompute_box_from_mask, | |
| ) | |
| return datapoint | |
| else: | |
| for i in range(len(datapoint.images)): | |
| all_boxes = [] | |
| # Get all boxes in the current image | |
| if self.respect_boxes: | |
| all_boxes += [ | |
| obj.bbox.view(-1, 4) for obj in datapoint.images[i].objects | |
| ] | |
| # Get all the boxes in the find queries that correspond to this image | |
| if self.respect_input_boxes: | |
| all_boxes += [ | |
| q.input_bbox.view(-1, 4) | |
| for q in datapoint.find_queries | |
| if q.image_id == i and q.input_bbox is not None | |
| ] | |
| if all_boxes: | |
| all_boxes = torch.cat(all_boxes, 0) | |
| else: | |
| all_boxes = torch.empty(0, 4) | |
| all_points = [ | |
| q.input_points.view(-1, 3)[:, :2] | |
| for q in datapoint.find_queries | |
| if q.input_points is not None | |
| ] | |
| if all_points: | |
| all_points = torch.cat(all_points, 0) | |
| else: | |
| all_points = torch.empty(0, 2) | |
| crop_param = self._sample_respect_boxes( | |
| datapoint.images[i].data, all_boxes, all_points | |
| ) | |
| datapoint = crop( | |
| datapoint, | |
| i, | |
| crop_param, | |
| v2=self.v2, | |
| check_validity=self.respect_boxes, | |
| check_input_validity=self.respect_input_boxes, | |
| recompute_box_from_mask=self.recompute_box_from_mask, | |
| ) | |
| return datapoint | |
| else: | |
| if self.consistent_transform: | |
| # Check that all the images are the same size | |
| w, h = datapoint.images[0].data.size | |
| for img in datapoint.images: | |
| assert img.data.size == (w, h) | |
| crop_param = self._sample_no_respect_boxes(datapoint.images[0].data) | |
| for i in range(len(datapoint.images)): | |
| datapoint = crop( | |
| datapoint, | |
| i, | |
| crop_param, | |
| v2=self.v2, | |
| check_validity=self.respect_boxes, | |
| check_input_validity=self.respect_input_boxes, | |
| recompute_box_from_mask=self.recompute_box_from_mask, | |
| ) | |
| return datapoint | |
| else: | |
| for i in range(len(datapoint.images)): | |
| crop_param = self._sample_no_respect_boxes(datapoint.images[i].data) | |
| datapoint = crop( | |
| datapoint, | |
| i, | |
| crop_param, | |
| v2=self.v2, | |
| check_validity=self.respect_boxes, | |
| check_input_validity=self.respect_input_boxes, | |
| recompute_box_from_mask=self.recompute_box_from_mask, | |
| ) | |
| return datapoint | |
| class CenterCropAPI: | |
| def __init__(self, size, consistent_transform, recompute_box_from_mask=False): | |
| self.size = size | |
| self.consistent_transform = consistent_transform | |
| self.recompute_box_from_mask = recompute_box_from_mask | |
| def _sample_crop(self, image_width, image_height): | |
| crop_height, crop_width = self.size | |
| crop_top = int(round((image_height - crop_height) / 2.0)) | |
| crop_left = int(round((image_width - crop_width) / 2.0)) | |
| return crop_top, crop_left, crop_height, crop_width | |
| def __call__(self, datapoint, **kwargs): | |
| if self.consistent_transform: | |
| # Check that all the images are the same size | |
| w, h = datapoint.images[0].data.size | |
| for img in datapoint.images: | |
| assert img.size == (w, h) | |
| crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h) | |
| for i in range(len(datapoint.images)): | |
| datapoint = crop( | |
| datapoint, | |
| i, | |
| (crop_top, crop_left, crop_height, crop_width), | |
| recompute_box_from_mask=self.recompute_box_from_mask, | |
| ) | |
| return datapoint | |
| for i in range(len(datapoint.images)): | |
| w, h = datapoint.images[i].data.size | |
| crop_top, crop_left, crop_height, crop_width = self._sample_crop(w, h) | |
| datapoint = crop( | |
| datapoint, | |
| i, | |
| (crop_top, crop_left, crop_height, crop_width), | |
| recompute_box_from_mask=self.recompute_box_from_mask, | |
| ) | |
| return datapoint | |
| class RandomHorizontalFlip: | |
| def __init__(self, consistent_transform, p=0.5): | |
| self.p = p | |
| self.consistent_transform = consistent_transform | |
| def __call__(self, datapoint, **kwargs): | |
| if self.consistent_transform: | |
| if random.random() < self.p: | |
| for i in range(len(datapoint.images)): | |
| datapoint = hflip(datapoint, i) | |
| return datapoint | |
| for i in range(len(datapoint.images)): | |
| if random.random() < self.p: | |
| datapoint = hflip(datapoint, i) | |
| return datapoint | |
| class RandomResizeAPI: | |
| def __init__( | |
| self, sizes, consistent_transform, max_size=None, square=False, v2=False | |
| ): | |
| if isinstance(sizes, int): | |
| sizes = (sizes,) | |
| assert isinstance(sizes, Iterable) | |
| self.sizes = list(sizes) | |
| self.max_size = max_size | |
| self.square = square | |
| self.consistent_transform = consistent_transform | |
| self.v2 = v2 | |
| def __call__(self, datapoint, **kwargs): | |
| if self.consistent_transform: | |
| size = random.choice(self.sizes) | |
| for i in range(len(datapoint.images)): | |
| datapoint = resize( | |
| datapoint, i, size, self.max_size, square=self.square, v2=self.v2 | |
| ) | |
| return datapoint | |
| for i in range(len(datapoint.images)): | |
| size = random.choice(self.sizes) | |
| datapoint = resize( | |
| datapoint, i, size, self.max_size, square=self.square, v2=self.v2 | |
| ) | |
| return datapoint | |
| class ScheduledRandomResizeAPI(RandomResizeAPI): | |
| def __init__(self, size_scheduler, consistent_transform, square=False): | |
| self.size_scheduler = size_scheduler | |
| # Just a meaningful init value for super | |
| params = self.size_scheduler(epoch_num=0) | |
| sizes, max_size = params["sizes"], params["max_size"] | |
| super().__init__(sizes, consistent_transform, max_size=max_size, square=square) | |
| def __call__(self, datapoint, **kwargs): | |
| assert "epoch" in kwargs, "Param scheduler needs to know the current epoch" | |
| params = self.size_scheduler(kwargs["epoch"]) | |
| sizes, max_size = params["sizes"], params["max_size"] | |
| self.sizes = sizes | |
| self.max_size = max_size | |
| datapoint = super(ScheduledRandomResizeAPI, self).__call__(datapoint, **kwargs) | |
| return datapoint | |
| class RandomPadAPI: | |
| def __init__(self, max_pad, consistent_transform): | |
| self.max_pad = max_pad | |
| self.consistent_transform = consistent_transform | |
| def _sample_pad(self): | |
| pad_x = random.randint(0, self.max_pad) | |
| pad_y = random.randint(0, self.max_pad) | |
| return pad_x, pad_y | |
| def __call__(self, datapoint, **kwargs): | |
| if self.consistent_transform: | |
| pad_x, pad_y = self._sample_pad() | |
| for i in range(len(datapoint.images)): | |
| datapoint = pad(datapoint, i, (pad_x, pad_y)) | |
| return datapoint | |
| for i in range(len(datapoint.images)): | |
| pad_x, pad_y = self._sample_pad() | |
| datapoint = pad(datapoint, i, (pad_x, pad_y)) | |
| return datapoint | |
| class PadToSizeAPI: | |
| def __init__(self, size, consistent_transform, bottom_right=False, v2=False): | |
| self.size = size | |
| self.consistent_transform = consistent_transform | |
| self.v2 = v2 | |
| self.bottom_right = bottom_right | |
| def _sample_pad(self, w, h): | |
| pad_x = self.size - w | |
| pad_y = self.size - h | |
| assert pad_x >= 0 and pad_y >= 0 | |
| pad_left = random.randint(0, pad_x) | |
| pad_right = pad_x - pad_left | |
| pad_top = random.randint(0, pad_y) | |
| pad_bottom = pad_y - pad_top | |
| return pad_left, pad_top, pad_right, pad_bottom | |
| def __call__(self, datapoint, **kwargs): | |
| if self.consistent_transform: | |
| # Check that all the images are the same size | |
| w, h = datapoint.images[0].data.size | |
| for img in datapoint.images: | |
| assert img.size == (w, h) | |
| if self.bottom_right: | |
| pad_right = self.size - w | |
| pad_bottom = self.size - h | |
| padding = (pad_right, pad_bottom) | |
| else: | |
| padding = self._sample_pad(w, h) | |
| for i in range(len(datapoint.images)): | |
| datapoint = pad(datapoint, i, padding, v2=self.v2) | |
| return datapoint | |
| for i, img in enumerate(datapoint.images): | |
| w, h = img.data.size | |
| if self.bottom_right: | |
| pad_right = self.size - w | |
| pad_bottom = self.size - h | |
| padding = (pad_right, pad_bottom) | |
| else: | |
| padding = self._sample_pad(w, h) | |
| datapoint = pad(datapoint, i, padding, v2=self.v2) | |
| return datapoint | |
| class RandomMosaicVideoAPI: | |
| def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): | |
| self.prob = prob | |
| self.grid_h = grid_h | |
| self.grid_w = grid_w | |
| self.use_random_hflip = use_random_hflip | |
| def __call__(self, datapoint, **kwargs): | |
| if random.random() > self.prob: | |
| return datapoint | |
| # select a random location to place the target mask in the mosaic | |
| target_grid_y = random.randint(0, self.grid_h - 1) | |
| target_grid_x = random.randint(0, self.grid_w - 1) | |
| # whether to flip each grid in the mosaic horizontally | |
| if self.use_random_hflip: | |
| should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 | |
| else: | |
| should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) | |
| for i in range(len(datapoint.images)): | |
| datapoint = random_mosaic_frame( | |
| datapoint, | |
| i, | |
| grid_h=self.grid_h, | |
| grid_w=self.grid_w, | |
| target_grid_y=target_grid_y, | |
| target_grid_x=target_grid_x, | |
| should_hflip=should_hflip, | |
| ) | |
| return datapoint | |
| def random_mosaic_frame( | |
| datapoint, | |
| index, | |
| grid_h, | |
| grid_w, | |
| target_grid_y, | |
| target_grid_x, | |
| should_hflip, | |
| ): | |
| # Step 1: downsize the images and paste them into a mosaic | |
| image_data = datapoint.images[index].data | |
| is_pil = isinstance(image_data, PILImage.Image) | |
| if is_pil: | |
| H_im = image_data.height | |
| W_im = image_data.width | |
| image_data_output = PILImage.new("RGB", (W_im, H_im)) | |
| else: | |
| H_im = image_data.size(-2) | |
| W_im = image_data.size(-1) | |
| image_data_output = torch.zeros_like(image_data) | |
| downsize_cache = {} | |
| for grid_y in range(grid_h): | |
| for grid_x in range(grid_w): | |
| y_offset_b = grid_y * H_im // grid_h | |
| x_offset_b = grid_x * W_im // grid_w | |
| y_offset_e = (grid_y + 1) * H_im // grid_h | |
| x_offset_e = (grid_x + 1) * W_im // grid_w | |
| H_im_downsize = y_offset_e - y_offset_b | |
| W_im_downsize = x_offset_e - x_offset_b | |
| if (H_im_downsize, W_im_downsize) in downsize_cache: | |
| image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] | |
| else: | |
| image_data_downsize = F.resize( | |
| image_data, | |
| size=(H_im_downsize, W_im_downsize), | |
| interpolation=InterpolationMode.BILINEAR, | |
| antialias=True, # antialiasing for downsizing | |
| ) | |
| downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize | |
| if should_hflip[grid_y, grid_x].item(): | |
| image_data_downsize = F.hflip(image_data_downsize) | |
| if is_pil: | |
| image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) | |
| else: | |
| image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( | |
| image_data_downsize | |
| ) | |
| datapoint.images[index].data = image_data_output | |
| # Step 2: downsize the masks and paste them into the target grid of the mosaic | |
| # (note that we don't scale input/target boxes since they are not used in TA) | |
| for obj in datapoint.images[index].objects: | |
| if obj.segment is None: | |
| continue | |
| assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 | |
| segment_output = torch.zeros_like(obj.segment) | |
| target_y_offset_b = target_grid_y * H_im // grid_h | |
| target_x_offset_b = target_grid_x * W_im // grid_w | |
| target_y_offset_e = (target_grid_y + 1) * H_im // grid_h | |
| target_x_offset_e = (target_grid_x + 1) * W_im // grid_w | |
| target_H_im_downsize = target_y_offset_e - target_y_offset_b | |
| target_W_im_downsize = target_x_offset_e - target_x_offset_b | |
| segment_downsize = F.resize( | |
| obj.segment[None, None], | |
| size=(target_H_im_downsize, target_W_im_downsize), | |
| interpolation=InterpolationMode.BILINEAR, | |
| antialias=True, # antialiasing for downsizing | |
| )[0, 0] | |
| if should_hflip[target_grid_y, target_grid_x].item(): | |
| segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] | |
| segment_output[ | |
| target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e | |
| ] = segment_downsize | |
| obj.segment = segment_output | |
| return datapoint | |
| class ScheduledPadToSizeAPI(PadToSizeAPI): | |
| def __init__(self, size_scheduler, consistent_transform): | |
| self.size_scheduler = size_scheduler | |
| size = self.size_scheduler(epoch_num=0)["sizes"] | |
| super().__init__(size, consistent_transform) | |
| def __call__(self, datapoint, **kwargs): | |
| assert "epoch" in kwargs, "Param scheduler needs to know the current epoch" | |
| params = self.size_scheduler(kwargs["epoch"]) | |
| self.size = params["resolution"] | |
| return super(ScheduledPadToSizeAPI, self).__call__(datapoint, **kwargs) | |
| class IdentityAPI: | |
| def __call__(self, datapoint, **kwargs): | |
| return datapoint | |
| class RandomSelectAPI: | |
| """ | |
| Randomly selects between transforms1 and transforms2, | |
| with probability p for transforms1 and (1 - p) for transforms2 | |
| """ | |
| def __init__(self, transforms1=None, transforms2=None, p=0.5): | |
| self.transforms1 = transforms1 or IdentityAPI() | |
| self.transforms2 = transforms2 or IdentityAPI() | |
| self.p = p | |
| def __call__(self, datapoint, **kwargs): | |
| if random.random() < self.p: | |
| return self.transforms1(datapoint, **kwargs) | |
| return self.transforms2(datapoint, **kwargs) | |
| class ToTensorAPI: | |
| def __init__(self, v2=False): | |
| self.v2 = v2 | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| for img in datapoint.images: | |
| if self.v2: | |
| img.data = Fv2.to_image_tensor(img.data) | |
| # img.data = Fv2.to_dtype(img.data, torch.uint8, scale=True) | |
| # img.data = Fv2.convert_image_dtype(img.data, torch.uint8) | |
| else: | |
| img.data = F.to_tensor(img.data) | |
| return datapoint | |
| class NormalizeAPI: | |
| def __init__(self, mean, std, v2=False): | |
| self.mean = mean | |
| self.std = std | |
| self.v2 = v2 | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| for img in datapoint.images: | |
| if self.v2: | |
| img.data = Fv2.convert_image_dtype(img.data, torch.float32) | |
| img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) | |
| else: | |
| img.data = F.normalize(img.data, mean=self.mean, std=self.std) | |
| for obj in img.objects: | |
| boxes = obj.bbox | |
| cur_h, cur_w = img.data.shape[-2:] | |
| boxes = box_xyxy_to_cxcywh(boxes) | |
| boxes = boxes / torch.tensor( | |
| [cur_w, cur_h, cur_w, cur_h], dtype=torch.float32 | |
| ) | |
| obj.bbox = boxes | |
| for query in datapoint.find_queries: | |
| if query.input_bbox is not None: | |
| boxes = query.input_bbox | |
| cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:] | |
| boxes = box_xyxy_to_cxcywh(boxes) | |
| boxes = boxes / torch.tensor( | |
| [cur_w, cur_h, cur_w, cur_h], dtype=torch.float32 | |
| ) | |
| query.input_bbox = boxes | |
| if query.input_points is not None: | |
| points = query.input_points | |
| cur_h, cur_w = datapoint.images[query.image_id].data.shape[-2:] | |
| points = points / torch.tensor([cur_w, cur_h, 1.0], dtype=torch.float32) | |
| query.input_points = points | |
| return datapoint | |
| class ComposeAPI: | |
| def __init__(self, transforms): | |
| self.transforms = transforms | |
| def __call__(self, datapoint, **kwargs): | |
| for t in self.transforms: | |
| datapoint = t(datapoint, **kwargs) | |
| return datapoint | |
| def __repr__(self): | |
| format_string = self.__class__.__name__ + "(" | |
| for t in self.transforms: | |
| format_string += "\n" | |
| format_string += " {0}".format(t) | |
| format_string += "\n)" | |
| return format_string | |
| class RandomGrayscale: | |
| def __init__(self, consistent_transform, p=0.5): | |
| self.p = p | |
| self.consistent_transform = consistent_transform | |
| self.Grayscale = T.Grayscale(num_output_channels=3) | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| if self.consistent_transform: | |
| if random.random() < self.p: | |
| for img in datapoint.images: | |
| img.data = self.Grayscale(img.data) | |
| return datapoint | |
| for img in datapoint.images: | |
| if random.random() < self.p: | |
| img.data = self.Grayscale(img.data) | |
| return datapoint | |
| class ColorJitter: | |
| def __init__(self, consistent_transform, brightness, contrast, saturation, hue): | |
| self.consistent_transform = consistent_transform | |
| self.brightness = ( | |
| brightness | |
| if isinstance(brightness, list) | |
| else [max(0, 1 - brightness), 1 + brightness] | |
| ) | |
| self.contrast = ( | |
| contrast | |
| if isinstance(contrast, list) | |
| else [max(0, 1 - contrast), 1 + contrast] | |
| ) | |
| self.saturation = ( | |
| saturation | |
| if isinstance(saturation, list) | |
| else [max(0, 1 - saturation), 1 + saturation] | |
| ) | |
| self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| if self.consistent_transform: | |
| # Create a color jitter transformation params | |
| ( | |
| fn_idx, | |
| brightness_factor, | |
| contrast_factor, | |
| saturation_factor, | |
| hue_factor, | |
| ) = T.ColorJitter.get_params( | |
| self.brightness, self.contrast, self.saturation, self.hue | |
| ) | |
| for img in datapoint.images: | |
| if not self.consistent_transform: | |
| ( | |
| fn_idx, | |
| brightness_factor, | |
| contrast_factor, | |
| saturation_factor, | |
| hue_factor, | |
| ) = T.ColorJitter.get_params( | |
| self.brightness, self.contrast, self.saturation, self.hue | |
| ) | |
| for fn_id in fn_idx: | |
| if fn_id == 0 and brightness_factor is not None: | |
| img.data = F.adjust_brightness(img.data, brightness_factor) | |
| elif fn_id == 1 and contrast_factor is not None: | |
| img.data = F.adjust_contrast(img.data, contrast_factor) | |
| elif fn_id == 2 and saturation_factor is not None: | |
| img.data = F.adjust_saturation(img.data, saturation_factor) | |
| elif fn_id == 3 and hue_factor is not None: | |
| img.data = F.adjust_hue(img.data, hue_factor) | |
| return datapoint | |
| class RandomAffine: | |
| def __init__( | |
| self, | |
| degrees, | |
| consistent_transform, | |
| scale=None, | |
| translate=None, | |
| shear=None, | |
| image_mean=(123, 116, 103), | |
| log_warning=True, | |
| num_tentatives=1, | |
| image_interpolation="bicubic", | |
| ): | |
| """ | |
| The mask is required for this transform. | |
| if consistent_transform if True, then the same random affine is applied to all frames and masks. | |
| """ | |
| self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) | |
| self.scale = scale | |
| self.shear = ( | |
| shear if isinstance(shear, list) else ([-shear, shear] if shear else None) | |
| ) | |
| self.translate = translate | |
| self.fill_img = image_mean | |
| self.consistent_transform = consistent_transform | |
| self.log_warning = log_warning | |
| self.num_tentatives = num_tentatives | |
| if image_interpolation == "bicubic": | |
| self.image_interpolation = InterpolationMode.BICUBIC | |
| elif image_interpolation == "bilinear": | |
| self.image_interpolation = InterpolationMode.BILINEAR | |
| else: | |
| raise NotImplementedError | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| for _tentative in range(self.num_tentatives): | |
| res = self.transform_datapoint(datapoint) | |
| if res is not None: | |
| return res | |
| if self.log_warning: | |
| logging.warning( | |
| f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" | |
| ) | |
| return datapoint | |
| def transform_datapoint(self, datapoint: Datapoint): | |
| _, height, width = F.get_dimensions(datapoint.images[0].data) | |
| img_size = [width, height] | |
| if self.consistent_transform: | |
| # Create a random affine transformation | |
| affine_params = T.RandomAffine.get_params( | |
| degrees=self.degrees, | |
| translate=self.translate, | |
| scale_ranges=self.scale, | |
| shears=self.shear, | |
| img_size=img_size, | |
| ) | |
| for img_idx, img in enumerate(datapoint.images): | |
| this_masks = [ | |
| obj.segment.unsqueeze(0) if obj.segment is not None else None | |
| for obj in img.objects | |
| ] | |
| if not self.consistent_transform: | |
| # if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation | |
| affine_params = T.RandomAffine.get_params( | |
| degrees=self.degrees, | |
| translate=self.translate, | |
| scale_ranges=self.scale, | |
| shears=self.shear, | |
| img_size=img_size, | |
| ) | |
| transformed_bboxes, transformed_masks = [], [] | |
| for i in range(len(img.objects)): | |
| if this_masks[i] is None: | |
| transformed_masks.append(None) | |
| # Dummy bbox for a dummy target | |
| transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]])) | |
| else: | |
| transformed_mask = F.affine( | |
| this_masks[i], | |
| *affine_params, | |
| interpolation=InterpolationMode.NEAREST, | |
| fill=0.0, | |
| ) | |
| if img_idx == 0 and transformed_mask.max() == 0: | |
| # We are dealing with a video and the object is not visible in the first frame | |
| # Return the datapoint without transformation | |
| return None | |
| transformed_bbox = masks_to_boxes(transformed_mask) | |
| transformed_bboxes.append(transformed_bbox) | |
| transformed_masks.append(transformed_mask.squeeze()) | |
| for i in range(len(img.objects)): | |
| img.objects[i].bbox = transformed_bboxes[i] | |
| img.objects[i].segment = transformed_masks[i] | |
| img.data = F.affine( | |
| img.data, | |
| *affine_params, | |
| interpolation=self.image_interpolation, | |
| fill=self.fill_img, | |
| ) | |
| return datapoint | |
| class RandomResizedCrop: | |
| def __init__( | |
| self, | |
| consistent_transform, | |
| size, | |
| scale=None, | |
| ratio=None, | |
| log_warning=True, | |
| num_tentatives=4, | |
| keep_aspect_ratio=False, | |
| ): | |
| """ | |
| The mask is required for this transform. | |
| if consistent_transform if True, then the same random resized crop is applied to all frames and masks. | |
| """ | |
| if isinstance(size, numbers.Number): | |
| self.size = (int(size), int(size)) | |
| elif isinstance(size, Sequence) and len(size) == 1: | |
| self.size = (size[0], size[0]) | |
| elif len(size) != 2: | |
| raise ValueError("Please provide only two dimensions (h, w) for size.") | |
| else: | |
| self.size = size | |
| self.scale = scale if scale is not None else (0.08, 1.0) | |
| self.ratio = ratio if ratio is not None else (3.0 / 4.0, 4.0 / 3.0) | |
| self.consistent_transform = consistent_transform | |
| self.log_warning = log_warning | |
| self.num_tentatives = num_tentatives | |
| self.keep_aspect_ratio = keep_aspect_ratio | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| for _tentative in range(self.num_tentatives): | |
| res = self.transform_datapoint(datapoint) | |
| if res is not None: | |
| return res | |
| if self.log_warning: | |
| logging.warning( | |
| f"Skip RandomResizeCrop for zero-area mask in first frame after {self.num_tentatives} tentatives" | |
| ) | |
| return datapoint | |
| def transform_datapoint(self, datapoint: Datapoint): | |
| if self.keep_aspect_ratio: | |
| original_size = datapoint.images[0].size | |
| original_ratio = original_size[1] / original_size[0] | |
| ratio = [r * original_ratio for r in self.ratio] | |
| else: | |
| ratio = self.ratio | |
| if self.consistent_transform: | |
| # Create a random crop transformation | |
| crop_params = T.RandomResizedCrop.get_params( | |
| img=datapoint.images[0].data, | |
| scale=self.scale, | |
| ratio=ratio, | |
| ) | |
| for img_idx, img in enumerate(datapoint.images): | |
| if not self.consistent_transform: | |
| # Create a random crop transformation | |
| crop_params = T.RandomResizedCrop.get_params( | |
| img=img.data, | |
| scale=self.scale, | |
| ratio=ratio, | |
| ) | |
| this_masks = [ | |
| obj.segment.unsqueeze(0) if obj.segment is not None else None | |
| for obj in img.objects | |
| ] | |
| transformed_bboxes, transformed_masks = [], [] | |
| for i in range(len(img.objects)): | |
| if this_masks[i] is None: | |
| transformed_masks.append(None) | |
| # Dummy bbox for a dummy target | |
| transformed_bboxes.append(torch.tensor([[0, 0, 0, 0]])) | |
| else: | |
| transformed_mask = F.resized_crop( | |
| this_masks[i], | |
| *crop_params, | |
| size=self.size, | |
| interpolation=InterpolationMode.NEAREST, | |
| ) | |
| if img_idx == 0 and transformed_mask.max() == 0: | |
| # We are dealing with a video and the object is not visible in the first frame | |
| # Return the datapoint without transformation | |
| return None | |
| transformed_masks.append(transformed_mask.squeeze()) | |
| transformed_bbox = masks_to_boxes(transformed_mask) | |
| transformed_bboxes.append(transformed_bbox) | |
| # Set the new boxes and masks if all transformed masks and boxes are good. | |
| for i in range(len(img.objects)): | |
| img.objects[i].bbox = transformed_bboxes[i] | |
| img.objects[i].segment = transformed_masks[i] | |
| img.data = F.resized_crop( | |
| img.data, | |
| *crop_params, | |
| size=self.size, | |
| interpolation=InterpolationMode.BILINEAR, | |
| ) | |
| return datapoint | |
| class ResizeToMaxIfAbove: | |
| # Resize datapoint image if one of its sides is larger that max_size | |
| def __init__( | |
| self, | |
| max_size=None, | |
| ): | |
| self.max_size = max_size | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| _, height, width = F.get_dimensions(datapoint.images[0].data) | |
| if height <= self.max_size and width <= self.max_size: | |
| # The original frames are small enough | |
| return datapoint | |
| elif height >= width: | |
| new_height = self.max_size | |
| new_width = int(round(self.max_size * width / height)) | |
| else: | |
| new_height = int(round(self.max_size * height / width)) | |
| new_width = self.max_size | |
| size = new_height, new_width | |
| for index in range(len(datapoint.images)): | |
| datapoint.images[index].data = F.resize(datapoint.images[index].data, size) | |
| for obj in datapoint.images[index].objects: | |
| obj.segment = F.resize( | |
| obj.segment[None, None], | |
| size, | |
| interpolation=InterpolationMode.NEAREST, | |
| ).squeeze() | |
| h, w = size | |
| datapoint.images[index].size = (h, w) | |
| return datapoint | |
| def get_bbox_xyxy_abs_coords_from_mask(mask): | |
| """Get the bounding box (XYXY format w/ absolute coordinates) of a binary mask.""" | |
| assert mask.dim() == 2 | |
| rows = torch.any(mask, dim=1) | |
| cols = torch.any(mask, dim=0) | |
| row_inds = rows.nonzero().view(-1) | |
| col_inds = cols.nonzero().view(-1) | |
| if row_inds.numel() == 0: | |
| # mask is empty | |
| bbox = torch.zeros(1, 4, dtype=torch.float32) | |
| bbox_area = 0.0 | |
| else: | |
| ymin, ymax = row_inds.min(), row_inds.max() | |
| xmin, xmax = col_inds.min(), col_inds.max() | |
| bbox = torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32).view(1, 4) | |
| bbox_area = float((ymax - ymin) * (xmax - xmin)) | |
| return bbox, bbox_area | |
| class MotionBlur: | |
| def __init__(self, kernel_size=5, consistent_transform=True, p=0.5): | |
| assert kernel_size % 2 == 1, "Kernel size must be odd." | |
| self.kernel_size = kernel_size | |
| self.consistent_transform = consistent_transform | |
| self.p = p | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| if random.random() >= self.p: | |
| return datapoint | |
| if self.consistent_transform: | |
| # Generate a single motion blur kernel for all images | |
| kernel = self._generate_motion_blur_kernel() | |
| for img in datapoint.images: | |
| if not self.consistent_transform: | |
| # Generate a new motion blur kernel for each image | |
| kernel = self._generate_motion_blur_kernel() | |
| img.data = self._apply_motion_blur(img.data, kernel) | |
| return datapoint | |
| def _generate_motion_blur_kernel(self): | |
| kernel = torch.zeros((self.kernel_size, self.kernel_size)) | |
| direction = random.choice(["horizontal", "vertical", "diagonal"]) | |
| if direction == "horizontal": | |
| kernel[self.kernel_size // 2, :] = 1.0 | |
| elif direction == "vertical": | |
| kernel[:, self.kernel_size // 2] = 1.0 | |
| elif direction == "diagonal": | |
| for i in range(self.kernel_size): | |
| kernel[i, i] = 1.0 | |
| kernel /= kernel.sum() | |
| return kernel | |
| def _apply_motion_blur(self, image, kernel): | |
| if isinstance(image, PILImage.Image): | |
| image = F.to_tensor(image) | |
| channels = image.shape[0] | |
| kernel = kernel.to(image.device).unsqueeze(0).unsqueeze(0) | |
| blurred_image = torch.nn.functional.conv2d( | |
| image.unsqueeze(0), | |
| kernel.repeat(channels, 1, 1, 1), | |
| padding=self.kernel_size // 2, | |
| groups=channels, | |
| ) | |
| return F.to_pil_image(blurred_image.squeeze(0)) | |
| class LargeScaleJitter: | |
| def __init__( | |
| self, | |
| scale_range=(0.1, 2.0), | |
| aspect_ratio_range=(0.75, 1.33), | |
| crop_size=(640, 640), | |
| consistent_transform=True, | |
| p=0.5, | |
| ): | |
| """ | |
| Args:rack | |
| scale_range (tuple): Range of scaling factors (min_scale, max_scale). | |
| aspect_ratio_range (tuple): Range of aspect ratios (min_aspect_ratio, max_aspect_ratio). | |
| crop_size (tuple): Target size of the cropped region (width, height). | |
| consistent_transform (bool): Whether to apply the same transformation across all frames. | |
| p (float): Probability of applying the transformation. | |
| """ | |
| self.scale_range = scale_range | |
| self.aspect_ratio_range = aspect_ratio_range | |
| self.crop_size = crop_size | |
| self.consistent_transform = consistent_transform | |
| self.p = p | |
| def __call__(self, datapoint: Datapoint, **kwargs): | |
| if random.random() >= self.p: | |
| return datapoint | |
| # Sample a single scale factor and aspect ratio for all frames | |
| log_ratio = torch.log(torch.tensor(self.aspect_ratio_range)) | |
| scale_factor = torch.empty(1).uniform_(*self.scale_range).item() | |
| aspect_ratio = torch.exp( | |
| torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) | |
| ).item() | |
| for idx, img in enumerate(datapoint.images): | |
| if not self.consistent_transform: | |
| # Sample a new scale factor and aspect ratio for each frame | |
| log_ratio = torch.log(torch.tensor(self.aspect_ratio_range)) | |
| scale_factor = torch.empty(1).uniform_(*self.scale_range).item() | |
| aspect_ratio = torch.exp( | |
| torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) | |
| ).item() | |
| # Compute the dimensions of the jittered crop | |
| original_width, original_height = img.data.size | |
| target_area = original_width * original_height * scale_factor | |
| crop_width = int(round((target_area * aspect_ratio) ** 0.5)) | |
| crop_height = int(round((target_area / aspect_ratio) ** 0.5)) | |
| # Randomly select the top-left corner of the crop | |
| crop_x = random.randint(0, max(0, original_width - crop_width)) | |
| crop_y = random.randint(0, max(0, original_height - crop_height)) | |
| # Extract the cropped region | |
| datapoint = crop(datapoint, idx, (crop_x, crop_y, crop_width, crop_height)) | |
| # Resize the cropped region to the target crop size | |
| datapoint = resize(datapoint, idx, self.crop_size) | |
| return datapoint | |