|
|
|
|
|
|
|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class PositionEmbeddingSine(nn.Module): |
|
|
""" |
|
|
This is a more standard version of the position embedding, very similar to the one |
|
|
used by the Attention is all you need paper, generalized to work on images. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_pos_feats, |
|
|
temperature: int = 10000, |
|
|
normalize: bool = True, |
|
|
scale: Optional[float] = None, |
|
|
precompute_resolution: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
assert num_pos_feats % 2 == 0, "Expecting even model width" |
|
|
self.num_pos_feats = num_pos_feats // 2 |
|
|
self.temperature = temperature |
|
|
self.normalize = normalize |
|
|
if scale is not None and normalize is False: |
|
|
raise ValueError("normalize should be True if scale is passed") |
|
|
if scale is None: |
|
|
scale = 2 * math.pi |
|
|
self.scale = scale |
|
|
|
|
|
self.cache = {} |
|
|
|
|
|
|
|
|
if precompute_resolution is not None: |
|
|
|
|
|
precompute_sizes = [ |
|
|
(precompute_resolution // 4, precompute_resolution // 4), |
|
|
(precompute_resolution // 8, precompute_resolution // 8), |
|
|
(precompute_resolution // 16, precompute_resolution // 16), |
|
|
(precompute_resolution // 32, precompute_resolution // 32), |
|
|
] |
|
|
for size in precompute_sizes: |
|
|
tensors = torch.zeros((1, 1) + size, device="cuda") |
|
|
self.forward(tensors) |
|
|
|
|
|
self.cache[size] = self.cache[size].clone().detach() |
|
|
|
|
|
def _encode_xy(self, x, y): |
|
|
|
|
|
assert len(x) == len(y) and x.ndim == y.ndim == 1 |
|
|
x_embed = x * self.scale |
|
|
y_embed = y * self.scale |
|
|
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
|
|
|
|
pos_x = x_embed[:, None] / dim_t |
|
|
pos_y = y_embed[:, None] / dim_t |
|
|
pos_x = torch.stack( |
|
|
(pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 |
|
|
).flatten(1) |
|
|
pos_y = torch.stack( |
|
|
(pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 |
|
|
).flatten(1) |
|
|
return pos_x, pos_y |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_boxes(self, x, y, w, h): |
|
|
pos_x, pos_y = self._encode_xy(x, y) |
|
|
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) |
|
|
return pos |
|
|
|
|
|
encode = encode_boxes |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_points(self, x, y, labels): |
|
|
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape |
|
|
assert bx == by and nx == ny and bx == bl and nx == nl |
|
|
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) |
|
|
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) |
|
|
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) |
|
|
return pos |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, x): |
|
|
cache_key = None |
|
|
cache_key = (x.shape[-2], x.shape[-1]) |
|
|
if cache_key in self.cache: |
|
|
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) |
|
|
y_embed = ( |
|
|
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) |
|
|
.view(1, -1, 1) |
|
|
.repeat(x.shape[0], 1, x.shape[-1]) |
|
|
) |
|
|
x_embed = ( |
|
|
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) |
|
|
.view(1, 1, -1) |
|
|
.repeat(x.shape[0], x.shape[-2], 1) |
|
|
) |
|
|
|
|
|
if self.normalize: |
|
|
eps = 1e-6 |
|
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
|
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
|
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
|
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t |
|
|
pos_y = y_embed[:, :, :, None] / dim_t |
|
|
pos_x = torch.stack( |
|
|
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 |
|
|
).flatten(3) |
|
|
pos_y = torch.stack( |
|
|
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 |
|
|
).flatten(3) |
|
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
|
|
if cache_key is not None: |
|
|
self.cache[cache_key] = pos[0] |
|
|
return pos |
|
|
|