| import os |
| import torch |
| from collections import OrderedDict |
| from abc import ABC, abstractmethod |
| from . import networks |
| import numpy as np |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| class BaseModel(ABC): |
| """This class is an abstract base class (ABC) for models. |
| To create a subclass, you need to implement the following five functions: |
| -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). |
| -- <set_input>: unpack data from dataset and apply preprocessing. |
| -- <forward>: produce intermediate results. |
| -- <optimize_parameters>: calculate losses, gradients, and update network weights. |
| -- <modify_commandline_options>: (optionally) add model-specific options and set default options. |
| """ |
|
|
| def __init__(self, opt): |
| """Initialize the BaseModel class. |
| |
| Parameters: |
| opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions |
| |
| When creating your custom class, you need to implement your own initialization. |
| In this fucntion, you should first call `BaseModel.__init__(self, opt)` |
| Then, you need to define four lists: |
| -- self.loss_names (str list): specify the training losses that you want to plot and save. |
| -- self.model_names (str list): specify the images that you want to display and save. |
| -- self.visual_names (str list): define networks used in our training. |
| -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. |
| """ |
| self.opt = opt |
| self.gpu_ids = opt.gpu_ids |
| self.isTrain = opt.isTrain |
| self.iter = 0 |
| self.last_iter = 0 |
| self.device = torch.device('cuda:{}'.format( |
| self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') |
| |
| self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) |
| try: |
| os.mkdir(self.save_dir) |
| except: |
| pass |
| |
| if opt.preprocess != 'scale_width': |
| torch.backends.cudnn.benchmark = True |
| self.loss_names = [] |
| self.model_names = [] |
| self.visual_names = [] |
| self.optimizers = [] |
| self.image_paths = [] |
|
|
| self.label_colours = np.random.randint(255, size=(100,3)) |
|
|
| def save_suppixel(self,l_inds): |
| im_target_rgb = np.array([self.label_colours[ c % 100 ] for c in l_inds]) |
| b,h,w = l_inds.shape |
| im_target_rgb = im_target_rgb.reshape(b,h,w,3).transpose(0,3,1,2)/127.5-1.0 |
| return torch.from_numpy(im_target_rgb) |
|
|
| @staticmethod |
| def modify_commandline_options(parser, is_train): |
| """Add new model-specific options, and rewrite default values for existing options. |
| |
| Parameters: |
| parser -- original option parser |
| is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. |
| |
| Returns: |
| the modified parser. |
| """ |
| return parser |
|
|
| @abstractmethod |
| def set_input(self, input): |
| """Unpack input data from the dataloader and perform necessary pre-processing steps. |
| |
| Parameters: |
| input (dict): includes the data itself and its metadata information. |
| """ |
| pass |
|
|
| @abstractmethod |
| def forward(self): |
| """Run forward pass; called by both functions <optimize_parameters> and <test>.""" |
| pass |
|
|
| def is_train(self): |
| """check if the current batch is good for training.""" |
| return True |
|
|
| @abstractmethod |
| def optimize_parameters(self): |
| """Calculate losses, gradients, and update network weights; called in every training iteration""" |
| pass |
|
|
| def setup(self, opt): |
| """Load and print networks; create schedulers |
| |
| Parameters: |
| opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions |
| """ |
| if self.isTrain: |
| self.schedulers = [networks.get_scheduler( |
| optimizer, opt) for optimizer in self.optimizers] |
| if not self.isTrain or opt.continue_train: |
| self.load_networks(opt.epoch) |
| self.print_networks(opt.verbose) |
|
|
| def eval(self): |
| """Make models eval mode during test time""" |
| for name in self.model_names: |
| if isinstance(name, str): |
| net = getattr(self, 'net' + name) |
| net.eval() |
|
|
| def test(self): |
| """Forward function used in test time. |
| |
| This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop |
| It also calls <compute_visuals> to produce additional visualization results |
| """ |
| with torch.no_grad(): |
| self.forward() |
| self.compute_visuals() |
|
|
| def compute_visuals(self): |
| """Calculate additional output images for visdom and HTML visualization""" |
| pass |
|
|
| def get_image_paths(self): |
| """ Return image paths that are used to load current data""" |
| return self.image_paths |
|
|
| def update_learning_rate(self): |
| """Update learning rates for all the networks; called at the end of every epoch""" |
| for scheduler in self.schedulers: |
| scheduler.step() |
| lr = self.optimizers[0].param_groups[0]['lr'] |
| print('learning rate = %.7f' % lr) |
|
|
| def get_current_visuals(self): |
| """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" |
| visual_ret = OrderedDict() |
| for name in self.visual_names: |
| if isinstance(name, str): |
| if 'Lab' in name: |
| labimg = getattr(self, name).cpu() |
| labimg[:,0,:,:]+=1 |
| labimg[:,0,:,:]*=50 |
| labimg[:,1:,:,:] *= 110 |
| labimg = labimg.permute((0,2,3,1)) |
| for i in range(labimg.shape[0]): |
| labimg[i,:,:,:]=lab2rgb(labimg[i,:,:,:]) |
| visual_ret[name] = (labimg.permute((0,3,1,2))*2-1.0).to(self.device) |
| elif 'Fm' in name: |
| visual_ret[name] = self.save_suppixel(getattr(self, name).cpu()).to(self.device) |
| else: |
| visual_ret[name] = getattr(self, name) |
| return visual_ret |
|
|
| def get_current_losses(self): |
| """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" |
| errors_ret = OrderedDict() |
| for name in self.loss_names: |
| if isinstance(name, str): |
| |
| errors_ret[name] = float(getattr(self, 'loss_' + name)) |
| return errors_ret |
|
|
| def save_networks(self, epoch): |
| """Save all the networks to the disk. |
| |
| Parameters: |
| epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
| """ |
| for name in self.model_names: |
| if isinstance(name, str): |
| save_filename = '%s_net_%s.pth' % (epoch, name) |
| save_path = os.path.join(self.save_dir, save_filename) |
| |
| net = getattr(self, 'net' + name) |
|
|
| if len(self.gpu_ids) > 0 and torch.cuda.is_available(): |
| torch.save(net.state_dict(), save_path) |
| |
| else: |
| torch.save(net.cpu().state_dict(), save_path) |
|
|
| save_filename = '%s_net_opt.pth' % (epoch) |
| save_path = os.path.join(self.save_dir, save_filename) |
| save_dict = {'iter': str(self.iter // self.opt.print_freq * self.opt.print_freq)} |
| for i, name in enumerate(self.optimizer_names): |
| save_dict.update({name.lower(): self.optimizers[i].state_dict()}) |
| torch.save(save_dict, save_path) |
|
|
|
|
| def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): |
| """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" |
| key = keys[i] |
| if i + 1 == len(keys): |
| if module.__class__.__name__.startswith('InstanceNorm') and \ |
| (key == 'running_mean' or key == 'running_var'): |
| if getattr(module, key) is None: |
| state_dict.pop('.'.join(keys)) |
| if module.__class__.__name__.startswith('InstanceNorm') and \ |
| (key == 'num_batches_tracked'): |
| state_dict.pop('.'.join(keys)) |
| else: |
| self.__patch_instance_norm_state_dict( |
| state_dict, getattr(module, key), keys, i + 1) |
|
|
| def load_networks(self, epoch): |
| """Load all the networks from the disk. |
| |
| Parameters: |
| epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) |
| """ |
| for name in self.model_names: |
| if isinstance(name, str): |
| load_filename = '%s_net_%s.pth' % (epoch, name) |
| load_path = os.path.join(self.save_dir, load_filename) |
| net = getattr(self, 'net' + name) |
| |
| if isinstance(net, DDP): |
| net = net.module |
| |
| print('loading the model from %s' % load_path) |
| |
| |
| state_dict = torch.load( |
| load_path, map_location=lambda storage, loc: storage.cuda()) |
| if hasattr(state_dict, '_metadata'): |
| del state_dict._metadata |
|
|
| |
| |
| |
| |
| |
|
|
| net.load_state_dict(state_dict) |
| del state_dict |
|
|
| def print_networks(self, verbose): |
| """Print the total number of parameters in the network and (if verbose) network architecture |
| |
| Parameters: |
| verbose (bool) -- if verbose: print the network architecture |
| """ |
| print('---------- Networks initialized -------------') |
| for name in self.model_names: |
| if isinstance(name, str): |
| net = getattr(self, 'net' + name) |
| num_params = 0 |
| for param in net.parameters(): |
| num_params += param.numel() |
| if verbose: |
| print(net) |
| print('[Network %s] Total number of parameters : %.3f M' % |
| (name, num_params / 1e6)) |
| print('-----------------------------------------------') |
|
|
| def set_requires_grad(self, nets, requires_grad=False): |
| """Set requires_grad=False for all the networks to avoid unnecessary computations |
| Parameters: |
| nets (network list) -- a list of networks |
| requires_grad (bool) -- whether the networks require gradients or not |
| """ |
| if not isinstance(nets, list): |
| nets = [nets] |
| for net in nets: |
| if net is not None: |
| for param in net.parameters(): |
| param.requires_grad = requires_grad |
|
|