| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | |
| | |
| |
|
| | |
| | class xATGLU(nn.Module): |
| | def __init__(self, input_dim, output_dim, bias=True): |
| | super().__init__() |
| | |
| | self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias) |
| | nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear') |
| | |
| | self.alpha = nn.Parameter(torch.zeros(1)) |
| | self.half_pi = torch.pi / 2 |
| | self.inv_pi = 1 / torch.pi |
| | |
| | def forward(self, x): |
| | projected = self.proj(x) |
| | gate_path, value_path = projected.chunk(2, dim=-1) |
| | |
| | |
| | gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi |
| | expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha |
| | |
| | return expanded_gate * value_path |
| |
|
| | class ResBlock(nn.Module): |
| | def __init__(self, channels): |
| | super().__init__() |
| | self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
| | self.norm1 = nn.GroupNorm(32, channels) |
| | self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
| | self.norm2 = nn.GroupNorm(32, channels) |
| | |
| | def forward(self, x): |
| | h = self.conv1(F.silu(self.norm1(x))) |
| | h = self.conv2(F.silu(self.norm2(h))) |
| | return x + h |
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, channels, num_heads=8): |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(channels) |
| | self.attn = nn.MultiheadAttention(channels, num_heads) |
| | self.norm2 = nn.LayerNorm(channels) |
| | self.mlp = nn.Sequential( |
| | xATGLU(channels, 4 * channels), |
| | nn.Linear(4 * channels, channels) |
| | ) |
| | |
| | def forward(self, x): |
| | |
| | b, c, h, w = x.shape |
| | spatial_size = h * w |
| | x = x.flatten(2).permute(2, 0, 1) |
| | |
| | |
| | h_attn = self.norm1(x) |
| | h_attn, _ = self.attn(h_attn, h_attn, h_attn) |
| | x = x + h_attn |
| | |
| | |
| | h_mlp = self.norm2(x) |
| | h_mlp = self.mlp(h_mlp) |
| | x = x + h_mlp |
| | |
| | |
| | return x.permute(1, 2, 0).reshape(b, c, h, w) |
| |
|
| | class LevelBlock(nn.Module): |
| | def __init__(self, channels, num_blocks, block_type='res'): |
| | super().__init__() |
| | self.blocks = nn.ModuleList() |
| | for _ in range(num_blocks): |
| | if block_type == 'transformer': |
| | self.blocks.append(TransformerBlock(channels)) |
| | else: |
| | self.blocks.append(ResBlock(channels)) |
| | |
| | def forward(self, x): |
| | for block in self.blocks: |
| | x = block(x) |
| | return x |
| |
|
| | class AsymmetricResidualUDiT(nn.Module): |
| | def __init__(self, |
| | in_channels=3, |
| | base_channels=128, |
| | patch_size=2, |
| | num_levels=3, |
| | encoder_blocks=3, |
| | decoder_blocks=7, |
| | encoder_transformer_thresh=2, |
| | decoder_transformer_thresh=4, |
| | mid_blocks=16 |
| | ): |
| | super().__init__() |
| | |
| | |
| | self.patch_embed = nn.Conv2d(in_channels, base_channels, |
| | kernel_size=patch_size, stride=patch_size) |
| | |
| | |
| | self.encoders = nn.ModuleList() |
| | curr_channels = base_channels |
| | |
| | for level in range(num_levels): |
| | |
| | use_transformer = level >= encoder_transformer_thresh |
| | |
| | |
| | self.encoders.append( |
| | LevelBlock(curr_channels, encoder_blocks, use_transformer) |
| | ) |
| | |
| | |
| | if level < num_levels - 1: |
| | self.encoders.append( |
| | nn.Conv2d(curr_channels, curr_channels * 2, 1) |
| | ) |
| | curr_channels *= 2 |
| | |
| | |
| | self.middle = nn.ModuleList([ |
| | TransformerBlock(curr_channels) for _ in range(mid_blocks) |
| | ]) |
| | |
| | |
| | self.decoders = nn.ModuleList() |
| | |
| | for level in range(num_levels): |
| | |
| | use_transformer = level <= decoder_transformer_thresh |
| | |
| | |
| | self.decoders.append( |
| | LevelBlock(curr_channels, decoder_blocks, use_transformer) |
| | ) |
| | |
| | |
| | |
| | if level < num_levels - 1: |
| | self.decoders.append( |
| | nn.Conv2d(curr_channels, curr_channels // 2, 1) |
| | ) |
| | curr_channels //= 2 |
| | |
| | |
| | self.final_proj = nn.ConvTranspose2d(base_channels, in_channels, |
| | kernel_size=patch_size, stride=patch_size) |
| | |
| | def downsample(self, x): |
| | return F.avg_pool2d(x, kernel_size=2) |
| | |
| | def upsample(self, x): |
| | return F.interpolate(x, scale_factor=2, mode='nearest') |
| | |
| | def forward(self, x, t=None): |
| | |
| | x = self.patch_embed(x) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | residuals = [] |
| | curr_res = x |
| | |
| | |
| | h = x |
| | for i, blocks in enumerate(self.encoders): |
| | if isinstance(blocks, LevelBlock): |
| | h = blocks(h) |
| | else: |
| | |
| | residuals.append(curr_res) |
| | |
| | h = self.downsample(blocks(h)) |
| | curr_res = h |
| | |
| | |
| | x = h |
| | for block in self.middle: |
| | x = block(x) |
| | |
| | |
| | x = x - curr_res |
| | |
| | |
| | for i, blocks in enumerate(self.decoders): |
| | if isinstance(blocks, LevelBlock): |
| | x = blocks(x) |
| | else: |
| | |
| | x = blocks(x) |
| | |
| | x = self.upsample(x) |
| | |
| | curr_res = residuals.pop() |
| | x = x + curr_res |
| | |
| | |
| | x = self.final_proj(x) |
| | |
| | return x |