|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
from typing import Callable, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
|
|
|
from .model_misc import LayerScale |
|
|
|
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_head: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
ls_init_value: Optional[float] = None, |
|
|
act_layer: Callable[[], nn.Module] = nn.GELU, |
|
|
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) |
|
|
|
|
|
|
|
|
self.ln_1 = norm_layer(d_model) |
|
|
self.ln_2 = norm_layer(d_model) |
|
|
|
|
|
self.ls_1 = ( |
|
|
LayerScale(d_model, ls_init_value) |
|
|
if ls_init_value is not None |
|
|
else nn.Identity() |
|
|
) |
|
|
self.ls_2 = ( |
|
|
LayerScale(d_model, ls_init_value) |
|
|
if ls_init_value is not None |
|
|
else nn.Identity() |
|
|
) |
|
|
|
|
|
|
|
|
mlp_width = int(d_model * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("c_fc", nn.Linear(d_model, mlp_width)), |
|
|
("gelu", act_layer()), |
|
|
("c_proj", nn.Linear(mlp_width, d_model)), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
def attention( |
|
|
self, |
|
|
q_x: torch.Tensor, |
|
|
k_x: Optional[torch.Tensor] = None, |
|
|
v_x: Optional[torch.Tensor] = None, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
k_x = k_x if k_x is not None else q_x |
|
|
v_x = v_x if v_x is not None else q_x |
|
|
if attn_mask is not None: |
|
|
|
|
|
if not attn_mask.dtype == torch.bool: |
|
|
attn_mask = attn_mask.to(q_x.dtype) |
|
|
|
|
|
return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0] |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q_x: torch.Tensor, |
|
|
k_x: Optional[torch.Tensor] = None, |
|
|
v_x: Optional[torch.Tensor] = None, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
k_x = ( |
|
|
self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None |
|
|
) |
|
|
v_x = ( |
|
|
self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None |
|
|
) |
|
|
x = q_x + self.ls_1( |
|
|
self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) |
|
|
) |
|
|
x = x + self.ls_2(self.mlp(self.ln_2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
width: int, |
|
|
layers: int, |
|
|
heads: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
ls_init_value: Optional[float] = None, |
|
|
act_layer: Callable[[], nn.Module] = nn.GELU, |
|
|
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, |
|
|
compile_mode: Optional[str] = None, |
|
|
use_act_checkpoint: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.width = width |
|
|
self.layers = layers |
|
|
self.grad_checkpointing = use_act_checkpoint |
|
|
self.resblocks = nn.ModuleList( |
|
|
[ |
|
|
ResidualAttentionBlock( |
|
|
width, |
|
|
heads, |
|
|
mlp_ratio, |
|
|
ls_init_value=ls_init_value, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
) |
|
|
for _ in range(layers) |
|
|
] |
|
|
) |
|
|
|
|
|
if compile_mode is not None: |
|
|
self.forward = torch.compile( |
|
|
self.forward, mode=compile_mode, fullgraph=True |
|
|
) |
|
|
if self.grad_checkpointing: |
|
|
torch._dynamo.config.optimize_ddp = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
for _, r in enumerate(self.resblocks): |
|
|
if ( |
|
|
self.grad_checkpointing |
|
|
and not torch.jit.is_scripting() |
|
|
and self.training |
|
|
): |
|
|
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) |
|
|
else: |
|
|
x = r( |
|
|
x, |
|
|
attn_mask=attn_mask, |
|
|
) |
|
|
return x |
|
|
|
|
|
|
|
|
def text_global_pool( |
|
|
x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax" |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
if pool_type == "first": |
|
|
pooled, tokens = x[:, 0], x[:, 1:] |
|
|
elif pool_type == "last": |
|
|
pooled, tokens = x[:, -1], x[:, :-1] |
|
|
elif pool_type == "argmax": |
|
|
|
|
|
assert text is not None |
|
|
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x |
|
|
else: |
|
|
pooled = tokens = x |
|
|
return pooled, tokens |
|
|
|
|
|
|
|
|
class TextTransformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
context_length: int = 77, |
|
|
vocab_size: int = 49408, |
|
|
width: int = 512, |
|
|
heads: int = 8, |
|
|
layers: int = 12, |
|
|
mlp_ratio: float = 4.0, |
|
|
ls_init_value: Optional[float] = None, |
|
|
output_dim: int = 512, |
|
|
no_causal_mask: bool = False, |
|
|
pool_type: str = "none", |
|
|
proj_bias: bool = False, |
|
|
act_layer: Callable = nn.GELU, |
|
|
norm_layer: Callable = nn.LayerNorm, |
|
|
output_tokens: bool = False, |
|
|
use_ln_post: bool = True, |
|
|
compile_mode: Optional[str] = None, |
|
|
use_act_checkpoint: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
assert pool_type in ("first", "last", "argmax", "none") |
|
|
self.output_tokens = output_tokens |
|
|
self.num_pos = self.context_length = context_length |
|
|
self.vocab_size = vocab_size |
|
|
self.width = width |
|
|
self.output_dim = output_dim |
|
|
self.heads = heads |
|
|
self.pool_type = pool_type |
|
|
|
|
|
self.token_embedding = nn.Embedding(self.vocab_size, width) |
|
|
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) |
|
|
self.transformer = Transformer( |
|
|
width=width, |
|
|
layers=layers, |
|
|
heads=heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
ls_init_value=ls_init_value, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
compile_mode=compile_mode, |
|
|
use_act_checkpoint=use_act_checkpoint, |
|
|
) |
|
|
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity() |
|
|
if no_causal_mask: |
|
|
self.attn_mask = None |
|
|
else: |
|
|
self.register_buffer( |
|
|
"attn_mask", self.build_causal_mask(), persistent=False |
|
|
) |
|
|
if proj_bias: |
|
|
self.text_projection = nn.Linear(width, output_dim) |
|
|
else: |
|
|
self.text_projection = nn.Parameter(torch.empty(width, output_dim)) |
|
|
|
|
|
def build_causal_mask(self) -> torch.Tensor: |
|
|
|
|
|
|
|
|
mask = torch.empty(self.num_pos, self.num_pos) |
|
|
mask.fill_(float("-inf")) |
|
|
mask.triu_(1) |
|
|
return mask |
|
|
|
|
|
def forward( |
|
|
self, text: torch.Tensor |
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
seq_len = text.shape[1] |
|
|
x = self.token_embedding(text) |
|
|
|
|
|
attn_mask = self.attn_mask |
|
|
if attn_mask is not None: |
|
|
attn_mask = attn_mask[:seq_len, :seq_len] |
|
|
|
|
|
x = x + self.positional_embedding[:seq_len] |
|
|
x = self.transformer(x, attn_mask=attn_mask) |
|
|
|
|
|
x = self.ln_final(x) |
|
|
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) |
|
|
if self.text_projection is not None: |
|
|
if isinstance(self.text_projection, nn.Linear): |
|
|
pooled = self.text_projection(pooled) |
|
|
else: |
|
|
pooled = pooled @ self.text_projection |
|
|
if self.output_tokens: |
|
|
return pooled, tokens |
|
|
return pooled |
|
|
|
|
|
|
|
|
class VETextEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
tokenizer: Callable, |
|
|
width: int = 1024, |
|
|
heads: int = 16, |
|
|
layers: int = 24, |
|
|
context_length: int = 32, |
|
|
vocab_size: int = 49408, |
|
|
use_ln_post: bool = True, |
|
|
compile_mode: Optional[str] = None, |
|
|
use_act_checkpoint: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.context_length = context_length |
|
|
self.use_ln_post = use_ln_post |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
self.encoder = TextTransformer( |
|
|
context_length=self.context_length, |
|
|
vocab_size=vocab_size, |
|
|
width=width, |
|
|
heads=heads, |
|
|
layers=layers, |
|
|
|
|
|
output_tokens=True, |
|
|
use_ln_post=use_ln_post, |
|
|
compile_mode=compile_mode, |
|
|
use_act_checkpoint=use_act_checkpoint, |
|
|
) |
|
|
self.resizer = nn.Linear(self.encoder.width, d_model) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]], |
|
|
input_boxes: Optional[List] = None, |
|
|
device: torch.device = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
if isinstance(text[0], str): |
|
|
|
|
|
assert input_boxes is None or len(input_boxes) == 0, "not supported" |
|
|
|
|
|
|
|
|
tokenized = self.tokenizer(text, context_length=self.context_length).to( |
|
|
device |
|
|
) |
|
|
text_attention_mask = (tokenized != 0).bool() |
|
|
|
|
|
|
|
|
inputs_embeds = self.encoder.token_embedding( |
|
|
tokenized |
|
|
) |
|
|
_, text_memory = self.encoder(tokenized) |
|
|
|
|
|
assert text_memory.shape[1] == inputs_embeds.shape[1] |
|
|
|
|
|
text_attention_mask = text_attention_mask.ne(1) |
|
|
|
|
|
text_memory = text_memory.transpose(0, 1) |
|
|
|
|
|
text_memory_resized = self.resizer(text_memory) |
|
|
else: |
|
|
|
|
|
text_attention_mask, text_memory_resized, tokenized = text |
|
|
inputs_embeds = tokenized["inputs_embeds"] |
|
|
assert ( |
|
|
input_boxes is None or len(input_boxes) == 0 |
|
|
), "Can't replace boxes in text if it's already encoded" |
|
|
|
|
|
|
|
|
return ( |
|
|
text_attention_mask, |
|
|
text_memory_resized, |
|
|
inputs_embeds.transpose(0, 1), |
|
|
) |
|
|
|