import functools import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from collections import OrderedDict # 小波分解相关代码 from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT) class WaveletDecompose(nn.Module): def __init__(self, mode='haar'): super().__init__() self.xfm = DWTForward(J=1, wave=mode, mode='reflect') def forward(self, x): """ 将一层小波分解结果转换为通道拼接格式 Args: x: 输入张量,形状为 (B, C, H, W) Returns: output: 拼接后的张量,形状为 (B, 4*C, H//2, W//2) 通道顺序: [LL, HL, LH, HH] """ yl, yh = self.xfm(x) # yl: (B, C, H//2, W//2) - LL子带 # yh[0]: (B, C, 3, H//2, W//2) - 高频系数 # 提取三个方向的高频系数 hl = yh[0][:, :, 0, :, :] # HL: 水平细节 lh = yh[0][:, :, 1, :, :] # LH: 垂直细节 hh = yh[0][:, :, 2, :, :] # HH: 对角细节 # 沿通道维度拼接 output = torch.cat([yl, hl, lh, hh], dim=1) return output class WaveletReconstruct(nn.Module): def __init__(self, mode='haar'): super().__init__() self.ifm = DWTInverse(wave=mode, mode='reflect') def forward(self, x): """ 将通道拼接的小波系数还原为原始图像 Args: x: 输入张量,形状为 (B, 4*C, H, W) Returns: 重构后的图像,形状为 (B, C, 2*H, 2*W) """ batch_size, total_channels, height, width = x.shape channels = total_channels // 4 # 分割通道 yl = x[:, :channels, :, :] # LL hl = x[:, channels:2*channels, :, :] # HL lh = x[:, 2*channels:3*channels, :, :] # LH hh = x[:, 3*channels:4*channels, :, :] # HH # 重新组织为 pytorch_wavelets 需要的格式 # 创建 yh 列表,第一个元素是形状为 (B, C, 3, H, W) 的张量 yh_coeff = torch.stack([hl, lh, hh], dim=2) # 在dim=2上堆叠 yh = [yh_coeff] # 必须放在列表中 # 执行逆变换 reconstructed = self.ifm((yl, yh)) return reconstructed def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block) return nn.Sequential(*layers) class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) class Module_with_Init(nn.Module): def __init__(self,): super().__init__() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0.0, 0.02) if m.bias is not None: m.bias.data.normal_(0.0, 0.02) if isinstance(m, nn.ConvTranspose2d): m.weight.data.normal_(0.0, 0.02) def lrelu(self, x): outt = torch.max(0.2*x, x) return outt class ResConvBlock_CBAM(nn.Module): def __init__(self, in_nc, nf=64, res_scale=1): super().__init__() self.res_scale = res_scale self.conv1 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.cbam = CBAM(nf) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.conv1(x)) out = self.res_scale * self.cbam(self.relu(self.conv2(x))) + x return x + out * self.res_scale class ResidualBlockNoBN(nn.Module): """Residual block without BN. It has a style of: ---Conv-ReLU-Conv-+- |________________| Args: nf (int): Channel number of intermediate features. Default: 64. res_scale (float): Residual scale. Default: 1. """ def __init__(self, nf=64, res_scale=1): super(ResidualBlockNoBN, self).__init__() self.res_scale = res_scale self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.relu = nn.ReLU() def forward(self, x): identity = x out = self.conv2(self.relu(self.conv1(x))) return identity + out * self.res_scale def conv1x1(in_nc, out_nc, groups=1): return nn.Conv2d(in_nc, out_nc,kernel_size=1,groups=groups,stride=1) class Identity(nn.Identity): def __init__(self, args): super().__init__() class ResidualBlock3D(nn.Module): def __init__(self, in_c, out_c, is_activate=True): super().__init__() self.activation = nn.ReLU(inplace=True) if is_activate else nn.Sequential() self.block = nn.Sequential( nn.Conv3d(in_c, out_c, kernel_size=3, padding=1, stride=1), self.activation, nn.Conv3d(out_c, out_c, kernel_size=3, padding=1, stride=1) ) if in_c != out_c: self.short_cut = nn.Sequential( nn.Conv3d(in_c, out_c, kernel_size=1, padding=0, stride=1) ) else: self.short_cut = nn.Sequential(OrderedDict([])) def forward(self, x): output = self.block(x) output += self.short_cut(x) output = self.activation(output) return output class conv3x3(nn.Module): def __init__(self, in_nc, out_nc, stride=2, is_activate=True): super().__init__() self.conv =nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1, stride=stride) if is_activate: self.conv.add_module("relu", nn.ReLU(inplace=True)) def forward(self, x): return self.conv(x) class convWithBN(nn.Module): def __init__(self, in_c, out_c, kernel_size=3, padding=1, stride=1, is_activate=True, is_bn=True): super(convWithBN, self).__init__() self.conv = nn.Sequential(OrderedDict([ ("conv", nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)), ])) if is_bn: self.conv.add_module("BN", nn.BatchNorm2d(out_c)) if is_activate: self.conv.add_module("relu", nn.ReLU(inplace=True)) def forward(self, x): return self.conv(x) class DoubleCvBlock(nn.Module): def __init__(self, in_c, out_c): super(DoubleCvBlock, self).__init__() self.block = nn.Sequential( convWithBN(in_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False), convWithBN(out_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False) ) def forward(self, x): output = self.block(x) return output class nResBlocks(nn.Module): def __init__(self, nf, nlayers=2): super().__init__() self.blocks = make_layer(ResidualBlock(nf, nf), n_layers=nlayers) def forward(self, x): return self.blocks(x) class GuidedResidualBlock(nn.Module): def __init__(self, in_c, out_c, is_activate=False): super().__init__() # self.norm = nn.LayerNorm(out_c) self.act = nn.SiLU() self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) self.gamma = nn.Sequential( conv1x1(1, out_c), nn.SiLU(), conv1x1(out_c, out_c), ) self.beta = nn.Sequential( nn.SiLU(), conv1x1(out_c, out_c), ) if in_c != out_c: self.short_cut = nn.Sequential( conv1x1(in_c, out_c) ) else: self.short_cut = nn.Sequential(OrderedDict([])) def forward(self, x, t): if len(t.shape) > 0 and t.shape[-1] != 1: t = F.interpolate(t, size=x.shape[2:], mode='bilinear', align_corners=False) x = self.short_cut(x) z = self.act(x) z = self.conv1(z) tk = self.gamma(t) tb = self.beta(tk) z = z * tk + tb z = self.act(z) z = self.conv2(z) z += x return z class GuidedConvBlock(nn.Module): def __init__(self, in_c, out_c, is_activate=False): super().__init__() # self.norm = nn.LayerNorm(out_c) self.act = nn.SiLU() self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) self.gamma = nn.Sequential( conv1x1(1, out_c), nn.SiLU(), conv1x1(out_c, out_c), ) self.beta = nn.Sequential( nn.SiLU(), conv1x1(out_c, out_c), ) if in_c != out_c: self.short_cut = nn.Sequential( conv1x1(in_c, out_c) ) else: self.short_cut = nn.Sequential(OrderedDict([])) def forward(self, x, t): x = self.short_cut(x) z = self.act(x) z = self.conv1(z) tk = self.gamma(t) tb = self.beta(tk) z = z * tk + tb z = self.act(z) z = self.conv2(z) return z class SNR_Block(nn.Module): def __init__(self, in_c, out_c, is_activate=False): super().__init__() # self.norm = nn.LayerNorm(out_c) self.act = nn.SiLU() self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) self.sfm1 = nn.Sequential( conv1x1(1, out_c), nn.SiLU(), conv1x1(out_c, out_c), ) self.sfm2 = nn.Sequential( conv1x1(1, out_c), nn.SiLU(), conv1x1(out_c, out_c), ) if in_c != out_c: self.short_cut = nn.Sequential( conv1x1(in_c, out_c) ) else: self.short_cut = nn.Sequential(OrderedDict([])) def forward(self, x, t): x = self.short_cut(x) z = self.act(x) z = self.conv1(z) a1 = self.sfm1(t) z *= a1 z = self.act(z) z = self.conv2(z) a2 = self.sfm2(t) z *= a2 z += x return z class ResBlock(nn.Module): def __init__(self, in_c, out_c, is_activate=False): super().__init__() # self.norm = nn.LayerNorm(out_c) self.act = nn.LeakyReLU(0.2) if is_activate else nn.SiLU() self.conv1 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=True) self.gamma = nn.Sequential( conv1x1(1, out_c), self.act, conv1x1(out_c, out_c), ) self.beta = nn.Sequential( self.act, conv1x1(out_c, out_c), ) if in_c != out_c: self.short_cut = nn.Sequential( conv1x1(in_c, out_c) ) else: self.short_cut = nn.Sequential(OrderedDict([])) def forward(self, x): x = self.short_cut(x) z = self.act(x) z = self.conv1(z) z = self.act(z) z = self.conv2(z) z += x return z class ResidualBlock(nn.Module): def __init__(self, in_c, out_c, is_activate=True): super(ResidualBlock, self).__init__() self.block = nn.Sequential( convWithBN(in_c, out_c, kernel_size=3, padding=1, stride=1, is_bn=False), convWithBN(out_c, out_c, kernel_size=3, padding=1, stride=1, is_activate=False, is_bn=False) ) if in_c != out_c: self.short_cut = nn.Sequential( convWithBN(in_c, out_c, kernel_size=1, padding=0, stride=1, is_activate=False, is_bn=False) ) else: self.short_cut = nn.Sequential(OrderedDict([])) self.activation = nn.LeakyReLU(0.2, inplace=False) if is_activate else nn.Sequential() def forward(self, x): output = self.block(x) output = self.activation(output) output += self.short_cut(x) return output class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super().__init__() self.in_nc = in_planes self.ratio = ratio self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.sharedMLP = nn.Sequential( nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) self.sigmoid = nn.Sigmoid() def forward(self, x): avgout = self.sharedMLP(self.avg_pool(x)) maxout = self.sharedMLP(self.max_pool(x)) return self.sigmoid(avgout + maxout) class SpatialAttention(nn.Module): def __init__(self, kernel_size=3): super().__init__() self.conv = nn.Conv2d(2,1,kernel_size, padding=1, bias=False) self.sigmoid = nn.Sigmoid() self.concat = Concat() self.mean = torch.mean self.max = torch.max def forward(self, x): avgout = self.mean(x, 1, True) maxout, _ = self.max(x, 1, True) x = self.concat([avgout, maxout], 1) x = self.conv(x) return self.sigmoid(x) class CBAM(nn.Module): def __init__(self, planes): super().__init__() self.ca = ChannelAttention(planes) self.sa = SpatialAttention() def forward(self, x): x = self.ca(x) * x out = self.sa(x) * x return out class MaskMul(nn.Module): def __init__(self, scale_factor=1): super().__init__() self.scale_factor = scale_factor def forward(self, x, mask): if mask.shape[1] != x.shape[1]: mask = torch.mean(mask, dim=1, keepdim=True) pooled_mask = F.avg_pool2d(mask, self.scale_factor) out = torch.mul(x, pooled_mask) return out class UpsampleBLock(nn.Module): def __init__(self, in_channels, out_channels=None, up_scale=2, mode='bilinear'): super(UpsampleBLock, self).__init__() if mode == 'pixel_shuffle': self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) self.up = nn.PixelShuffle(up_scale) elif mode=='bilinear': self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.up = nn.UpsamplingBilinear2d(scale_factor=up_scale) else: print(f"Please tell me what is '{mode}' mode ????") raise NotImplementedError self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.up(x) x = self.relu(x) return x def pixel_unshuffle(input, downscale_factor): ''' input: batchSize * c * k*w * k*h kdownscale_factor: k batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h ''' c = input.shape[1] kernel = torch.zeros(size=[downscale_factor * downscale_factor * c, 1, downscale_factor, downscale_factor], device=input.device) for y in range(downscale_factor): for x in range(downscale_factor): kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1 return F.conv2d(input, kernel, stride=downscale_factor, groups=c) class PixelUnshuffle(nn.Module): def __init__(self, downscale_factor): super(PixelUnshuffle, self).__init__() self.downscale_factor = downscale_factor def forward(self, input): ''' input: batchSize * c * k*w * k*h kdownscale_factor: k batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h ''' return pixel_unshuffle(input, self.downscale_factor) class Concat(nn.Module): def __init__(self, dim=1): super().__init__() self.dim = 1 self.concat = torch.cat def padding(self, tensors): if len(tensors) > 2: return tensors x , y = tensors xb, xc, xh, xw = x.size() yb, yc, yh, yw = y.size() diffY = xh - yh diffX = xw - yw y = F.pad(y, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)) return (x, y) def forward(self, x, dim=None): x = self.padding(x) return self.concat(x, dim if dim is not None else self.dim) # if __name__ == '__main__': # from torchsummary import summary # x = torch.randn((1,32,16,16)) # for k in range(1,3): # # up = upsample(32, 2**k) # # down = downsample(32//(2**k), 2**k) # # x_up = up(x) # # x_down = down(x_up) # # s_up = (32,16,16) # # summary(up,s,device='cpu') # # summary(down,s,device='cpu') # print(k)