| | import torch |
| |
|
| | from . import bev_pool_ext |
| |
|
| | __all__ = ["bev_pool"] |
| |
|
| |
|
| | class QuickCumsum(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x, geom_feats, ranks): |
| | x = x.cumsum(0) |
| | kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) |
| | kept[:-1] = ranks[1:] != ranks[:-1] |
| |
|
| | x, geom_feats = x[kept], geom_feats[kept] |
| | x = torch.cat((x[:1], x[1:] - x[:-1])) |
| |
|
| | |
| | ctx.save_for_backward(kept) |
| |
|
| | |
| | ctx.mark_non_differentiable(geom_feats) |
| |
|
| | return x, geom_feats |
| |
|
| | @staticmethod |
| | def backward(ctx, gradx, gradgeom): |
| | (kept,) = ctx.saved_tensors |
| | back = torch.cumsum(kept, 0) |
| | back[kept] -= 1 |
| |
|
| | val = gradx[back] |
| |
|
| | return val, None, None |
| |
|
| |
|
| | class QuickCumsumCuda(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x, geom_feats, ranks, B, D, H, W): |
| | kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) |
| | kept[1:] = ranks[1:] != ranks[:-1] |
| | interval_starts = torch.where(kept)[0].int() |
| | interval_lengths = torch.zeros_like(interval_starts) |
| | interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1] |
| | interval_lengths[-1] = x.shape[0] - interval_starts[-1] |
| | geom_feats = geom_feats.int() |
| |
|
| | out = bev_pool_ext.bev_pool_forward( |
| | x, |
| | geom_feats, |
| | interval_lengths, |
| | interval_starts, |
| | B, |
| | D, |
| | H, |
| | W, |
| | ) |
| |
|
| | ctx.save_for_backward(interval_starts, interval_lengths, geom_feats) |
| | ctx.saved_shapes = B, D, H, W |
| | return out |
| |
|
| | @staticmethod |
| | def backward(ctx, out_grad): |
| | interval_starts, interval_lengths, geom_feats = ctx.saved_tensors |
| | B, D, H, W = ctx.saved_shapes |
| |
|
| | out_grad = out_grad.contiguous() |
| | x_grad = bev_pool_ext.bev_pool_backward( |
| | out_grad, |
| | geom_feats, |
| | interval_lengths, |
| | interval_starts, |
| | B, |
| | D, |
| | H, |
| | W, |
| | ) |
| |
|
| | return x_grad, None, None, None, None, None, None |
| |
|
| |
|
| | def bev_pool(feats, coords, B, D, H, W): |
| | assert feats.shape[0] == coords.shape[0] |
| |
|
| | ranks = ( |
| | coords[:, 0] * (W * D * B) |
| | + coords[:, 1] * (D * B) |
| | + coords[:, 2] * B |
| | + coords[:, 3] |
| | ) |
| | indices = ranks.argsort() |
| | feats, coords, ranks = feats[indices], coords[indices], ranks[indices] |
| |
|
| | x = QuickCumsumCuda.apply(feats, coords, ranks, B, D, H, W) |
| | x = x.permute(0, 4, 1, 2, 3).contiguous() |
| | return x |
| |
|