Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,819 Bytes
917a889 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import sparse as sp
from .sparse_unet_vae import (
SparseResBlock3d,
SparseConvNeXtBlock3d,
SparseResBlockDownsample3d,
SparseResBlockUpsample3d,
SparseResBlockS2C3d,
SparseResBlockC2S3d,
)
from .sparse_unet_vae import (
SparseUnetVaeEncoder,
SparseUnetVaeDecoder,
)
from ...representations import Mesh
from o_voxel.convert import flexible_dual_grid_to_mesh
class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder):
def __init__(
self,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
down_block_type: List[str],
block_args: List[Dict[str, Any]],
use_fp16: bool = False,
):
super().__init__(
6,
model_channels,
latent_channels,
num_blocks,
block_type,
down_block_type,
block_args,
use_fp16,
)
def forward(self, vertices: sp.SparseTensor, intersected: sp.SparseTensor, sample_posterior=False, return_raw=False):
x = vertices.replace(torch.cat([
vertices.feats - 0.5,
intersected.feats.float() - 0.5,
], dim=1))
return super().forward(x, sample_posterior, return_raw)
class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder):
def __init__(
self,
resolution: int,
model_channels: List[int],
latent_channels: int,
num_blocks: List[int],
block_type: List[str],
up_block_type: List[str],
block_args: List[Dict[str, Any]],
voxel_margin: float = 0.5,
use_fp16: bool = False,
):
self.resolution = resolution
self.voxel_margin = voxel_margin
super().__init__(
7,
model_channels,
latent_channels,
num_blocks,
block_type,
up_block_type,
block_args,
use_fp16,
)
def set_resolution(self, resolution: int) -> None:
self.resolution = resolution
def forward(self, x: sp.SparseTensor, gt_intersected: sp.SparseTensor = None, **kwargs):
decoded = super().forward(x, **kwargs)
if self.training:
h, subs_gt, subs = decoded
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
intersected_logits = h.replace(h.feats[..., 3:6])
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
mesh = [Mesh(flexible_dual_grid_to_mesh(
h.coords[:, 1:], v.feats, i.feats, q.feats,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
grid_size=self.resolution,
train=True
)) for v, i, q in zip(vertices, gt_intersected, quad_lerp)]
return mesh, vertices, intersected_logits, subs_gt, subs
else:
out_list = list(decoded) if isinstance(decoded, tuple) else [decoded]
h = out_list[0]
vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin)
intersected = h.replace(h.feats[..., 3:6] > 0)
quad_lerp = h.replace(F.softplus(h.feats[..., 6:7]))
mesh = [Mesh(*flexible_dual_grid_to_mesh(
h.coords[:, 1:], v.feats, i.feats, q.feats,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
grid_size=self.resolution,
train=False
)) for v, i, q in zip(vertices, intersected, quad_lerp)]
out_list[0] = mesh
return out_list[0] if len(out_list) == 1 else tuple(out_list)
|