| |
|
|
| from typing import Sequence, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer |
|
|
|
|
| class UNesTBlock(nn.Module): |
| """ """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: Union[Sequence[int], int], |
| stride: Union[Sequence[int], int], |
| upsample_kernel_size: Union[Sequence[int], int], |
| norm_name: Union[Tuple, str], |
| res_block: bool = False, |
| ) -> None: |
| """ |
| Args: |
| spatial_dims: number of spatial dimensions. |
| in_channels: number of input channels. |
| out_channels: number of output channels. |
| kernel_size: convolution kernel size. |
| stride: convolution stride. |
| upsample_kernel_size: convolution kernel size for transposed convolution layers. |
| norm_name: feature normalization type and arguments. |
| res_block: bool argument to determine if residual block is used. |
| |
| """ |
|
|
| super(UNesTBlock, self).__init__() |
| upsample_stride = upsample_kernel_size |
| self.transp_conv = get_conv_layer( |
| spatial_dims, |
| in_channels, |
| out_channels, |
| kernel_size=upsample_kernel_size, |
| stride=upsample_stride, |
| conv_only=True, |
| is_transposed=True, |
| ) |
|
|
| if res_block: |
| self.conv_block = UnetResBlock( |
| spatial_dims, |
| out_channels + out_channels, |
| out_channels, |
| kernel_size=kernel_size, |
| stride=1, |
| norm_name=norm_name, |
| ) |
| else: |
| self.conv_block = UnetBasicBlock( |
| spatial_dims, |
| out_channels + out_channels, |
| out_channels, |
| kernel_size=kernel_size, |
| stride=1, |
| norm_name=norm_name, |
| ) |
|
|
| def forward(self, inp, skip): |
| |
| out = self.transp_conv(inp) |
| |
| |
| out = torch.cat((out, skip), dim=1) |
| out = self.conv_block(out) |
| return out |
|
|
|
|
| class UNestUpBlock(nn.Module): |
| """ """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| num_layer: int, |
| kernel_size: Union[Sequence[int], int], |
| stride: Union[Sequence[int], int], |
| upsample_kernel_size: Union[Sequence[int], int], |
| norm_name: Union[Tuple, str], |
| conv_block: bool = False, |
| res_block: bool = False, |
| ) -> None: |
| """ |
| Args: |
| spatial_dims: number of spatial dimensions. |
| in_channels: number of input channels. |
| out_channels: number of output channels. |
| num_layer: number of upsampling blocks. |
| kernel_size: convolution kernel size. |
| stride: convolution stride. |
| upsample_kernel_size: convolution kernel size for transposed convolution layers. |
| norm_name: feature normalization type and arguments. |
| conv_block: bool argument to determine if convolutional block is used. |
| res_block: bool argument to determine if residual block is used. |
| |
| """ |
|
|
| super().__init__() |
|
|
| upsample_stride = upsample_kernel_size |
| self.transp_conv_init = get_conv_layer( |
| spatial_dims, |
| in_channels, |
| out_channels, |
| kernel_size=upsample_kernel_size, |
| stride=upsample_stride, |
| conv_only=True, |
| is_transposed=True, |
| ) |
| if conv_block: |
| if res_block: |
| self.blocks = nn.ModuleList( |
| [ |
| nn.Sequential( |
| get_conv_layer( |
| spatial_dims, |
| out_channels, |
| out_channels, |
| kernel_size=upsample_kernel_size, |
| stride=upsample_stride, |
| conv_only=True, |
| is_transposed=True, |
| ), |
| UnetResBlock( |
| spatial_dims=3, |
| in_channels=out_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| norm_name=norm_name, |
| ), |
| ) |
| for i in range(num_layer) |
| ] |
| ) |
| else: |
| self.blocks = nn.ModuleList( |
| [ |
| nn.Sequential( |
| get_conv_layer( |
| spatial_dims, |
| out_channels, |
| out_channels, |
| kernel_size=upsample_kernel_size, |
| stride=upsample_stride, |
| conv_only=True, |
| is_transposed=True, |
| ), |
| UnetBasicBlock( |
| spatial_dims=3, |
| in_channels=out_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| norm_name=norm_name, |
| ), |
| ) |
| for i in range(num_layer) |
| ] |
| ) |
| else: |
| self.blocks = nn.ModuleList( |
| [ |
| get_conv_layer( |
| spatial_dims, |
| out_channels, |
| out_channels, |
| kernel_size=1, |
| stride=1, |
| conv_only=True, |
| is_transposed=True, |
| ) |
| for i in range(num_layer) |
| ] |
| ) |
|
|
| def forward(self, x): |
| x = self.transp_conv_init(x) |
| for blk in self.blocks: |
| x = blk(x) |
| return x |
|
|
|
|
| class UNesTConvBlock(nn.Module): |
| """ |
| UNesT block with skip connections |
| """ |
|
|
| def __init__( |
| self, |
| spatial_dims: int, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: Union[Sequence[int], int], |
| stride: Union[Sequence[int], int], |
| norm_name: Union[Tuple, str], |
| res_block: bool = False, |
| ) -> None: |
| """ |
| Args: |
| spatial_dims: number of spatial dimensions. |
| in_channels: number of input channels. |
| out_channels: number of output channels. |
| kernel_size: convolution kernel size. |
| stride: convolution stride. |
| norm_name: feature normalization type and arguments. |
| res_block: bool argument to determine if residual block is used. |
| |
| """ |
|
|
| super().__init__() |
|
|
| if res_block: |
| self.layer = UnetResBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| norm_name=norm_name, |
| ) |
| else: |
| self.layer = UnetBasicBlock( |
| spatial_dims=spatial_dims, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| norm_name=norm_name, |
| ) |
|
|
| def forward(self, inp): |
| out = self.layer(inp) |
| return out |
|
|