|
|
|
|
|
|
|
|
"""Necks are the interface between a vision backbone and the rest of the detection model""" |
|
|
|
|
|
from copy import deepcopy |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class Sam3DualViTDetNeck(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
trunk: nn.Module, |
|
|
position_encoding: nn.Module, |
|
|
d_model: int, |
|
|
scale_factors=(4.0, 2.0, 1.0, 0.5), |
|
|
add_sam2_neck: bool = False, |
|
|
): |
|
|
""" |
|
|
SimpleFPN neck a la ViTDet |
|
|
(From detectron2, very lightly adapted) |
|
|
It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights |
|
|
|
|
|
:param trunk: the backbone |
|
|
:param position_encoding: the positional encoding to use |
|
|
:param d_model: the dimension of the model |
|
|
""" |
|
|
super().__init__() |
|
|
self.trunk = trunk |
|
|
self.position_encoding = position_encoding |
|
|
self.convs = nn.ModuleList() |
|
|
|
|
|
self.scale_factors = scale_factors |
|
|
use_bias = True |
|
|
dim: int = self.trunk.channel_list[-1] |
|
|
|
|
|
for _, scale in enumerate(scale_factors): |
|
|
current = nn.Sequential() |
|
|
|
|
|
if scale == 4.0: |
|
|
current.add_module( |
|
|
"dconv_2x2_0", |
|
|
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), |
|
|
) |
|
|
current.add_module( |
|
|
"gelu", |
|
|
nn.GELU(), |
|
|
) |
|
|
current.add_module( |
|
|
"dconv_2x2_1", |
|
|
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), |
|
|
) |
|
|
out_dim = dim // 4 |
|
|
elif scale == 2.0: |
|
|
current.add_module( |
|
|
"dconv_2x2", |
|
|
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), |
|
|
) |
|
|
out_dim = dim // 2 |
|
|
elif scale == 1.0: |
|
|
out_dim = dim |
|
|
elif scale == 0.5: |
|
|
current.add_module( |
|
|
"maxpool_2x2", |
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
) |
|
|
out_dim = dim |
|
|
else: |
|
|
raise NotImplementedError(f"scale_factor={scale} is not supported yet.") |
|
|
|
|
|
current.add_module( |
|
|
"conv_1x1", |
|
|
nn.Conv2d( |
|
|
in_channels=out_dim, |
|
|
out_channels=d_model, |
|
|
kernel_size=1, |
|
|
bias=use_bias, |
|
|
), |
|
|
) |
|
|
current.add_module( |
|
|
"conv_3x3", |
|
|
nn.Conv2d( |
|
|
in_channels=d_model, |
|
|
out_channels=d_model, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
bias=use_bias, |
|
|
), |
|
|
) |
|
|
self.convs.append(current) |
|
|
|
|
|
self.sam2_convs = None |
|
|
if add_sam2_neck: |
|
|
|
|
|
self.sam2_convs = deepcopy(self.convs) |
|
|
|
|
|
def forward( |
|
|
self, tensor_list: List[torch.Tensor] |
|
|
) -> Tuple[ |
|
|
List[torch.Tensor], |
|
|
List[torch.Tensor], |
|
|
Optional[List[torch.Tensor]], |
|
|
Optional[List[torch.Tensor]], |
|
|
]: |
|
|
xs = self.trunk(tensor_list) |
|
|
sam3_out, sam3_pos = [], [] |
|
|
sam2_out, sam2_pos = None, None |
|
|
if self.sam2_convs is not None: |
|
|
sam2_out, sam2_pos = [], [] |
|
|
x = xs[-1] |
|
|
for i in range(len(self.convs)): |
|
|
sam3_x_out = self.convs[i](x) |
|
|
sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype) |
|
|
sam3_out.append(sam3_x_out) |
|
|
sam3_pos.append(sam3_pos_out) |
|
|
|
|
|
if self.sam2_convs is not None: |
|
|
sam2_x_out = self.sam2_convs[i](x) |
|
|
sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype) |
|
|
sam2_out.append(sam2_x_out) |
|
|
sam2_pos.append(sam2_pos_out) |
|
|
return sam3_out, sam3_pos, sam2_out, sam2_pos |
|
|
|