| """ |
| Adafactor Optimizer for BitTransformerLM Extensions |
| =================================================== |
| |
| Implementation of the Adafactor optimizer with memory-efficient factorization. |
| Based on "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" research. |
| |
| Key features: |
| - Factorized second moment estimates for memory efficiency |
| - Automatic scaling of learning rates |
| - Relative step size and clip threshold |
| - Compatible with BitTransformerLM's training infrastructure |
| """ |
|
|
| import math |
| import torch |
| from torch.optim.optimizer import Optimizer |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
| class Adafactor(Optimizer): |
| """ |
| Adafactor optimizer implementation. |
| |
| Adafactor reduces memory usage by factorizing the second moment estimates |
| for parameters with 2 or more dimensions, making it highly memory efficient |
| for large transformer models. |
| |
| Args: |
| params: Iterable of parameters to optimize |
| lr: External learning rate (default: None, uses automatic scaling) |
| eps2: Regularization constant for second moment (default: 1e-30) |
| cliping_threshold: Threshold for adaptive clipping (default: 1.0) |
| decay_rate: Coefficient used for computing running averages (default: -0.8) |
| beta1: Coefficient used for computing running averages of gradient (default: None) |
| weight_decay: Weight decay coefficient (default: 0.0) |
| scale_parameter: If True, learning rate is scaled by root mean square of parameter (default: True) |
| relative_step_size: If True, use relative step size (default: True) |
| warmup_init: If True, warmup learning rate (default: False) |
| """ |
| |
| def __init__( |
| self, |
| params, |
| lr: Optional[float] = None, |
| eps2: float = 1e-30, |
| cliping_threshold: float = 1.0, |
| decay_rate: float = -0.8, |
| beta1: Optional[float] = None, |
| weight_decay: float = 0.0, |
| scale_parameter: bool = True, |
| relative_step_size: bool = True, |
| warmup_init: bool = False, |
| ): |
| if lr is not None and lr <= 0.0: |
| raise ValueError(f"Invalid learning rate: {lr}") |
| if weight_decay < 0.0: |
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| |
| defaults = dict( |
| lr=lr, |
| eps2=eps2, |
| cliping_threshold=cliping_threshold, |
| decay_rate=decay_rate, |
| beta1=beta1, |
| weight_decay=weight_decay, |
| scale_parameter=scale_parameter, |
| relative_step_size=relative_step_size, |
| warmup_init=warmup_init, |
| ) |
| super().__init__(params, defaults) |
| |
| def _get_lr(self, param_group, param_state): |
| """Compute learning rate for parameter group.""" |
| min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 |
| rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) |
| param_scale = 1.0 |
| if param_group["scale_parameter"]: |
| param_scale = max(param_group["eps2"], param_state["RMS"]) |
| return param_scale * rel_step_sz |
| |
| def _get_options(self, param_group, param_shape): |
| """Get optimization options for parameter.""" |
| factored = len(param_shape) >= 2 |
| use_first_moment = param_group["beta1"] is not None |
| return factored, use_first_moment |
| |
| def _rms(self, tensor): |
| """Root mean square.""" |
| return tensor.norm(2) / (tensor.numel() ** 0.5) |
| |
| def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): |
| """Approximation of exponential moving average of square of gradient.""" |
| r_factor = ((exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) |
| .rsqrt_()) |
| c_factor = ((exp_avg_sq_col).rsqrt()) |
| return torch.mul(r_factor.unsqueeze(-1), c_factor.unsqueeze(0)) |
| |
| @torch.no_grad() |
| def step(self, closure=None): |
| """Perform a single optimization step.""" |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
| |
| for group in self.param_groups: |
| for p in group["params"]: |
| if p.grad is None: |
| continue |
| |
| grad = p.grad |
| if grad.dtype in {torch.float16, torch.bfloat16}: |
| grad = grad.float() |
| |
| state = self.state[p] |
| grad_shape = grad.shape |
| |
| factored, use_first_moment = self._get_options(group, grad_shape) |
| |
| |
| if len(state) == 0: |
| state["step"] = 0 |
| |
| if use_first_moment: |
| |
| state["exp_avg"] = torch.zeros_like(grad).float() |
| if factored: |
| state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).float() |
| state["exp_avg_sq_col"] = torch.zeros( |
| grad_shape[:-2] + grad_shape[-1:]).float() |
| else: |
| state["exp_avg_sq"] = torch.zeros_like(grad).float() |
| |
| state["RMS"] = 0 |
| |
| p_data_fp32 = p.data |
| if p.data.dtype in {torch.float16, torch.bfloat16}: |
| p_data_fp32 = p_data_fp32.float() |
| |
| state["step"] += 1 |
| state["RMS"] = self._rms(p_data_fp32) |
| |
| lr = group["lr"] |
| if group["lr"] is None: |
| lr = self._get_lr(group, state) |
| |
| beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) |
| update = grad**2 + group["eps2"] |
| |
| if factored: |
| exp_avg_sq_row = state["exp_avg_sq_row"] |
| exp_avg_sq_col = state["exp_avg_sq_col"] |
| |
| exp_avg_sq_row.mul_(beta2t).add_( |
| update.mean(dim=-1), alpha=1.0 - beta2t) |
| exp_avg_sq_col.mul_(beta2t).add_( |
| update.mean(dim=-2), alpha=1.0 - beta2t) |
| |
| update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
| update.mul_(grad) |
| else: |
| exp_avg_sq = state["exp_avg_sq"] |
| exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) |
| update = exp_avg_sq.rsqrt().mul_(grad) |
| |
| update.div_(max(1.0, self._rms(update) / group["cliping_threshold"])) |
| |
| if use_first_moment: |
| exp_avg = state["exp_avg"] |
| exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) |
| update = exp_avg |
| |
| if group["weight_decay"] != 0: |
| p_data_fp32.mul_(1 - group["weight_decay"] * lr) |
| |
| p_data_fp32.add_(update, alpha=-lr) |
| |
| if p.data.dtype in {torch.float16, torch.bfloat16}: |
| p.data.copy_(p_data_fp32) |
| |
| return loss |
|
|
|
|
| def configure_adafactor_optimizer( |
| model: torch.nn.Module, |
| lr: Optional[float] = None, |
| weight_decay: float = 0.0, |
| total_steps: Optional[int] = None, |
| warmup_ratio: float = 0.1, |
| scale_parameter: bool = True, |
| relative_step_size: bool = True, |
| warmup_init: bool = False, |
| cliping_threshold: float = 1.0, |
| decay_rate: float = -0.8, |
| beta1: Optional[float] = None, |
| eps2: float = 1e-30, |
| **adafactor_kwargs |
| ) -> Tuple[Adafactor, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
| """ |
| Configure Adafactor optimizer with optional learning rate scheduling. |
| |
| This function provides a drop-in replacement for BitTransformerLM's |
| configure_optimizer function, using Adafactor instead of AdamW. |
| |
| Args: |
| model: PyTorch model to optimize |
| lr: External learning rate (None for automatic scaling) |
| weight_decay: Weight decay coefficient |
| total_steps: Total training steps for scheduling |
| warmup_ratio: Fraction of steps for warmup |
| scale_parameter: Whether to scale learning rate by parameter RMS |
| relative_step_size: Whether to use relative step size |
| warmup_init: Whether to use warmup initialization |
| cliping_threshold: Threshold for adaptive clipping |
| decay_rate: Decay rate for second moment estimates |
| beta1: Coefficient for first moment (None to disable) |
| eps2: Regularization constant |
| **adafactor_kwargs: Additional arguments for Adafactor |
| |
| Returns: |
| Tuple of (optimizer, scheduler) |
| """ |
| |
| params = [p for p in model.parameters() if p.requires_grad] |
| |
| optimizer = Adafactor( |
| params, |
| lr=lr, |
| weight_decay=weight_decay, |
| scale_parameter=scale_parameter, |
| relative_step_size=relative_step_size, |
| warmup_init=warmup_init, |
| cliping_threshold=cliping_threshold, |
| decay_rate=decay_rate, |
| beta1=beta1, |
| eps2=eps2, |
| **adafactor_kwargs |
| ) |
| |
| scheduler = None |
| |
| if total_steps is not None and total_steps > 0 and lr is not None: |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=lr, |
| total_steps=total_steps, |
| pct_start=warmup_ratio, |
| anneal_strategy='cos', |
| cycle_momentum=False, |
| div_factor=25.0, |
| final_div_factor=1e4, |
| ) |
| |
| return optimizer, scheduler |
|
|
|
|
| class AdafactorScheduler(torch.optim.lr_scheduler._LRScheduler): |
| """ |
| Custom scheduler for Adafactor with warmup and polynomial decay. |
| |
| This scheduler is specifically designed to work with Adafactor's |
| relative step size feature. |
| """ |
| |
| def __init__( |
| self, |
| optimizer: Adafactor, |
| warmup_steps: int = 1000, |
| total_steps: Optional[int] = None, |
| min_lr_ratio: float = 0.1, |
| polynomial_power: float = 1.0, |
| last_epoch: int = -1, |
| ): |
| self.warmup_steps = warmup_steps |
| self.total_steps = total_steps |
| self.min_lr_ratio = min_lr_ratio |
| self.polynomial_power = polynomial_power |
| super().__init__(optimizer, last_epoch) |
| |
| def get_lr(self): |
| step = self.last_epoch + 1 |
| |
| if step < self.warmup_steps: |
| |
| return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs] |
| |
| if self.total_steps is None: |
| |
| return self.base_lrs |
| |
| |
| progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
| progress = min(progress, 1.0) |
| decay_factor = (1 - progress) ** self.polynomial_power |
| decay_factor = max(decay_factor, self.min_lr_ratio) |
| |
| return [base_lr * decay_factor for base_lr in self.base_lrs] |
|
|
|
|
| def configure_adafactor_with_scheduler( |
| model: torch.nn.Module, |
| lr: float = 1e-3, |
| warmup_steps: int = 1000, |
| total_steps: Optional[int] = None, |
| weight_decay: float = 0.0, |
| **kwargs |
| ) -> Tuple[Adafactor, AdafactorScheduler]: |
| """ |
| Configure Adafactor optimizer with custom Adafactor scheduler. |
| |
| Args: |
| model: PyTorch model to optimize |
| lr: Base learning rate |
| warmup_steps: Number of warmup steps |
| total_steps: Total training steps |
| weight_decay: Weight decay coefficient |
| **kwargs: Additional arguments for Adafactor |
| |
| Returns: |
| Tuple of (optimizer, scheduler) |
| """ |
| params = [p for p in model.parameters() if p.requires_grad] |
| |
| optimizer = Adafactor( |
| params, |
| lr=lr, |
| weight_decay=weight_decay, |
| relative_step_size=False, |
| **kwargs |
| ) |
| |
| scheduler = AdafactorScheduler( |
| optimizer, |
| warmup_steps=warmup_steps, |
| total_steps=total_steps, |
| ) |
| |
| return optimizer, scheduler |
|
|
|
|
| def create_adafactor_training_config( |
| lr: Optional[float] = None, |
| weight_decay: float = 0.0, |
| scale_parameter: bool = True, |
| relative_step_size: bool = True, |
| warmup_init: bool = False, |
| **kwargs |
| ) -> Dict[str, Any]: |
| """ |
| Create a training configuration dictionary for Adafactor optimizer. |
| |
| Args: |
| lr: External learning rate (None for automatic) |
| weight_decay: Weight decay coefficient |
| scale_parameter: Whether to scale by parameter RMS |
| relative_step_size: Whether to use relative step size |
| warmup_init: Whether to use warmup initialization |
| **kwargs: Additional configuration options |
| |
| Returns: |
| Dictionary containing training configuration |
| """ |
| config = { |
| "optimizer_type": "adafactor", |
| "optimizer_config": { |
| "lr": lr, |
| "weight_decay": weight_decay, |
| "scale_parameter": scale_parameter, |
| "relative_step_size": relative_step_size, |
| "warmup_init": warmup_init, |
| **kwargs |
| }, |
| "scheduler_type": "adafactor_custom" if lr is None else "onecycle", |
| } |
| |
| return config |
|
|
|
|
| |
| def integrate_with_bittransformerlm(): |
| """ |
| Example of how to integrate Adafactor optimizer with BitTransformerLM training. |
| |
| Usage: |
| from BTLM_Extensions.adafactor_optimizer import configure_adafactor_optimizer |
| |
| # Option 1: Use Adafactor with automatic learning rate scaling |
| optimizer, scheduler = configure_adafactor_optimizer( |
| model, lr=None, total_steps=1000 # lr=None enables auto-scaling |
| ) |
| |
| # Option 2: Use Adafactor with fixed learning rate |
| optimizer, scheduler = configure_adafactor_optimizer( |
| model, lr=1e-3, total_steps=1000 |
| ) |
| |
| # Option 3: Use Adafactor with custom scheduler |
| from BTLM_Extensions.adafactor_optimizer import configure_adafactor_with_scheduler |
| |
| optimizer, scheduler = configure_adafactor_with_scheduler( |
| model, lr=1e-3, warmup_steps=100, total_steps=1000 |
| ) |
| |
| # Use in training loop |
| train_loop(model, data, optimizer=optimizer, scheduler=scheduler) |
| """ |
| pass |
|
|
|
|
| def analyze_memory_usage(model: torch.nn.Module) -> Dict[str, float]: |
| """ |
| Analyze memory usage comparison between optimizers. |
| |
| Args: |
| model: PyTorch model to analyze |
| |
| Returns: |
| Dictionary with memory usage estimates in MB |
| """ |
| param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| param_bytes = param_count * 4 |
| |
| |
| adamw_memory = param_bytes * 4 |
| |
| |
| adafactor_memory = param_bytes |
| adafactor_memory += param_bytes |
| |
| |
| factored_params = 0 |
| unfactored_params = 0 |
| |
| for p in model.parameters(): |
| if p.requires_grad: |
| if len(p.shape) >= 2: |
| factored_params += p.shape[0] + p.shape[1] |
| else: |
| unfactored_params += p.numel() |
| |
| adafactor_memory += (factored_params + unfactored_params) * 4 |
| |
| return { |
| "adamw_mb": adamw_memory / (1024 * 1024), |
| "adafactor_mb": adafactor_memory / (1024 * 1024), |
| "savings_mb": (adamw_memory - adafactor_memory) / (1024 * 1024), |
| "savings_percent": ((adamw_memory - adafactor_memory) / adamw_memory) * 100, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| |
| import torch.nn as nn |
| |
| model = nn.Sequential( |
| nn.Linear(100, 200), |
| nn.ReLU(), |
| nn.Linear(200, 50), |
| nn.ReLU(), |
| nn.Linear(50, 1) |
| ) |
| |
| print("Testing Adafactor optimizer...") |
| |
| |
| optimizer, scheduler = configure_adafactor_optimizer( |
| model, lr=None, total_steps=100 |
| ) |
| |
| |
| x = torch.randn(32, 100) |
| y = torch.randn(32, 1) |
| |
| pred = model(x) |
| loss = nn.functional.mse_loss(pred, y) |
| initial_loss = loss.item() |
| loss.backward() |
| |
| optimizer.step() |
| if scheduler: |
| scheduler.step() |
| |
| |
| optimizer2, scheduler2 = configure_adafactor_optimizer( |
| model, lr=1e-3, total_steps=100 |
| ) |
| |
| pred = model(x) |
| loss = nn.functional.mse_loss(pred, y) |
| loss.backward() |
| optimizer2.step() |
| if scheduler2: |
| scheduler2.step() |
| |
| |
| memory_analysis = analyze_memory_usage(model) |
| |
| print("Adafactor optimizer test completed successfully!") |
| print(f"Initial loss: {initial_loss:.4f}") |
| print(f"Final loss: {loss.item():.4f}") |
| print(f"Memory analysis: {memory_analysis}") |