| |
| |
| |
|
|
| import numpy as np |
| import os, math, gc |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision as vision |
| import pytorch_lightning as pl |
| from pytorch_lightning.utilities import rank_zero_info, rank_zero_only |
| from pytorch_lightning.strategies import DeepSpeedStrategy |
| import deepspeed |
| from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam |
| |
|
|
| def __nop(ob): |
| return ob |
| MyModule = torch.jit.ScriptModule |
| |
| MyFunction = torch.jit.script_method |
|
|
| import clip |
| from transformers import CLIPModel |
|
|
| class L2pooling(nn.Module): |
| def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): |
| super(L2pooling, self).__init__() |
| self.padding = (filter_size - 2) // 2 |
| self.stride = stride |
| self.channels = channels |
| a = np.hanning(filter_size)[1:-1] |
| g = torch.Tensor(a[:, None] * a[None, :]) |
| g = g / torch.sum(g) |
| self.register_buffer( |
| "filter", g[None, None, :, :].repeat((self.channels, 1, 1, 1)) |
| ) |
|
|
| def forward(self, input): |
| input = input**2 |
| out = F.conv2d( |
| input, |
| self.filter, |
| stride=self.stride, |
| padding=self.padding, |
| groups=input.shape[1], |
| ) |
| return (out + 1e-12).sqrt() |
|
|
|
|
| class DISTS(torch.nn.Module): |
| def __init__(self, load_weights=True): |
| super(DISTS, self).__init__() |
| vgg_pretrained_features = vision.models.vgg16( |
| weights="VGG16_Weights.IMAGENET1K_V1" |
| ).features |
| self.stage1 = torch.nn.Sequential() |
| self.stage2 = torch.nn.Sequential() |
| self.stage3 = torch.nn.Sequential() |
| self.stage4 = torch.nn.Sequential() |
| self.stage5 = torch.nn.Sequential() |
| for x in range(0, 4): |
| self.stage1.add_module(str(x), vgg_pretrained_features[x]) |
| self.stage2.add_module(str(4), L2pooling(channels=64)) |
| for x in range(5, 9): |
| self.stage2.add_module(str(x), vgg_pretrained_features[x]) |
| self.stage3.add_module(str(9), L2pooling(channels=128)) |
| for x in range(10, 16): |
| self.stage3.add_module(str(x), vgg_pretrained_features[x]) |
| self.stage4.add_module(str(16), L2pooling(channels=256)) |
| for x in range(17, 23): |
| self.stage4.add_module(str(x), vgg_pretrained_features[x]) |
| self.stage5.add_module(str(23), L2pooling(channels=512)) |
| for x in range(24, 30): |
| self.stage5.add_module(str(x), vgg_pretrained_features[x]) |
|
|
| self.register_buffer( |
| "mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1) |
| ) |
| self.register_buffer( |
| "std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1) |
| ) |
|
|
| self.chns = [3, 64, 128, 256, 512, 512] |
| self.register_buffer( |
| "alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)) |
| ) |
| self.register_buffer("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))) |
| self.alpha.data.normal_(0.1, 0.01) |
| self.beta.data.normal_(0.1, 0.01) |
| weights = torch.load("test/DISTS_weights.pt") |
| self.alpha.data = weights["alpha"] |
| self.beta.data = weights["beta"] |
|
|
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| def forward_once(self, x): |
| h = (x - self.mean) / self.std |
| h = self.stage1(h) |
| h_relu1_2 = h |
| h = self.stage2(h) |
| h_relu2_2 = h |
| h = self.stage3(h) |
| h_relu3_3 = h |
| h = self.stage4(h) |
| h_relu4_3 = h |
| h = self.stage5(h) |
| h_relu5_3 = h |
| return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] |
|
|
| def forward(self, x, y, require_grad=False, batch_average=False): |
| if require_grad: |
| feats0 = self.forward_once(x) |
| feats1 = self.forward_once(y) |
| else: |
| with torch.no_grad(): |
| feats0 = self.forward_once(x) |
| feats1 = self.forward_once(y) |
| dist1 = 0 |
| dist2 = 0 |
| c1 = 1e-6 |
| c2 = 1e-6 |
| w_sum = self.alpha.sum() + self.beta.sum() |
| alpha = torch.split(self.alpha / w_sum, self.chns, dim=1) |
| beta = torch.split(self.beta / w_sum, self.chns, dim=1) |
|
|
| for k in range(len(self.chns)): |
| x_mean = feats0[k].mean([2, 3], keepdim=True) |
| y_mean = feats1[k].mean([2, 3], keepdim=True) |
| S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1) |
| dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True) |
|
|
| x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True) |
| y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True) |
| xy_cov = (feats0[k] * feats1[k]).mean( |
| [2, 3], keepdim=True |
| ) - x_mean * y_mean |
| S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) |
| dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) |
|
|
| score = 1 - (dist1 + dist2).squeeze() |
|
|
| if batch_average: |
| return score.mean() |
| else: |
| return score |
|
|
| class ToBinary(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| |
| |
| |
| |
| |
| return torch.floor(x + 0.5) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output.clone() |
|
|
| |
|
|
| class R_ENCODER(MyModule): |
| def __init__(self, args): |
| super().__init__() |
| self.args = args |
| dd = 8 |
| self.Bxx = nn.BatchNorm2d(dd*64) |
|
|
| self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) |
| self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) |
| self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) |
|
|
| self.B00 = nn.BatchNorm2d(dd*4) |
| self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
| self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
| self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
| self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
|
|
| self.B10 = nn.BatchNorm2d(dd*16) |
| self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
| self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
| self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
| self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
|
|
| self.B20 = nn.BatchNorm2d(dd*64) |
| self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
| self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
| self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
| self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
| |
| |
| |
| |
| |
|
|
| self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) |
|
|
| @MyFunction |
| def forward(self, img): |
| ACT = F.mish |
|
|
| x = self.CIN(img) |
| xx = self.Bxx(F.pixel_unshuffle(x, 8)) |
| x = x + self.Cx1(ACT(self.Cx0(x))) |
|
|
| x = F.pixel_unshuffle(x, 2) |
| x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) |
| x = x + self.C03(ACT(self.C02(x))) |
|
|
| x = F.pixel_unshuffle(x, 2) |
| x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) |
| x = x + self.C13(ACT(self.C12(x))) |
|
|
| x = F.pixel_unshuffle(x, 2) |
| x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) |
| x = x + self.C23(ACT(self.C22(x))) |
| |
| |
|
|
| x = self.COUT(x + xx) |
| return torch.sigmoid(x) |
|
|
| |
|
|
| class R_DECODER(MyModule): |
| def __init__(self, args): |
| super().__init__() |
| self.args = args |
| dd = 8 |
| self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) |
|
|
| self.B00 = nn.BatchNorm2d(dd*64) |
| self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
| self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
| self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) |
| self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) |
| |
| |
| |
| |
| |
|
|
| self.B10 = nn.BatchNorm2d(dd*16) |
| self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
| self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
| self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) |
| self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) |
|
|
| self.B20 = nn.BatchNorm2d(dd*4) |
| self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
| self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
| self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) |
| self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) |
|
|
| self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) |
| self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) |
| self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) |
|
|
| @MyFunction |
| def forward(self, code): |
| ACT = F.mish |
| x = self.CIN(code) |
|
|
| x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) |
| x = x + self.C03(ACT(self.C02(x))) |
| |
| |
| x = F.pixel_shuffle(x, 2) |
|
|
| x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) |
| x = x + self.C13(ACT(self.C12(x))) |
| x = F.pixel_shuffle(x, 2) |
|
|
| x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) |
| x = x + self.C23(ACT(self.C22(x))) |
| x = F.pixel_shuffle(x, 2) |
|
|
| x = x + self.Cx1(ACT(self.Cx0(x))) |
| x = self.COUT(x) |
| |
| return torch.sigmoid(x) |
|
|
| |
|
|
| def cosine_loss(x, y): |
| x = F.normalize(x, dim=-1) |
| y = F.normalize(y, dim=-1) |
| return 1 - torch.einsum('ij,ij->i',[x,y]) |
|
|
| class RWKV_IMG(pl.LightningModule): |
| def __init__(self, args): |
| super().__init__() |
| self.args = args |
| |
| self.encoder = R_ENCODER(args) |
| self.decoder = R_DECODER(args) |
|
|
| self.clip_model = None |
| clip_name = args.my_img_clip |
| if clip_name == 'B32': |
| clip_name = 'ViT-B/32' |
| elif clip_name == 'B16': |
| clip_name = 'ViT-B/16' |
| elif clip_name == 'L14': |
| clip_name = 'ViT-L/14' |
| elif clip_name == 'OB32': |
| clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" |
| self.clip_model = CLIPModel.from_pretrained(clip_name) |
| self.clip_model.encode_image = self.clip_model.get_image_features |
| if self.clip_model == None: |
| self.clip_model, _ = clip.load(clip_name, jit = True) |
| self.register_buffer( |
| "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1) |
| ) |
| self.register_buffer( |
| "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1) |
| ) |
|
|
| for n, p in self.named_parameters(): |
| if 'clip_model' in n: |
| p.requires_grad = False |
|
|
| self.loss_dists = DISTS() |
| |
|
|
| def configure_optimizers(self): |
| args = self.args |
| optim_groups = [ |
| {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, |
| ] |
| if self.deepspeed_offload: |
| return DeepSpeedCPUAdam( |
| optim_groups, |
| lr=self.args.lr_init, |
| betas=self.args.betas, |
| eps=self.args.adam_eps, |
| bias_correction=True, |
| adamw_mode=False, |
| weight_decay=0, |
| amsgrad=False, |
| ) |
| return FusedAdam( |
| optim_groups, |
| lr=self.args.lr_init, |
| betas=self.args.betas, |
| eps=self.args.adam_eps, |
| bias_correction=True, |
| adam_w_mode=False, |
| weight_decay=0, |
| amsgrad=False, |
| ) |
| |
|
|
| @property |
| def deepspeed_offload(self) -> bool: |
| strategy = self.trainer.strategy |
| if isinstance(strategy, DeepSpeedStrategy): |
| config = strategy.config["zero_optimization"] |
| return config.get("offload_optimizer") or config.get("offload_param") |
| return False |
|
|
| def forward(self, img): |
| z = self.encoder(img) |
| z = ToBinary.apply(z) |
| out = self.decoder(z) |
| return out |
|
|
| def training_step(self, batch, batch_idx): |
| args = self.args |
| img, txt = batch |
| out = self(img) |
| if self.trainer.is_global_zero: |
| if (self.trainer.global_step + 1) % (100 * int(args.devices)) == 0: |
| img_dir = f"test/image_model/{args.run_name}" |
| if not os.path.exists(img_dir): |
| os.makedirs(img_dir) |
| vision.utils.save_image( |
| img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg" |
| ) |
| vision.utils.save_image( |
| out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg" |
| ) |
|
|
| |
| loss_dists = self.loss_dists(out, img, require_grad=True, batch_average=True) |
|
|
| iii = self.clip_model.encode_image((img - self.clip_mean) / self.clip_std) |
| ooo = self.clip_model.encode_image((out - self.clip_mean) / self.clip_std) |
| loss_clip = torch.mean(cosine_loss(iii, ooo)) |
|
|
| if args.my_img_l1_scale > 0: |
| loss_l1 = F.l1_loss(out, img) |
| return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale |
| else: |
| return loss_dists + loss_clip * args.my_img_clip_scale |
|
|
| def training_step_end(self, batch_parts): |
| all = self.all_gather(batch_parts) |
| if self.trainer.is_global_zero: |
| self.trainer.my_loss_all = all |
|
|
| def generate_init_weight(self): |
| print( |
| f""" |
| ############################################################################ |
| # |
| # Init model weight (slow for large models)... |
| # |
| ############################################################################ |
| """ |
| ) |
| m = {} |
| for n in self.state_dict(): |
| scale = 1 |
| p = self.state_dict()[n] |
| shape = p.shape |
| ss = n.split('.') |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| m[n] = p |
|
|
| m[n] = m[n].cpu() |
| if os.environ["RWKV_FLOAT_MODE"] == "fp16": |
| m[n] = m[n].half() |
| elif os.environ["RWKV_FLOAT_MODE"] == "bf16": |
| m[n] = m[n].bfloat16() |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
| return m |
|
|