Spaces:
No application file
No application file
| import os | |
| import traceback | |
| from collections import OrderedDict | |
| import torch | |
| def savee(ckpt, sr, if_f0, name, epoch, version): | |
| try: | |
| opt = OrderedDict() | |
| opt["weight"] = {} | |
| for key in ckpt.keys(): | |
| if "enc_q" in key: | |
| continue | |
| opt["weight"][key] = ckpt[key].half() | |
| if sr == "40k": | |
| opt["config"] = [ | |
| 1025, | |
| 32, | |
| 192, | |
| 192, | |
| 768, | |
| 2, | |
| 6, | |
| 3, | |
| 0, | |
| "1", | |
| [3, 7, 11], | |
| [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| [10, 10, 2, 2], | |
| 512, | |
| [16, 16, 4, 4], | |
| 109, | |
| 256, | |
| 40000, | |
| ] | |
| elif sr == "48k": | |
| opt["config"] = [ | |
| 1025, | |
| 32, | |
| 192, | |
| 192, | |
| 768, | |
| 2, | |
| 6, | |
| 3, | |
| 0, | |
| "1", | |
| [3, 7, 11], | |
| [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| [10, 6, 2, 2, 2], | |
| 512, | |
| [16, 16, 4, 4, 4], | |
| 109, | |
| 256, | |
| 48000, | |
| ] | |
| elif sr == "32k": | |
| opt["config"] = [ | |
| 513, | |
| 32, | |
| 192, | |
| 192, | |
| 768, | |
| 2, | |
| 6, | |
| 3, | |
| 0, | |
| "1", | |
| [3, 7, 11], | |
| [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| [10, 4, 2, 2, 2], | |
| 512, | |
| [16, 16, 4, 4, 4], | |
| 109, | |
| 256, | |
| 32000, | |
| ] | |
| opt["info"] = "%sepoch" % epoch | |
| opt["sr"] = sr | |
| opt["f0"] = if_f0 | |
| opt["version"] = version | |
| os.makedirs(os.path.dirname(name), exist_ok=True) | |
| torch.save(opt, name) | |
| return "Success." | |
| except: | |
| return traceback.format_exc() | |
| def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): | |
| if hasattr(model, "module"): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) | |
| torch.save( | |
| { | |
| "model": state_dict, | |
| "iteration": iteration, | |
| "optimizer": optimizer.state_dict(), | |
| "learning_rate": learning_rate, | |
| }, | |
| checkpoint_path, | |
| ) | |
| def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): | |
| assert os.path.isfile(checkpoint_path) | |
| checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | |
| saved_state_dict = checkpoint_dict["model"] | |
| if hasattr(model, "module"): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): # 模型需要的shape | |
| try: | |
| new_state_dict[k] = saved_state_dict[k] | |
| if saved_state_dict[k].shape != state_dict[k].shape: | |
| print( | |
| "shape-%s-mismatch|need-%s|get-%s" | |
| % (k, state_dict[k].shape, saved_state_dict[k].shape) | |
| ) # | |
| raise KeyError | |
| except: | |
| # logger.info(traceback.format_exc()) | |
| new_state_dict[k] = v # 模型自带的随机值 | |
| if hasattr(model, "module"): | |
| model.module.load_state_dict(new_state_dict, strict=False) | |
| else: | |
| model.load_state_dict(new_state_dict, strict=False) | |
| iteration = checkpoint_dict["iteration"] | |
| learning_rate = checkpoint_dict["learning_rate"] | |
| if ( | |
| optimizer is not None and load_opt == 1 | |
| ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch | |
| # try: | |
| optimizer.load_state_dict(checkpoint_dict["optimizer"]) | |
| # except: | |
| # traceback.print_exc() | |
| return model, optimizer, learning_rate, iteration |