| |
| |
| |
| |
| |
| |
| |
|
|
| from logging import getLogger |
| import math |
| import os |
| from typing import Dict, List, Optional, Union, Tuple |
| from types import MethodType |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.nn.utils import parametrize |
|
|
|
|
| |
| class DAMP(nn.Identity): |
| def __init__(self, std: float): |
| super().__init__() |
| self.std = std |
|
|
|
|
| def enable_damp(model: nn.Module, std: float): |
| if isinstance(model, (list, tuple)): |
| for m in model: |
| enable_damp(m, std) |
| return |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Linear): |
| parametrize.register_parametrization(module, 'weight', DAMP(std)) |
|
|
|
|
| def configure_damp_from_args(model: nn.Module, args): |
| damp = getattr(args, 'damp', None) |
| if damp: |
| enable_damp(model, damp) |
|
|