Test-Prompt / backend /utils /image_processing.py
abhiman181025's picture
First commit
1314bf5
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from typing import List, Tuple, Union
# Constants from InternVL preprocessing
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size: int = 448) -> T.Compose:
"""
Return torchvision transform matching InternVL pre‑training.
Args:
input_size: Input image size (default: 448)
Returns:
Composed torchvision transforms
"""
return T.Compose([
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
def find_closest_aspect_ratio(
aspect_ratio: float,
target_ratios: List[Tuple[int, int]],
width: int,
height: int,
image_size: int
) -> Tuple[int, int]:
"""
Find the closest aspect ratio from target ratios.
Args:
aspect_ratio: Current image aspect ratio
target_ratios: List of target aspect ratios as (width, height) tuples
width: Original image width
height: Original image height
image_size: Target image size
Returns:
Best matching aspect ratio as (width, height) tuple
"""
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
tgt_ar = ratio[0] / ratio[1]
diff = abs(aspect_ratio - tgt_ar)
if (diff < best_ratio_diff or
(diff == best_ratio_diff and area > 0.5 * image_size * image_size * ratio[0] * ratio[1])):
best_ratio_diff = diff
best_ratio = ratio
return best_ratio
def dynamic_preprocess(
image: Image.Image,
min_num: int = 1,
max_num: int = 12,
image_size: int = 448,
use_thumbnail: bool = False
) -> List[Image.Image]:
"""
Split arbitrarily‑sized image into ≀12 tiles sized 448Γ—448 (InternVL spec).
Args:
image: Input PIL Image
min_num: Minimum number of tiles
max_num: Maximum number of tiles
image_size: Size of each tile (default: 448)
use_thumbnail: Whether to add a thumbnail version
Returns:
List of processed image tiles
"""
ow, oh = image.size
aspect_ratio = ow / oh
# Generate target ratios
target_ratios = sorted(
{(i, j) for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if min_num <= i * j <= max_num},
key=lambda x: x[0] * x[1],
)
ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, ow, oh, image_size)
tw, th = image_size * ratio[0], image_size * ratio[1]
blocks = ratio[0] * ratio[1]
resized = image.resize((tw, th))
# Create tiles
tiles = []
for idx in range(blocks):
tile = resized.crop((
(idx % (tw // image_size)) * image_size,
(idx // (tw // image_size)) * image_size,
((idx % (tw // image_size)) + 1) * image_size,
((idx // (tw // image_size)) + 1) * image_size,
))
tiles.append(tile)
# Add thumbnail if requested and more than one tile
if use_thumbnail and blocks != 1:
tiles.append(image.resize((image_size, image_size)))
return tiles
def load_image(
path: str,
input_size: int = 448,
max_num: int = 12
) -> torch.Tensor:
"""
Load and preprocess image for InternVL model.
Args:
path: Path to the image file
input_size: Input image size (default: 448)
max_num: Maximum number of tiles (default: 12)
Returns:
Tensor of shape (N, 3, H, W) ready for InternVL
"""
img = Image.open(path).convert("RGB")
transform = build_transform(input_size)
tiles = dynamic_preprocess(
img,
image_size=input_size,
use_thumbnail=True,
max_num=max_num
)
return torch.stack([transform(t) for t in tiles])