import torch import numpy as np from PIL import Image import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from typing import List, Tuple, Union # Constants from InternVL preprocessing IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def build_transform(input_size: int = 448) -> T.Compose: """ Return torchvision transform matching InternVL pre‑training. Args: input_size: Input image size (default: 448) Returns: Composed torchvision transforms """ return T.Compose([ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) def find_closest_aspect_ratio( aspect_ratio: float, target_ratios: List[Tuple[int, int]], width: int, height: int, image_size: int ) -> Tuple[int, int]: """ Find the closest aspect ratio from target ratios. Args: aspect_ratio: Current image aspect ratio target_ratios: List of target aspect ratios as (width, height) tuples width: Original image width height: Original image height image_size: Target image size Returns: Best matching aspect ratio as (width, height) tuple """ best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: tgt_ar = ratio[0] / ratio[1] diff = abs(aspect_ratio - tgt_ar) if (diff < best_ratio_diff or (diff == best_ratio_diff and area > 0.5 * image_size * image_size * ratio[0] * ratio[1])): best_ratio_diff = diff best_ratio = ratio return best_ratio def dynamic_preprocess( image: Image.Image, min_num: int = 1, max_num: int = 12, image_size: int = 448, use_thumbnail: bool = False ) -> List[Image.Image]: """ Split arbitrarily‑sized image into ≤12 tiles sized 448×448 (InternVL spec). Args: image: Input PIL Image min_num: Minimum number of tiles max_num: Maximum number of tiles image_size: Size of each tile (default: 448) use_thumbnail: Whether to add a thumbnail version Returns: List of processed image tiles """ ow, oh = image.size aspect_ratio = ow / oh # Generate target ratios target_ratios = sorted( {(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if min_num <= i * j <= max_num}, key=lambda x: x[0] * x[1], ) ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, ow, oh, image_size) tw, th = image_size * ratio[0], image_size * ratio[1] blocks = ratio[0] * ratio[1] resized = image.resize((tw, th)) # Create tiles tiles = [] for idx in range(blocks): tile = resized.crop(( (idx % (tw // image_size)) * image_size, (idx // (tw // image_size)) * image_size, ((idx % (tw // image_size)) + 1) * image_size, ((idx // (tw // image_size)) + 1) * image_size, )) tiles.append(tile) # Add thumbnail if requested and more than one tile if use_thumbnail and blocks != 1: tiles.append(image.resize((image_size, image_size))) return tiles def load_image( path: str, input_size: int = 448, max_num: int = 12 ) -> torch.Tensor: """ Load and preprocess image for InternVL model. Args: path: Path to the image file input_size: Input image size (default: 448) max_num: Maximum number of tiles (default: 12) Returns: Tensor of shape (N, 3, H, W) ready for InternVL """ img = Image.open(path).convert("RGB") transform = build_transform(input_size) tiles = dynamic_preprocess( img, image_size=input_size, use_thumbnail=True, max_num=max_num ) return torch.stack([transform(t) for t in tiles])