from .modules import * def data_normalize(data, lower=None, upper=None): lower = 0#torch.tensor([data[b].min() for b in range(data.shape[0])], #dtype=data.dtype, device=data.device).view(-1,1,1,1) # upper = torch.tensor([data[b].max() for b in range(data.shape[0])], # dtype=data.dtype, device=data.device).view(-1,1,1,1) upper = torch.amax(data, dim=(1,2,3), keepdim=True).clip(1e-5, 1) # 不会暗到1e-5这么逆天吧…… data = (data - lower) / (upper - lower) return data, lower, upper def data_inv_normalize(data, lower, upper): data = data * (upper - lower) + lower return data # SID Unet class UNetSeeInDark(nn.Module): def __init__(self, args=None): super().__init__() self.args = args self.nframes = args['nframes'] self.cf = 0 self.res = args['res'] self.norm = args['norm'] if 'norm' in args else False nframes = self.args['nframes'] if 'nframes' in args else 1 nf = args['nf'] in_nc = args['in_nc'] out_nc = args['out_nc'] self.conv1_1 = nn.Conv2d(in_nc*nframes, nf, kernel_size=3, stride=1, padding=1) self.conv1_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2_1 = nn.Conv2d(nf, nf*2, kernel_size=3, stride=1, padding=1) self.conv2_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3_1 = nn.Conv2d(nf*2, nf*4, kernel_size=3, stride=1, padding=1) self.conv3_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1) self.pool3 = nn.MaxPool2d(kernel_size=2) self.conv4_1 = nn.Conv2d(nf*4, nf*8, kernel_size=3, stride=1, padding=1) self.conv4_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1) self.pool4 = nn.MaxPool2d(kernel_size=2) self.conv5_1 = nn.Conv2d(nf*8, nf*16, kernel_size=3, stride=1, padding=1) self.conv5_2 = nn.Conv2d(nf*16, nf*16, kernel_size=3, stride=1, padding=1) self.upv6 = nn.ConvTranspose2d(nf*16, nf*8, 2, stride=2) self.conv6_1 = nn.Conv2d(nf*16, nf*8, kernel_size=3, stride=1, padding=1) self.conv6_2 = nn.Conv2d(nf*8, nf*8, kernel_size=3, stride=1, padding=1) self.upv7 = nn.ConvTranspose2d(nf*8, nf*4, 2, stride=2) self.conv7_1 = nn.Conv2d(nf*8, nf*4, kernel_size=3, stride=1, padding=1) self.conv7_2 = nn.Conv2d(nf*4, nf*4, kernel_size=3, stride=1, padding=1) self.upv8 = nn.ConvTranspose2d(nf*4, nf*2, 2, stride=2) self.conv8_1 = nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1) self.conv8_2 = nn.Conv2d(nf*2, nf*2, kernel_size=3, stride=1, padding=1) self.upv9 = nn.ConvTranspose2d(nf*2, nf, 2, stride=2) self.conv9_1 = nn.Conv2d(nf*2, nf, kernel_size=3, stride=1, padding=1) self.conv9_2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1) self.conv10_1 = nn.Conv2d(nf, out_nc, kernel_size=1, stride=1) self.relu = nn.LeakyReLU(0.2, inplace=True) def forward(self, x): if self.norm: x, lb, ub = data_normalize(x) conv1 = self.relu(self.conv1_1(x)) conv1 = self.relu(self.conv1_2(conv1)) pool1 = self.pool1(conv1) conv2 = self.relu(self.conv2_1(pool1)) conv2 = self.relu(self.conv2_2(conv2)) pool2 = self.pool1(conv2) conv3 = self.relu(self.conv3_1(pool2)) conv3 = self.relu(self.conv3_2(conv3)) pool3 = self.pool1(conv3) conv4 = self.relu(self.conv4_1(pool3)) conv4 = self.relu(self.conv4_2(conv4)) pool4 = self.pool1(conv4) conv5 = self.relu(self.conv5_1(pool4)) conv5 = self.relu(self.conv5_2(conv5)) up6 = self.upv6(conv5) up6 = torch.cat([up6, conv4], 1) conv6 = self.relu(self.conv6_1(up6)) conv6 = self.relu(self.conv6_2(conv6)) up7 = self.upv7(conv6) up7 = torch.cat([up7, conv3], 1) conv7 = self.relu(self.conv7_1(up7)) conv7 = self.relu(self.conv7_2(conv7)) up8 = self.upv8(conv7) up8 = torch.cat([up8, conv2], 1) conv8 = self.relu(self.conv8_1(up8)) conv8 = self.relu(self.conv8_2(conv8)) up9 = self.upv9(conv8) up9 = torch.cat([up9, conv1], 1) conv9 = self.relu(self.conv9_1(up9)) conv9 = self.relu(self.conv9_2(conv9)) out = self.conv10_1(conv9) if self.res: out = out + x[:, self.cf*4:self.cf*4+4] if self.norm: out = data_inv_normalize(out, lb, ub) return out def get_updown_module(nf, updown_type='conv', mode='up'): if updown_type == 'conv': if mode == 'down': return conv3x3(nf, nf*2) elif mode == 'up': return nn.ConvTranspose2d(nf, nf//2, 2, stride=2) elif updown_type in ['bilinear', 'bicubic', 'nearest']: if mode == 'down': return nn.Sequential( nn.Upsample(1/2, mode=updown_type), nn.Conv2d(nf, nf*2, kernel_size=3, stride=1, padding=1), ) if mode == 'up': return nn.Sequential( nn.Upsample(2, mode=updown_type), nn.Conv2d(nf, nf//2, kernel_size=3, stride=1, padding=1), ) elif updown_type == 'shuffle': if mode == 'down': return nn.Sequential( nn.PixelUnshuffle(2), nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1), ) if mode == 'up': return nn.Sequential( nn.PixelShuffle(2), nn.Conv2d(nf//4, nf//2, kernel_size=3, stride=1, padding=1), ) elif updown_type in ['haar','db1','db2','db3']: if mode == 'down': return nn.Sequential( WaveletDecompose(updown_type), nn.Conv2d(nf*4, nf*2, kernel_size=3, stride=1, padding=1), ) if mode == 'up': return nn.Sequential( WaveletReconstruct(updown_type), nn.Conv2d(nf//4, nf//2, kernel_size=3, stride=1, padding=1), ) class GuidedResUnet(nn.Module): def __init__(self, args=None): super().__init__() self.args = args self.cf = 0 self.nframes = nframes = args.get('nframes', 1) self.res = args.get('res', False) self.norm = args.get('norm', False) self.updown_type = args.get('updown_type', 'conv') self.downsample = args.get('downsample', False) if self.downsample == 'shuffle': self.down_fn = nn.PixelUnshuffle(2) self.up_fn = nn.PixelShuffle(2) elif self.downsample != False: self.down_fn = WaveletDecompose(mode=self.downsample) self.up_fn = WaveletReconstruct(mode=self.downsample) ext = 4 if self.downsample else 1 nf = args.get('nf', 32) in_nc = args.get('in_nc', 4) out_nc = args.get('out_nc', 4) self.conv_in = nn.Conv2d(in_nc*nframes*ext, nf, kernel_size=3, stride=1, padding=1) self.conv1 = GuidedResidualBlock(nf, nf, is_activate=False) self.pool1 = get_updown_module(nf, self.updown_type, mode='down') self.conv2 = GuidedResidualBlock(nf*2, nf*2, is_activate=False) self.pool2 = get_updown_module(nf*2, self.updown_type, mode='down') self.conv3 = GuidedResidualBlock(nf*4, nf*4, is_activate=False) self.pool3 = get_updown_module(nf*4, self.updown_type, mode='down') self.conv4 = GuidedResidualBlock(nf*8, nf*8, is_activate=False) self.pool4 = get_updown_module(nf*8, self.updown_type, mode='down') self.conv5 = GuidedResidualBlock(nf*16, nf*16, is_activate=False) self.upv6 = get_updown_module(nf*16, self.updown_type, mode='up') self.conv6 = GuidedResidualBlock(nf*16, nf*8, is_activate=False) self.upv7 = get_updown_module(nf*8, self.updown_type, mode='up') self.conv7 = GuidedResidualBlock(nf*8, nf*4, is_activate=False) self.upv8 = get_updown_module(nf*4, self.updown_type, mode='up') self.conv8 = GuidedResidualBlock(nf*4, nf*2, is_activate=False) self.upv9 = get_updown_module(nf*2, self.updown_type, mode='up') self.conv9 = GuidedResidualBlock(nf*2, nf, is_activate=False) self.conv10 = nn.Conv2d(nf, out_nc*ext, kernel_size=1, stride=1) self.lrelu = nn.LeakyReLU(0.01, inplace=True) def forward(self, x, t): # shape= x.size() # x = x.view(-1,shape[-3],shape[-2],shape[-1]) if self.norm: x, lb, ub = data_normalize(x) t = t / (ub-lb) if self.downsample: x = self.down_fn(x) conv_in = self.lrelu(self.conv_in(x)) conv1 = self.conv1(conv_in, t) pool1 = self.pool1(conv1) conv2 = self.conv2(pool1, t) pool2 = self.pool2(conv2) conv3 = self.conv3(pool2, t) pool3 = self.pool3(conv3) conv4 = self.conv4(pool3, t) pool4 = self.pool4(conv4) conv5 = self.conv5(pool4, t) up6 = self.upv6(conv5) up6 = torch.cat([up6, conv4], 1) conv6 = self.conv6(up6, t) up7 = self.upv7(conv6) up7 = torch.cat([up7, conv3], 1) conv7 = self.conv7(up7, t) up8 = self.upv8(conv7) up8 = torch.cat([up8, conv2], 1) conv8 = self.conv8(up8, t) up9 = self.upv9(conv8) up9 = torch.cat([up9, conv1], 1) conv9 = self.conv9(up9, t) out = self.conv10(conv9) if self.res: out = out + x[:, self.cf*4:self.cf*4+4] if self.downsample: out = self.up_fn(out) if self.norm: out = data_inv_normalize(out, lb, ub) return out