bellmake's picture
SAM3 Video Segmentation - Clean deployment
14114e8
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .model_misc import LayerScale
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
):
super().__init__()
# Attention
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
# LayerNorm, LayerScale
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
self.ls_1 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
self.ls_2 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
# MLP
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model)),
]
)
)
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
if attn_mask is not None:
# Leave boolean masks as is
if not attn_mask.dtype == torch.bool:
attn_mask = attn_mask.to(q_x.dtype)
return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
k_x = (
self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
)
v_x = (
self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
)
x = q_x + self.ls_1(
self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
)
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
act_layer: Callable[[], nn.Module] = nn.GELU,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = use_act_checkpoint
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
for _ in range(layers)
]
)
if compile_mode is not None:
self.forward = torch.compile(
self.forward, mode=compile_mode, fullgraph=True
)
if self.grad_checkpointing:
torch._dynamo.config.optimize_ddp = False
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for _, r in enumerate(self.resblocks):
if (
self.grad_checkpointing
and not torch.jit.is_scripting()
and self.training
):
x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
else:
x = r(
x,
attn_mask=attn_mask,
)
return x
def text_global_pool(
x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
) -> Tuple[torch.Tensor, torch.Tensor]:
if pool_type == "first":
pooled, tokens = x[:, 0], x[:, 1:]
elif pool_type == "last":
pooled, tokens = x[:, -1], x[:, :-1]
elif pool_type == "argmax":
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
else:
pooled = tokens = x
return pooled, tokens
class TextTransformer(nn.Module):
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
mlp_ratio: float = 4.0,
ls_init_value: Optional[float] = None,
output_dim: int = 512,
no_causal_mask: bool = False,
pool_type: str = "none", # no pooling
proj_bias: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
output_tokens: bool = False,
use_ln_post: bool = True,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = False,
):
super().__init__()
assert pool_type in ("first", "last", "argmax", "none")
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pool_type = pool_type
self.token_embedding = nn.Embedding(self.vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
compile_mode=compile_mode,
use_act_checkpoint=use_act_checkpoint,
)
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
if no_causal_mask:
self.attn_mask = None
else:
self.register_buffer(
"attn_mask", self.build_causal_mask(), persistent=False
)
if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def build_causal_mask(self) -> torch.Tensor:
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(
self, text: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_len = text.shape[1]
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if attn_mask is not None:
attn_mask = attn_mask[:seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len]
x = self.transformer(x, attn_mask=attn_mask)
x = self.ln_final(x)
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
pooled = self.text_projection(pooled)
else:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class VETextEncoder(nn.Module):
def __init__(
self,
d_model: int,
tokenizer: Callable,
width: int = 1024,
heads: int = 16,
layers: int = 24,
context_length: int = 32,
vocab_size: int = 49408,
use_ln_post: bool = True,
compile_mode: Optional[str] = None,
use_act_checkpoint: bool = True,
):
super().__init__()
self.context_length = context_length
self.use_ln_post = use_ln_post
self.tokenizer = tokenizer
self.encoder = TextTransformer(
context_length=self.context_length,
vocab_size=vocab_size,
width=width,
heads=heads,
layers=layers,
# we want the tokens, not just the pooled output
output_tokens=True,
use_ln_post=use_ln_post,
compile_mode=compile_mode,
use_act_checkpoint=use_act_checkpoint,
)
self.resizer = nn.Linear(self.encoder.width, d_model)
def forward(
self,
text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]],
input_boxes: Optional[List] = None,
device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if isinstance(text[0], str):
# no use case for this
assert input_boxes is None or len(input_boxes) == 0, "not supported"
# Encode the text
tokenized = self.tokenizer(text, context_length=self.context_length).to(
device
) # [b, seq_len]
text_attention_mask = (tokenized != 0).bool()
# manually embed the tokens
inputs_embeds = self.encoder.token_embedding(
tokenized
) # [b, seq_len, d=1024]
_, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
assert text_memory.shape[1] == inputs_embeds.shape[1]
# Invert attention mask because its the opposite in pytorch transformer
text_attention_mask = text_attention_mask.ne(1)
# Transpose memory because pytorch's attention expects sequence first
text_memory = text_memory.transpose(0, 1)
# Resize the encoder hidden states to be of the same d_model as the decoder
text_memory_resized = self.resizer(text_memory)
else:
# The text is already encoded, use as is.
text_attention_mask, text_memory_resized, tokenized = text
inputs_embeds = tokenized["inputs_embeds"]
assert (
input_boxes is None or len(input_boxes) == 0
), "Can't replace boxes in text if it's already encoded"
# Note that the input_embeds are returned in pytorch's convention (sequence first)
return (
text_attention_mask,
text_memory_resized,
inputs_embeds.transpose(0, 1),
)