File size: 4,251 Bytes
1314bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from typing import List, Tuple, Union

# Constants from InternVL preprocessing
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size: int = 448) -> T.Compose:
    """
    Return torchvision transform matching InternVL pre‑training.
    
    Args:
        input_size: Input image size (default: 448)
        
    Returns:
        Composed torchvision transforms
    """
    return T.Compose([
        T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])


def find_closest_aspect_ratio(
    aspect_ratio: float, 
    target_ratios: List[Tuple[int, int]], 
    width: int, 
    height: int, 
    image_size: int
) -> Tuple[int, int]:
    """
    Find the closest aspect ratio from target ratios.
    
    Args:
        aspect_ratio: Current image aspect ratio
        target_ratios: List of target aspect ratios as (width, height) tuples
        width: Original image width
        height: Original image height
        image_size: Target image size
        
    Returns:
        Best matching aspect ratio as (width, height) tuple
    """
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    
    for ratio in target_ratios:
        tgt_ar = ratio[0] / ratio[1]
        diff = abs(aspect_ratio - tgt_ar)
        
        if (diff < best_ratio_diff or 
            (diff == best_ratio_diff and area > 0.5 * image_size * image_size * ratio[0] * ratio[1])):
            best_ratio_diff = diff
            best_ratio = ratio
            
    return best_ratio


def dynamic_preprocess(
    image: Image.Image, 
    min_num: int = 1, 
    max_num: int = 12, 
    image_size: int = 448, 
    use_thumbnail: bool = False
) -> List[Image.Image]:
    """
    Split arbitrarily‑sized image into ≀12 tiles sized 448Γ—448 (InternVL spec).
    
    Args:
        image: Input PIL Image
        min_num: Minimum number of tiles
        max_num: Maximum number of tiles
        image_size: Size of each tile (default: 448)
        use_thumbnail: Whether to add a thumbnail version
        
    Returns:
        List of processed image tiles
    """
    ow, oh = image.size
    aspect_ratio = ow / oh
    
    # Generate target ratios
    target_ratios = sorted(
        {(i, j) for n in range(min_num, max_num + 1) 
         for i in range(1, n + 1) 
         for j in range(1, n + 1) 
         if min_num <= i * j <= max_num},
        key=lambda x: x[0] * x[1],
    )
    
    ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, ow, oh, image_size)
    tw, th = image_size * ratio[0], image_size * ratio[1]
    blocks = ratio[0] * ratio[1]
    
    resized = image.resize((tw, th))
    
    # Create tiles
    tiles = []
    for idx in range(blocks):
        tile = resized.crop((
            (idx % (tw // image_size)) * image_size,
            (idx // (tw // image_size)) * image_size,
            ((idx % (tw // image_size)) + 1) * image_size,
            ((idx // (tw // image_size)) + 1) * image_size,
        ))
        tiles.append(tile)
    
    # Add thumbnail if requested and more than one tile
    if use_thumbnail and blocks != 1:
        tiles.append(image.resize((image_size, image_size)))
        
    return tiles


def load_image(
    path: str, 
    input_size: int = 448, 
    max_num: int = 12
) -> torch.Tensor:
    """
    Load and preprocess image for InternVL model.
    
    Args:
        path: Path to the image file
        input_size: Input image size (default: 448)
        max_num: Maximum number of tiles (default: 12)
        
    Returns:
        Tensor of shape (N, 3, H, W) ready for InternVL
    """
    img = Image.open(path).convert("RGB")
    transform = build_transform(input_size)
    tiles = dynamic_preprocess(
        img, 
        image_size=input_size, 
        use_thumbnail=True, 
        max_num=max_num
    )
    return torch.stack([transform(t) for t in tiles])