# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved """ Transformer decoder. Inspired from Pytorch's version, adds the pre-norm variant """ from typing import Any, Dict, List, Optional import numpy as np import torch from sam3.sam.transformer import RoPEAttention from torch import nn, Tensor from torchvision.ops.roi_align import RoIAlign from .act_ckpt_utils import activation_ckpt_wrapper from .box_ops import box_cxcywh_to_xyxy from .model_misc import ( gen_sineembed_for_position, get_activation_fn, get_clones, inverse_sigmoid, MLP, ) class TransformerDecoderLayer(nn.Module): def __init__( self, activation: str, d_model: int, dim_feedforward: int, dropout: float, cross_attention: nn.Module, n_heads: int, use_text_cross_attention: bool = False, ): super().__init__() # cross attention self.cross_attn = cross_attention self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm1 = nn.LayerNorm(d_model) # cross attention text self.use_text_cross_attention = use_text_cross_attention if use_text_cross_attention: self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.catext_norm = nn.LayerNorm(d_model) # self attention self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm2 = nn.LayerNorm(d_model) # ffn self.linear1 = nn.Linear(d_model, dim_feedforward) self.activation = get_activation_fn(activation) self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.linear2 = nn.Linear(dim_feedforward, d_model) self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.norm3 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, tgt): with torch.amp.autocast(device_type="cuda", enabled=False): tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout4(tgt2) tgt = self.norm3(tgt) return tgt def forward( self, # for tgt tgt: Optional[Tensor], # nq, bs, d_model tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos)) tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos) tgt_key_padding_mask: Optional[Tensor] = None, tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4 memory_text: Optional[Tensor] = None, # num_token, bs, d_model text_attention_mask: Optional[Tensor] = None, # bs, num_token # for memory memory: Optional[Tensor] = None, # hw, bs, d_model memory_key_padding_mask: Optional[Tensor] = None, memory_level_start_index: Optional[Tensor] = None, # num_levels memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 memory_pos: Optional[Tensor] = None, # pos for memory # sa self_attn_mask: Optional[Tensor] = None, # mask used for self-attention cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention # dac dac=False, dac_use_selfatt_ln=True, presence_token=None, # skip inside deformable attn identity=0.0, **kwargs, # additional kwargs for compatibility ): """ Input: - tgt/tgt_query_pos: nq, bs, d_model - """ # self attention if self.self_attn is not None: if dac: # we only apply self attention to the first half of the queries assert tgt.shape[0] % 2 == 0 num_o2o_queries = tgt.shape[0] // 2 tgt_o2o = tgt[:num_o2o_queries] tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries] tgt_o2m = tgt[num_o2o_queries:] else: tgt_o2o = tgt tgt_query_pos_o2o = tgt_query_pos if presence_token is not None: tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0) tgt_query_pos_o2o = torch.cat( [torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0 ) tgt_query_pos = torch.cat( [torch.zeros_like(presence_token), tgt_query_pos], dim=0 ) q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o) tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0] tgt_o2o = tgt_o2o + self.dropout2(tgt2) if dac: if not dac_use_selfatt_ln: tgt_o2o = self.norm2(tgt_o2o) tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine if dac_use_selfatt_ln: tgt = self.norm2(tgt) else: tgt = tgt_o2o tgt = self.norm2(tgt) if self.use_text_cross_attention: tgt2 = self.ca_text( self.with_pos_embed(tgt, tgt_query_pos), memory_text, memory_text, key_padding_mask=text_attention_mask, )[0] tgt = tgt + self.catext_dropout(tgt2) tgt = self.catext_norm(tgt) if presence_token is not None: presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) cross_attn_mask = torch.cat( [presence_token_mask, cross_attn_mask], dim=1 ) # (bs*nheads, 1+nq, hw) # Cross attention to image tgt2 = self.cross_attn( query=self.with_pos_embed(tgt, tgt_query_pos), key=self.with_pos_embed(memory, memory_pos), value=memory, attn_mask=cross_attn_mask, key_padding_mask=( memory_key_padding_mask.transpose(0, 1) if memory_key_padding_mask is not None else None ), )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) # ffn tgt = self.forward_ffn(tgt) presence_token_out = None if presence_token is not None: presence_token_out = tgt[:1] tgt = tgt[1:] return tgt, presence_token_out class TransformerDecoder(nn.Module): def __init__( self, d_model: int, frozen: bool, interaction_layer, layer, num_layers: int, num_queries: int, return_intermediate: bool, box_refine: bool = False, num_o2m_queries: int = 0, dac: bool = False, boxRPB: str = "none", # Experimental: An object query for SAM 2 tasks instance_query: bool = False, # Defines the number of additional instance queries, # 1 or 4 are the most likely for single vs multi mask support num_instances: int = 1, # Irrelevant if instance_query is False dac_use_selfatt_ln: bool = True, use_act_checkpoint: bool = False, compile_mode=None, presence_token: bool = False, clamp_presence_logits: bool = True, clamp_presence_logit_max_val: float = 10.0, use_normed_output_consistently: bool = True, separate_box_head_instance: bool = False, separate_norm_instance: bool = False, resolution: Optional[int] = None, stride: Optional[int] = None, ): super().__init__() self.d_model = d_model self.layers = get_clones(layer, num_layers) self.fine_layers = ( get_clones(interaction_layer, num_layers) if interaction_layer is not None else [None] * num_layers ) self.num_layers = num_layers self.num_queries = num_queries self.dac = dac if dac: self.num_o2m_queries = num_queries tot_num_queries = num_queries else: self.num_o2m_queries = num_o2m_queries tot_num_queries = num_queries + num_o2m_queries self.norm = nn.LayerNorm(d_model) self.return_intermediate = return_intermediate self.bbox_embed = MLP(d_model, d_model, 4, 3) self.query_embed = nn.Embedding(tot_num_queries, d_model) self.instance_query_embed = None self.instance_query_reference_points = None self.use_instance_query = instance_query self.num_instances = num_instances self.use_normed_output_consistently = use_normed_output_consistently self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None self.instance_bbox_embed = None if separate_box_head_instance: self.instance_bbox_embed = MLP(d_model, d_model, 4, 3) if instance_query: self.instance_query_embed = nn.Embedding(num_instances, d_model) self.box_refine = box_refine if box_refine: nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) self.reference_points = nn.Embedding(num_queries, 4) if instance_query: self.instance_reference_points = nn.Embedding(num_instances, 4) assert boxRPB in ["none", "log", "linear", "both"] self.boxRPB = boxRPB if boxRPB != "none": try: nheads = self.layers[0].cross_attn_image.num_heads except AttributeError: nheads = self.layers[0].cross_attn.num_heads n_input = 4 if boxRPB == "both" else 2 self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2) self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2) self.compilable_cord_cache = None self.compilable_stored_size = None self.coord_cache = {} if resolution is not None and stride is not None: feat_size = resolution // stride coords_h, coords_w = self._get_coords( feat_size, feat_size, device="cuda" ) self.compilable_cord_cache = (coords_h, coords_w) self.compilable_stored_size = (feat_size, feat_size) self.roi_pooler = ( RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True) if interaction_layer is not None else None ) if frozen: for p in self.parameters(): p.requires_grad_(False) self.presence_token = None self.clamp_presence_logits = clamp_presence_logits self.clamp_presence_logit_max_val = clamp_presence_logit_max_val if presence_token: self.presence_token = nn.Embedding(1, d_model) self.presence_token_head = MLP(d_model, d_model, 1, 3) self.presence_token_out_norm = nn.LayerNorm(d_model) self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2) self.dac_use_selfatt_ln = dac_use_selfatt_ln self.use_act_checkpoint = use_act_checkpoint nn.init.normal_(self.query_embed.weight.data) if self.instance_query_embed is not None: nn.init.normal_(self.instance_query_embed.weight.data) assert self.roi_pooler is None assert self.return_intermediate, "support return_intermediate only" assert self.box_refine, "support box refine only" self.compile_mode = compile_mode self.compiled = False # We defer compilation till after the first forward, to first warm-up the boxRPB cache # assign layer index to each layer so that some layers can decide what to do # based on which layer index they are (e.g. cross attention to memory bank only # in selected layers) for layer_idx, layer in enumerate(self.layers): layer.layer_idx = layer_idx @staticmethod def _get_coords(H, W, device): coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W return coords_h, coords_w def _get_rpb_matrix(self, reference_boxes, feat_size): H, W = feat_size boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1) bs, num_queries, _ = boxes_xyxy.shape if self.compilable_cord_cache is None: self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device) self.compilable_stored_size = (H, W) if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == ( H, W, ): # good, hitting the cache, will be compilable coords_h, coords_w = self.compilable_cord_cache else: # cache miss, will create compilation issue # In case we're not compiling, we'll still rely on the dict-based cache if feat_size not in self.coord_cache: self.coord_cache[feat_size] = self._get_coords( H, W, reference_boxes.device ) coords_h, coords_w = self.coord_cache[feat_size] assert coords_h.shape == (H,) assert coords_w.shape == (W,) deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] deltas_y = deltas_y.view(bs, num_queries, -1, 2) deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] deltas_x = deltas_x.view(bs, num_queries, -1, 2) if self.boxRPB in ["log", "both"]: deltas_x_log = deltas_x * 8 # normalize to -8, 8 deltas_x_log = ( torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / np.log2(8) ) deltas_y_log = deltas_y * 8 # normalize to -8, 8 deltas_y_log = ( torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / np.log2(8) ) if self.boxRPB == "log": deltas_x = deltas_x_log deltas_y = deltas_y_log else: deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1) deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1) if self.training: assert self.use_act_checkpoint, "activation ckpt not enabled in decoder" deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)( x=deltas_x, act_ckpt_enable=self.training and self.use_act_checkpoint, ) # bs, num_queries, W, n_heads deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)( x=deltas_y, act_ckpt_enable=self.training and self.use_act_checkpoint, ) # bs, num_queries, H, n_heads if not torch.compiler.is_dynamo_compiling(): assert deltas_x.shape[:3] == (bs, num_queries, W) assert deltas_y.shape[:3] == (bs, num_queries, H) B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( 2 ) # bs, num_queries, H, W, n_heads if not torch.compiler.is_dynamo_compiling(): assert B.shape[:4] == (bs, num_queries, H, W) B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W B = B.contiguous() # memeff attn likes ordered strides if not torch.compiler.is_dynamo_compiling(): assert B.shape[2:] == (num_queries, H * W) return B def forward( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4 # for memory level_start_index: Optional[Tensor] = None, # num_levels spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 valid_ratios: Optional[Tensor] = None, # for text memory_text: Optional[Tensor] = None, text_attention_mask: Optional[Tensor] = None, # if `apply_dac` is None, it will default to `self.dac` apply_dac: Optional[bool] = None, is_instance_prompt=False, decoder_extra_kwargs: Optional[Dict] = None, # ROI memory bank obj_roi_memory_feat=None, obj_roi_memory_mask=None, box_head_trk=None, ): """ Input: - tgt: nq, bs, d_model - memory: \\sum{hw}, bs, d_model - pos: \\sum{hw}, bs, d_model - reference_boxes: nq, bs, 4 (after sigmoid) - valid_ratios/spatial_shapes: bs, nlevel, 2 """ if memory_mask is not None: assert ( self.boxRPB == "none" ), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented" apply_dac = apply_dac if apply_dac is not None else self.dac if apply_dac: assert (tgt.shape[0] == self.num_queries) or ( self.use_instance_query and (tgt.shape[0] == self.instance_query_embed.num_embeddings) ) tgt = tgt.repeat(2, 1, 1) # note that we don't tile tgt_mask, since DAC doesn't # use self-attention in o2m queries if reference_boxes is not None: assert (reference_boxes.shape[0] == self.num_queries) or ( self.use_instance_query and ( reference_boxes.shape[0] == self.instance_query_embed.num_embeddings ) ) reference_boxes = reference_boxes.repeat(2, 1, 1) bs = tgt.shape[1] intermediate = [] intermediate_presence_logits = [] presence_feats = None if self.box_refine: if reference_boxes is None: # In this case, we're in a one-stage model, so we generate the reference boxes reference_boxes = self.reference_points.weight.unsqueeze(1) reference_boxes = ( reference_boxes.repeat(2, bs, 1) if apply_dac else reference_boxes.repeat(1, bs, 1) ) reference_boxes = reference_boxes.sigmoid() intermediate_ref_boxes = [reference_boxes] else: reference_boxes = None intermediate_ref_boxes = None output = tgt presence_out = None if self.presence_token is not None and is_instance_prompt is False: # expand to batch dim presence_out = self.presence_token.weight[None].expand(1, bs, -1) box_head = self.bbox_embed if is_instance_prompt and self.instance_bbox_embed is not None: box_head = self.instance_bbox_embed out_norm = self.norm if is_instance_prompt and self.instance_norm is not None: out_norm = self.instance_norm for layer_idx, layer in enumerate(self.layers): reference_points_input = ( reference_boxes[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[None, :] ) # nq, bs, nlevel, 4 query_sine_embed = gen_sineembed_for_position( reference_points_input[:, :, 0, :], self.d_model ) # nq, bs, d_model*2 # conditional query query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model if self.boxRPB != "none" and reference_boxes is not None: assert ( spatial_shapes.shape[0] == 1 ), "only single scale support implemented" memory_mask = self._get_rpb_matrix( reference_boxes, (spatial_shapes[0, 0], spatial_shapes[0, 1]), ) memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W) if self.training: assert ( self.use_act_checkpoint ), "Activation checkpointing not enabled in the decoder" output, presence_out = activation_ckpt_wrapper(layer)( tgt=output, tgt_query_pos=query_pos, tgt_query_sine_embed=query_sine_embed, tgt_key_padding_mask=tgt_key_padding_mask, tgt_reference_points=reference_points_input, memory_text=memory_text, text_attention_mask=text_attention_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, memory_level_start_index=level_start_index, memory_spatial_shapes=spatial_shapes, memory_pos=pos, self_attn_mask=tgt_mask, cross_attn_mask=memory_mask, dac=apply_dac, dac_use_selfatt_ln=self.dac_use_selfatt_ln, presence_token=presence_out, **(decoder_extra_kwargs or {}), act_ckpt_enable=self.training and self.use_act_checkpoint, # ROI memory bank obj_roi_memory_feat=obj_roi_memory_feat, obj_roi_memory_mask=obj_roi_memory_mask, ) # iter update if self.box_refine: reference_before_sigmoid = inverse_sigmoid(reference_boxes) if box_head_trk is None: # delta_unsig = self.bbox_embed(output) if not self.use_normed_output_consistently: delta_unsig = box_head(output) else: delta_unsig = box_head(out_norm(output)) else: # box_head_trk use a separate box head for tracking queries Q_det = decoder_extra_kwargs["Q_det"] assert output.size(0) >= Q_det delta_unsig_det = self.bbox_embed(output[:Q_det]) delta_unsig_trk = box_head_trk(output[Q_det:]) delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0) outputs_unsig = delta_unsig + reference_before_sigmoid new_reference_points = outputs_unsig.sigmoid() reference_boxes = new_reference_points.detach() if layer_idx != self.num_layers - 1: intermediate_ref_boxes.append(new_reference_points) else: raise NotImplementedError("not implemented yet") intermediate.append(out_norm(output)) if self.presence_token is not None and is_instance_prompt is False: # norm, mlp head intermediate_layer_presence_logits = self.presence_token_head( self.presence_token_out_norm(presence_out) ).squeeze(-1) # clamp to mitigate numerical issues if self.clamp_presence_logits: intermediate_layer_presence_logits.clamp( min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val, ) intermediate_presence_logits.append(intermediate_layer_presence_logits) presence_feats = presence_out.clone() if not self.compiled and self.compile_mode is not None: self.forward = torch.compile( self.forward, mode=self.compile_mode, fullgraph=True ) self.compiled = True return ( torch.stack(intermediate), torch.stack(intermediate_ref_boxes), ( torch.stack(intermediate_presence_logits) if self.presence_token is not None and is_instance_prompt is False else None ), presence_feats, ) class TransformerEncoderCrossAttention(nn.Module): def __init__( self, d_model: int, frozen: bool, pos_enc_at_input: bool, layer, num_layers: int, use_act_checkpoint: bool = False, batch_first: bool = False, # Do layers expect batch first input? # which layers to exclude cross attention? default: None, means all # layers use cross attention remove_cross_attention_layers: Optional[list] = None, ): super().__init__() self.d_model = d_model self.layers = get_clones(layer, num_layers) self.num_layers = num_layers self.norm = nn.LayerNorm(d_model) self.pos_enc_at_input = pos_enc_at_input self.use_act_checkpoint = use_act_checkpoint if frozen: for p in self.parameters(): p.requires_grad_(False) self.batch_first = batch_first # remove cross attention layers if specified self.remove_cross_attention_layers = [False] * self.num_layers if remove_cross_attention_layers is not None: for i in remove_cross_attention_layers: self.remove_cross_attention_layers[i] = True assert len(self.remove_cross_attention_layers) == len(self.layers) for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers): if remove_cross_attention: self.layers[i].cross_attn_image = None self.layers[i].norm2 = None self.layers[i].dropout2 = None def forward( self, src, # self-attention inputs prompt, # cross-attention inputs src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs src_key_padding_mask: Optional[Tensor] = None, prompt_key_padding_mask: Optional[Tensor] = None, src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs feat_sizes: Optional[list] = None, num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* ): if isinstance(src, list): assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list) assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1 src, src_key_padding_mask, src_pos = ( src[0], src_key_padding_mask[0], src_pos[0], ) assert ( src.shape[1] == prompt.shape[1] ), "Batch size must be the same for src and prompt" output = src if self.pos_enc_at_input and src_pos is not None: output = output + 0.1 * src_pos if self.batch_first: # Convert to batch first output = output.transpose(0, 1) src_pos = src_pos.transpose(0, 1) prompt = prompt.transpose(0, 1) prompt_pos = prompt_pos.transpose(0, 1) for layer in self.layers: kwds = {} if isinstance(layer.cross_attn_image, RoPEAttention): kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} output = activation_ckpt_wrapper(layer)( tgt=output, memory=prompt, tgt_mask=src_mask, memory_mask=prompt_mask, tgt_key_padding_mask=src_key_padding_mask, memory_key_padding_mask=prompt_key_padding_mask, pos=prompt_pos, query_pos=src_pos, dac=False, attn_bias=None, act_ckpt_enable=self.training and self.use_act_checkpoint, **kwds, ) normed_output = self.norm(output) if self.batch_first: # Convert back to seq first normed_output = normed_output.transpose(0, 1) src_pos = src_pos.transpose(0, 1) return { "memory": normed_output, "pos_embed": src_pos, "padding_mask": src_key_padding_mask, } class TransformerDecoderLayerv1(nn.Module): def __init__( self, activation: str, cross_attention: nn.Module, d_model: int, dim_feedforward: int, dropout: float, pos_enc_at_attn: bool, pos_enc_at_cross_attn_keys: bool, pos_enc_at_cross_attn_queries: bool, pre_norm: bool, self_attention: nn.Module, ): super().__init__() self.d_model = d_model self.dim_feedforward = dim_feedforward self.dropout_value = dropout self.self_attn = self_attention self.cross_attn_image = cross_attention # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation_str = activation self.activation = get_activation_fn(activation) self.pre_norm = pre_norm self.pos_enc_at_attn = pos_enc_at_attn self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys def forward_post( self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, **kwargs, ): q = k = tgt + query_pos if self.pos_enc_at_attn else tgt # Self attention tgt2 = self.self_attn( q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) # Cross attention to image tgt2 = self.cross_attn_image( query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt, key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, )[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # FFN tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt def forward_pre( self, tgt, memory, dac: bool = False, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, attn_bias: Optional[Tensor] = None, **kwargs, ): if dac: # we only apply self attention to the first half of the queries assert tgt.shape[0] % 2 == 0 other_tgt = tgt[tgt.shape[0] // 2 :] tgt = tgt[: tgt.shape[0] // 2] tgt2 = self.norm1(tgt) q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 tgt2 = self.self_attn( q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask, )[0] tgt = tgt + self.dropout1(tgt2) if dac: # Recombine tgt = torch.cat((tgt, other_tgt), dim=0) tgt2 = self.norm2(tgt) tgt2 = self.cross_attn_image( query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, key=memory + pos if self.pos_enc_at_cross_attn_keys else memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask, attn_bias=attn_bias, )[0] tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward( self, tgt, memory, dac: bool = False, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, attn_bias: Optional[Tensor] = None, **kwds: Any, ) -> torch.Tensor: fwd_fn = self.forward_pre if self.pre_norm else self.forward_post return fwd_fn( tgt, memory, dac=dac, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos, attn_bias=attn_bias, **kwds, ) class TransformerDecoderLayerv2(TransformerDecoderLayerv1): def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any): super().__init__(*args, **kwds) self.cross_attention_first = cross_attention_first def _forward_sa(self, tgt, query_pos): # Self-Attention tgt2 = self.norm1(tgt) q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 tgt2 = self.self_attn(q, k, v=tgt2) tgt = tgt + self.dropout1(tgt2) return tgt def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): if self.cross_attn_image is None: return tgt kwds = {} if num_k_exclude_rope > 0: assert isinstance(self.cross_attn_image, RoPEAttention) kwds = {"num_k_exclude_rope": num_k_exclude_rope} # Cross-Attention tgt2 = self.norm2(tgt) tgt2 = self.cross_attn_image( q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, v=memory, **kwds, ) tgt = tgt + self.dropout2(tgt2) return tgt def forward_pre( self, tgt, memory, dac: bool, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, attn_bias: Optional[Tensor] = None, num_k_exclude_rope: int = 0, ): assert dac is False assert tgt_mask is None assert memory_mask is None assert tgt_key_padding_mask is None assert memory_key_padding_mask is None assert attn_bias is None if self.cross_attention_first: tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) tgt = self._forward_sa(tgt, query_pos) else: tgt = self._forward_sa(tgt, query_pos) tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) # MLP tgt2 = self.norm3(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout3(tgt2) return tgt def forward(self, *args: Any, **kwds: Any) -> torch.Tensor: if self.pre_norm: return self.forward_pre(*args, **kwds) raise NotImplementedError