Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| from torchvision.transforms.functional import InterpolationMode | |
| from typing import List, Tuple, Union | |
| # Constants from InternVL preprocessing | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| def build_transform(input_size: int = 448) -> T.Compose: | |
| """ | |
| Return torchvision transform matching InternVL preβtraining. | |
| Args: | |
| input_size: Input image size (default: 448) | |
| Returns: | |
| Composed torchvision transforms | |
| """ | |
| return T.Compose([ | |
| T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), | |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| def find_closest_aspect_ratio( | |
| aspect_ratio: float, | |
| target_ratios: List[Tuple[int, int]], | |
| width: int, | |
| height: int, | |
| image_size: int | |
| ) -> Tuple[int, int]: | |
| """ | |
| Find the closest aspect ratio from target ratios. | |
| Args: | |
| aspect_ratio: Current image aspect ratio | |
| target_ratios: List of target aspect ratios as (width, height) tuples | |
| width: Original image width | |
| height: Original image height | |
| image_size: Target image size | |
| Returns: | |
| Best matching aspect ratio as (width, height) tuple | |
| """ | |
| best_ratio_diff = float("inf") | |
| best_ratio = (1, 1) | |
| area = width * height | |
| for ratio in target_ratios: | |
| tgt_ar = ratio[0] / ratio[1] | |
| diff = abs(aspect_ratio - tgt_ar) | |
| if (diff < best_ratio_diff or | |
| (diff == best_ratio_diff and area > 0.5 * image_size * image_size * ratio[0] * ratio[1])): | |
| best_ratio_diff = diff | |
| best_ratio = ratio | |
| return best_ratio | |
| def dynamic_preprocess( | |
| image: Image.Image, | |
| min_num: int = 1, | |
| max_num: int = 12, | |
| image_size: int = 448, | |
| use_thumbnail: bool = False | |
| ) -> List[Image.Image]: | |
| """ | |
| Split arbitrarilyβsized image into β€12 tiles sized 448Γ448 (InternVL spec). | |
| Args: | |
| image: Input PIL Image | |
| min_num: Minimum number of tiles | |
| max_num: Maximum number of tiles | |
| image_size: Size of each tile (default: 448) | |
| use_thumbnail: Whether to add a thumbnail version | |
| Returns: | |
| List of processed image tiles | |
| """ | |
| ow, oh = image.size | |
| aspect_ratio = ow / oh | |
| # Generate target ratios | |
| target_ratios = sorted( | |
| {(i, j) for n in range(min_num, max_num + 1) | |
| for i in range(1, n + 1) | |
| for j in range(1, n + 1) | |
| if min_num <= i * j <= max_num}, | |
| key=lambda x: x[0] * x[1], | |
| ) | |
| ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, ow, oh, image_size) | |
| tw, th = image_size * ratio[0], image_size * ratio[1] | |
| blocks = ratio[0] * ratio[1] | |
| resized = image.resize((tw, th)) | |
| # Create tiles | |
| tiles = [] | |
| for idx in range(blocks): | |
| tile = resized.crop(( | |
| (idx % (tw // image_size)) * image_size, | |
| (idx // (tw // image_size)) * image_size, | |
| ((idx % (tw // image_size)) + 1) * image_size, | |
| ((idx // (tw // image_size)) + 1) * image_size, | |
| )) | |
| tiles.append(tile) | |
| # Add thumbnail if requested and more than one tile | |
| if use_thumbnail and blocks != 1: | |
| tiles.append(image.resize((image_size, image_size))) | |
| return tiles | |
| def load_image( | |
| path: str, | |
| input_size: int = 448, | |
| max_num: int = 12 | |
| ) -> torch.Tensor: | |
| """ | |
| Load and preprocess image for InternVL model. | |
| Args: | |
| path: Path to the image file | |
| input_size: Input image size (default: 448) | |
| max_num: Maximum number of tiles (default: 12) | |
| Returns: | |
| Tensor of shape (N, 3, H, W) ready for InternVL | |
| """ | |
| img = Image.open(path).convert("RGB") | |
| transform = build_transform(input_size) | |
| tiles = dynamic_preprocess( | |
| img, | |
| image_size=input_size, | |
| use_thumbnail=True, | |
| max_num=max_num | |
| ) | |
| return torch.stack([transform(t) for t in tiles]) |