File size: 6,450 Bytes
14114e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Provides utility to combine a vision backbone with a language backbone."""
from copy import copy
from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn.attention import sdpa_kernel, SDPBackend
from .act_ckpt_utils import activation_ckpt_wrapper
from .necks import Sam3DualViTDetNeck
class SAM3VLBackbone(nn.Module):
"""This backbone combines a vision backbone and a language backbone without fusion.
As such it is more of a convenience wrapper to handle the two backbones together.
It adds support for activation checkpointing and compilation.
"""
def __init__(
self,
visual: Sam3DualViTDetNeck,
text,
compile_visual: bool = False,
act_ckpt_whole_vision_backbone: bool = False,
act_ckpt_whole_language_backbone: bool = False,
scalp=0,
):
"""Initialize the backbone combiner.
:param visual: The vision backbone to use
:param text: The text encoder to use
"""
super().__init__()
self.vision_backbone: Sam3DualViTDetNeck = (
torch.compile(visual) if compile_visual else visual
)
self.language_backbone = text
self.scalp = scalp
# allow running activation checkpointing on the entire vision and language backbones
self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
def forward(
self,
samples: torch.Tensor,
captions: List[str],
input_boxes: Optional[torch.Tensor] = None,
additional_text: Optional[List[str]] = None,
):
"""Forward pass of the backbone combiner.
:param samples: The input images
:param captions: The input captions
:param input_boxes: If the text contains place-holders for boxes, this
parameter contains the tensor containing their spatial features
:param additional_text: This can be used to encode some additional text
(different from the captions) in the same forward of the backbone
:return: Output dictionary with the following keys:
- vision_features: The output of the vision backbone
- language_features: The output of the language backbone
- language_mask: The attention mask of the language backbone
- vision_pos_enc: The positional encoding of the vision backbone
- (optional) additional_text_features: The output of the language
backbone for the additional text
- (optional) additional_text_mask: The attention mask of the
language backbone for the additional text
"""
output = self.forward_image(samples)
device = output["vision_features"].device
output.update(self.forward_text(captions, input_boxes, additional_text, device))
return output
def forward_image(self, samples: torch.Tensor):
return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)(
samples=samples,
act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training,
)
def _forward_image_no_act_ckpt(self, samples):
# Forward through backbone
sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(
samples
)
if self.scalp > 0:
# Discard the lowest resolution features
sam3_features, sam3_pos = (
sam3_features[: -self.scalp],
sam3_pos[: -self.scalp],
)
if sam2_features is not None and sam2_pos is not None:
sam2_features, sam2_pos = (
sam2_features[: -self.scalp],
sam2_pos[: -self.scalp],
)
sam2_output = None
if sam2_features is not None and sam2_pos is not None:
sam2_src = sam2_features[-1]
sam2_output = {
"vision_features": sam2_src,
"vision_pos_enc": sam2_pos,
"backbone_fpn": sam2_features,
}
sam3_src = sam3_features[-1]
output = {
"vision_features": sam3_src,
"vision_pos_enc": sam3_pos,
"backbone_fpn": sam3_features,
"sam2_backbone_out": sam2_output,
}
return output
def forward_text(
self, captions, input_boxes=None, additional_text=None, device="cuda"
):
return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)(
captions=captions,
input_boxes=input_boxes,
additional_text=additional_text,
device=device,
act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training,
)
def _forward_text_no_ack_ckpt(
self,
captions,
input_boxes=None,
additional_text=None,
device="cuda",
):
output = {}
# Forward through text_encoder
text_to_encode = copy(captions)
if additional_text is not None:
# if there are additional_text, we piggy-back them into this forward.
# They'll be used later for output alignment
text_to_encode += additional_text
sdpa_context = sdpa_kernel(
[
SDPBackend.MATH,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.FLASH_ATTENTION,
]
)
with sdpa_context:
text_attention_mask, text_memory, text_embeds = self.language_backbone(
text_to_encode, input_boxes, device=device
)
if additional_text is not None:
output["additional_text_features"] = text_memory[:, -len(additional_text) :]
output["additional_text_mask"] = text_attention_mask[
-len(additional_text) :
]
text_memory = text_memory[:, : len(captions)]
text_attention_mask = text_attention_mask[: len(captions)]
text_embeds = text_embeds[:, : len(captions)]
output["language_features"] = text_memory
output["language_mask"] = text_attention_mask
output["language_embeds"] = (
text_embeds # Text embeddings before forward to the encoder
)
return output
|