File size: 5,951 Bytes
d407812 8f2be92 d407812 8f2be92 d407812 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from functools import partial
from typing import Any, Dict, List, Optional
import torch
from torch import nn
class BaseEncoder(nn.Module):
def __init__(self, parent: nn.Module) -> None:
super().__init__()
self._parent = [parent]
@property
def parent(self) -> nn.Module:
return self._parent[0]
class BasicImageEncoder(BaseEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
) -> None:
super().__init__(parent)
self.start_tokens = start_tokens
self.end_tokens = end_tokens
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
if tokens is None:
return None
token_ids = self.parent.tokenizer(tokens).input_ids
token_ids = torch.tensor(token_ids, device=self.parent.device)
return self.parent.llm_model_embed_tokens(token_ids)
def _process_features(
self,
features: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
) -> torch.Tensor:
if start_token_embeds is not None:
features = torch.cat([start_token_embeds, features], dim=0)
if end_token_embeds is not None:
features = torch.cat([features, end_token_embeds], dim=0)
return features
def forward(self, images: List[torch.Tensor], config: Dict[str, Any], device: torch.device) -> List[torch.Tensor]:
images = torch.stack(images, dim=0)
features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
)
return [process_features(f).to(device) for f in features]
class BasicVideoEncoder(BaseEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
) -> None:
super().__init__(parent)
self.start_tokens = start_tokens
self.end_tokens = end_tokens
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
if tokens is None:
return None
token_ids = self.parent.tokenizer(tokens).input_ids
token_ids = torch.tensor(token_ids, device=self.parent.device)
return self.parent.llm_model_embed_tokens(token_ids)
def _process_features(
self,
features: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
) -> torch.Tensor:
if start_token_embeds is not None:
start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
features = torch.cat([start_embeds, features], dim=1)
if end_token_embeds is not None:
end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
features = torch.cat([features, end_embeds], dim=1)
return features.flatten(0, 1)
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
num_frames = [video.shape[0] for video in videos]
images = torch.cat(videos, dim=0)
features = self.parent.encode_images(images)
features = torch.split(features, num_frames)
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
)
return [process_features(f) for f in features]
def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
class TSPVideoEncoder(BasicVideoEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
sep_tokens: Optional[str] = None,
) -> None:
super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
self.pool_sizes = [[8, 1, 1]]
self.sep_tokens = sep_tokens
def _process_features(
self,
inputs: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
sep_token_embeds: Optional[torch.Tensor],
) -> torch.Tensor:
nt, ns = inputs.shape[:2]
nl = int(ns**0.5)
outputs = []
for pool_size in self.pool_sizes:
features = inputs.view(nt, nl, nl, -1)
for dim, p in enumerate(pool_size):
features = pool(features, p, dim=dim)
features = features.flatten(1, 2)
features = super()._process_features(
features,
start_token_embeds=start_token_embeds,
end_token_embeds=end_token_embeds,
)
if sep_token_embeds is not None:
features = torch.cat([features, sep_token_embeds], dim=0)
outputs.append(features)
return torch.cat(outputs, dim=0)
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
num_frames = [video.shape[0] for video in videos]
images = torch.cat(videos, dim=0)
features = self.parent.encode_images(images)
features = torch.split(features, num_frames)
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
sep_token_embeds=self.embed_tokens(self.sep_tokens),
)
return [process_features(f) for f in features]
|