| """ |
| MobiusNet Trainer with TensorBoard, SafeTensors, and HuggingFace Upload |
| ======================================================================= |
| """ |
|
|
| import os |
| import re |
| import json |
| import math |
| import shutil |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
| from typing import Tuple, Optional, Dict, Any |
| from torchvision import datasets, transforms |
| from torch.utils.data import DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm.auto import tqdm |
| from datetime import datetime |
| from pathlib import Path |
| from safetensors.torch import save_file as save_safetensors, load_file as load_safetensors |
| from huggingface_hub import HfApi, login |
|
|
| |
| try: |
| from google.colab import userdata |
| token = userdata.get('HF_TOKEN') |
| os.environ['HF_TOKEN'] = token |
| login(token=token) |
| print("Logged in to HuggingFace via Colab") |
| except: |
| |
| pass |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
|
|
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.set_float32_matmul_precision('high') |
|
|
|
|
| |
| |
| |
|
|
| class MobiusLens(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| layer_idx: int, |
| total_layers: int, |
| scale_range: Tuple[float, float] = (1.0, 9.0), |
| ): |
| super().__init__() |
| |
| self.dim = dim |
| self.layer_idx = layer_idx |
| self.total_layers = total_layers |
| self.t = layer_idx / max(total_layers - 1, 1) |
| |
| scale_span = scale_range[1] - scale_range[0] |
| step = scale_span / max(total_layers, 1) |
| scale_low = scale_range[0] + self.t * scale_span |
| scale_high = scale_low + step |
| |
| self.register_buffer('scales', torch.tensor([scale_low, scale_high])) |
| |
| self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) |
| self.twist_in_proj = nn.Linear(dim, dim, bias=False) |
| nn.init.orthogonal_(self.twist_in_proj.weight) |
| |
| self.omega = nn.Parameter(torch.tensor(math.pi)) |
| self.alpha = nn.Parameter(torch.tensor(1.5)) |
| |
| self.phase_l = nn.Parameter(torch.zeros(2)) |
| self.drift_l = nn.Parameter(torch.ones(2)) |
| self.phase_m = nn.Parameter(torch.zeros(2)) |
| self.drift_m = nn.Parameter(torch.zeros(2)) |
| self.phase_r = nn.Parameter(torch.zeros(2)) |
| self.drift_r = nn.Parameter(-torch.ones(2)) |
| |
| self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) |
| self.xor_weight = nn.Parameter(torch.tensor(0.7)) |
| |
| self.gate_norm = nn.LayerNorm(dim) |
| |
| self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) |
| self.twist_out_proj = nn.Linear(dim, dim, bias=False) |
| nn.init.orthogonal_(self.twist_out_proj.weight) |
| |
| def _twist_in(self, x: Tensor) -> Tensor: |
| cos_t = torch.cos(self.twist_in_angle) |
| sin_t = torch.sin(self.twist_in_angle) |
| return x * cos_t + self.twist_in_proj(x) * sin_t |
| |
| def _center_lens(self, x: Tensor) -> Tensor: |
| x_norm = torch.tanh(x) |
| t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
| |
| x_exp = x_norm.unsqueeze(-2) |
| s = self.scales.view(-1, 1) |
| |
| def wave(phase, drift): |
| a = self.alpha.abs() + 0.1 |
| pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
| return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) |
| |
| L = wave(self.phase_l, self.drift_l) |
| M = wave(self.phase_m, self.drift_m) |
| R = wave(self.phase_r, self.drift_r) |
| |
| w = torch.softmax(self.accum_weights, dim=0) |
| xor_w = torch.sigmoid(self.xor_weight) |
| |
| xor_comp = (L + R - 2 * L * R).abs() |
| and_comp = L * R |
| lr = xor_w * xor_comp + (1 - xor_w) * and_comp |
| |
| gate = w[0] * L + w[1] * M + w[2] * R |
| gate = gate * (0.5 + 0.5 * lr) |
| gate = torch.sigmoid(self.gate_norm(gate)) |
| |
| return x * gate |
| |
| def _twist_out(self, x: Tensor) -> Tensor: |
| cos_t = torch.cos(self.twist_out_angle) |
| sin_t = torch.sin(self.twist_out_angle) |
| return x * cos_t + self.twist_out_proj(x) * sin_t |
| |
| def forward(self, x: Tensor) -> Tensor: |
| return self._twist_out(self._center_lens(self._twist_in(x))) |
| |
| def get_lens_stats(self) -> Dict[str, float]: |
| """Return lens parameters for logging.""" |
| return { |
| 'omega': self.omega.item(), |
| 'alpha': self.alpha.item(), |
| 'twist_in_angle': self.twist_in_angle.item(), |
| 'twist_out_angle': self.twist_out_angle.item(), |
| 'xor_weight': torch.sigmoid(self.xor_weight).item(), |
| 'accum_weights_l': torch.softmax(self.accum_weights, dim=0)[0].item(), |
| 'accum_weights_m': torch.softmax(self.accum_weights, dim=0)[1].item(), |
| 'accum_weights_r': torch.softmax(self.accum_weights, dim=0)[2].item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class MobiusConvBlock(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| layer_idx: int, |
| total_layers: int, |
| scale_range: Tuple[float, float] = (1.0, 9.0), |
| reduction: float = 0.5, |
| ): |
| super().__init__() |
| |
| self.conv = nn.Sequential( |
| nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), |
| nn.Conv2d(channels, channels, 1, bias=False), |
| nn.BatchNorm2d(channels), |
| ) |
| |
| self.lens = MobiusLens(channels, layer_idx, total_layers, scale_range) |
| |
| third = channels // 3 |
| which_third = layer_idx % 3 |
| mask = torch.ones(channels) |
| start = which_third * third |
| end = start + third + (channels % 3 if which_third == 2 else 0) |
| mask[start:end] = reduction |
| self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) |
| |
| self.residual_weight = nn.Parameter(torch.tensor(0.9)) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| identity = x |
| |
| h = self.conv(x) |
| B, D, H, W = h.shape |
| h = h.permute(0, 2, 3, 1) |
| h = self.lens(h) |
| h = h.permute(0, 3, 1, 2) |
| h = h * self.thirds_mask |
| |
| rw = torch.sigmoid(self.residual_weight) |
| return rw * identity + (1 - rw) * h |
| |
| def get_residual_weight(self) -> float: |
| return torch.sigmoid(self.residual_weight).item() |
|
|
|
|
| |
| |
| |
|
|
| class MobiusNet(nn.Module): |
| def __init__( |
| self, |
| in_chans: int = 3, |
| num_classes: int = 200, |
| channels: Tuple[int, ...] = (64, 128, 256, 512), |
| depths: Tuple[int, ...] = (2, 2, 2, 2), |
| scale_range: Tuple[float, float] = (0.5, 2.5), |
| use_integrator: bool = True, |
| ): |
| super().__init__() |
| |
| num_stages = len(depths) |
| total_layers = sum(depths) |
| |
| self.total_layers = total_layers |
| self.scale_range = scale_range |
| self.channels = tuple(channels) |
| self.depths = tuple(depths) |
| self.num_stages = num_stages |
| self.use_integrator = use_integrator |
| self.num_classes = num_classes |
| self.in_chans = in_chans |
| |
| channels = list(channels) |
| while len(channels) < num_stages: |
| channels.append(channels[-1]) |
| |
| self.stem = nn.Sequential( |
| nn.Conv2d(in_chans, channels[0], 3, stride=1, padding=1, bias=False), |
| nn.BatchNorm2d(channels[0]), |
| ) |
| |
| layer_idx = 0 |
| self.stages = nn.ModuleList() |
| self.downsamples = nn.ModuleList() |
| |
| for stage_idx in range(num_stages): |
| ch = channels[stage_idx] |
| |
| stage = nn.ModuleList() |
| for _ in range(depths[stage_idx]): |
| stage.append(MobiusConvBlock(ch, layer_idx, total_layers, scale_range)) |
| layer_idx += 1 |
| self.stages.append(stage) |
| |
| if stage_idx < num_stages - 1: |
| ch_next = channels[stage_idx + 1] |
| self.downsamples.append(nn.Sequential( |
| nn.Conv2d(ch, ch_next, 3, stride=2, padding=1, bias=False), |
| nn.BatchNorm2d(ch_next), |
| )) |
| |
| final_ch = channels[num_stages - 1] |
| if use_integrator: |
| self.integrator = nn.Sequential( |
| nn.Conv2d(final_ch, final_ch, 3, padding=1, bias=False), |
| nn.BatchNorm2d(final_ch), |
| nn.GELU(), |
| ) |
| else: |
| self.integrator = nn.Identity() |
| |
| self.pool = nn.AdaptiveAvgPool2d(1) |
| self.head = nn.Linear(final_ch, num_classes) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| x = self.stem(x) |
| |
| for i, stage in enumerate(self.stages): |
| for block in stage: |
| x = block(x) |
| if i < len(self.downsamples): |
| x = self.downsamples[i](x) |
| |
| x = self.integrator(x) |
| return self.head(self.pool(x).flatten(1)) |
| |
| def get_config(self) -> Dict[str, Any]: |
| """Return model configuration for saving.""" |
| return { |
| 'in_chans': self.in_chans, |
| 'num_classes': self.num_classes, |
| 'channels': self.channels, |
| 'depths': self.depths, |
| 'scale_range': self.scale_range, |
| 'use_integrator': self.use_integrator, |
| 'total_layers': self.total_layers, |
| 'num_stages': self.num_stages, |
| } |
| |
| def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]: |
| """Return stats from all lenses for logging.""" |
| stats = {} |
| layer_idx = 0 |
| for stage_idx, stage in enumerate(self.stages): |
| for block_idx, block in enumerate(stage): |
| key = f"stage{stage_idx}_block{block_idx}" |
| stats[key] = block.lens.get_lens_stats() |
| stats[key]['residual_weight'] = block.get_residual_weight() |
| layer_idx += 1 |
| return stats |
|
|
|
|
| |
| |
| |
|
|
| def get_tiny_imagenet_loaders(data_dir='./data/tiny-imagenet-200', batch_size=128): |
| train_dir = os.path.join(data_dir, 'train') |
| val_dir = os.path.join(data_dir, 'val') |
| |
| val_images_dir = os.path.join(val_dir, 'images') |
| if os.path.exists(val_images_dir): |
| print("Reorganizing validation folder...") |
| reorganize_val_folder(val_dir) |
| |
| train_transform = transforms.Compose([ |
| transforms.RandomCrop(64, padding=8), |
| transforms.RandomHorizontalFlip(), |
| transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET), |
| transforms.ToTensor(), |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| ]) |
| |
| val_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| ]) |
| |
| train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) |
| val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) |
| |
| train_loader = DataLoader( |
| train_dataset, batch_size=batch_size, shuffle=True, |
| num_workers=8, pin_memory=True, persistent_workers=True |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size=256, shuffle=False, |
| num_workers=4, pin_memory=True, persistent_workers=True |
| ) |
| |
| return train_loader, val_loader |
|
|
|
|
| def reorganize_val_folder(val_dir): |
| """Reorganize Tiny ImageNet val folder into class subfolders.""" |
| val_images_dir = os.path.join(val_dir, 'images') |
| val_annotations = os.path.join(val_dir, 'val_annotations.txt') |
| |
| if not os.path.exists(val_images_dir): |
| return |
| |
| with open(val_annotations, 'r') as f: |
| for line in f: |
| parts = line.strip().split('\t') |
| img_name, class_id = parts[0], parts[1] |
| |
| class_dir = os.path.join(val_dir, class_id) |
| os.makedirs(class_dir, exist_ok=True) |
| |
| src = os.path.join(val_images_dir, img_name) |
| dst = os.path.join(class_dir, img_name) |
| |
| if os.path.exists(src): |
| shutil.move(src, dst) |
| |
| if os.path.exists(val_images_dir): |
| shutil.rmtree(val_images_dir) |
| if os.path.exists(val_annotations): |
| os.remove(val_annotations) |
| |
| print("Validation folder reorganized.") |
|
|
|
|
| |
| |
| |
|
|
| PRESETS = { |
| 'mobius_tiny_s': { |
| 'channels': (64, 128, 256), |
| 'depths': (2, 2, 2), |
| 'scale_range': (0.5, 2.5), |
| }, |
| 'mobius_tiny_m': { |
| 'channels': (64, 128, 256, 512, 768), |
| 'depths': (2, 2, 4, 2, 2), |
| 'scale_range': (0.25, 2.75), |
| }, |
| 'mobius_tiny_l': { |
| 'channels': (96, 192, 384, 768), |
| 'depths': (3, 3, 3, 3), |
| 'scale_range': (0.5, 3.5), |
| }, |
| 'mobius_base': { |
| 'channels': (128, 256, 512, 768, 1024), |
| 'depths': (2, 2, 2, 2, 2), |
| 'scale_range': (0.25, 2.75), |
| }, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class CheckpointManager: |
| def __init__( |
| self, |
| base_dir: str, |
| variant_name: str, |
| dataset_name: str, |
| hf_repo: str = "AbstractPhil/mobiusnet", |
| upload_every_n_epochs: int = 10, |
| save_every_n_epochs: int = 10, |
| timestamp: Optional[str] = None, |
| ): |
| self.timestamp = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") |
| self.variant_name = variant_name |
| self.dataset_name = dataset_name |
| self.hf_repo = hf_repo |
| self.upload_every_n_epochs = upload_every_n_epochs |
| self.save_every_n_epochs = save_every_n_epochs |
| |
| |
| self.run_name = f"{variant_name}_{dataset_name}" |
| self.run_dir = Path(base_dir) / "checkpoints" / self.run_name / self.timestamp |
| self.checkpoints_dir = self.run_dir / "checkpoints" |
| self.tensorboard_dir = self.run_dir / "tensorboard" |
| |
| |
| self.checkpoints_dir.mkdir(parents=True, exist_ok=True) |
| self.tensorboard_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| self.writer = SummaryWriter(log_dir=str(self.tensorboard_dir)) |
| |
| |
| self.hf_api = HfApi() |
| self.uploaded_files = set() |
| |
| |
| self.best_acc = 0.0 |
| self.best_epoch = 0 |
| self.best_changed_since_upload = False |
| |
| print(f"Checkpoint directory: {self.run_dir}") |
| |
| @staticmethod |
| def extract_timestamp(checkpoint_path: str) -> Optional[str]: |
| """Extract timestamp from checkpoint path.""" |
| |
| match = re.search(r'(\d{8}_\d{6})', checkpoint_path) |
| if match: |
| return match.group(1) |
| return None |
| |
| def save_config(self, config: Dict[str, Any], training_config: Dict[str, Any]): |
| """Save model and training configuration.""" |
| full_config = { |
| 'model': config, |
| 'training': training_config, |
| 'timestamp': self.timestamp, |
| 'variant_name': self.variant_name, |
| 'dataset_name': self.dataset_name, |
| } |
| |
| config_path = self.run_dir / "config.json" |
| with open(config_path, 'w') as f: |
| json.dump(full_config, f, indent=2) |
| |
| return config_path |
| |
| def save_checkpoint( |
| self, |
| model: nn.Module, |
| optimizer: torch.optim.Optimizer, |
| scheduler: Any, |
| epoch: int, |
| train_acc: float, |
| val_acc: float, |
| train_loss: float, |
| is_best: bool = False, |
| ): |
| """Save checkpoint every N epochs, always save best (overwriting).""" |
| |
| |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| |
| |
| checkpoint = { |
| 'epoch': epoch, |
| 'train_acc': train_acc, |
| 'val_acc': val_acc, |
| 'train_loss': train_loss, |
| 'best_acc': self.best_acc, |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| } |
| |
| |
| if epoch % self.save_every_n_epochs == 0: |
| epoch_pt_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt" |
| torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, epoch_pt_path) |
| |
| epoch_st_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors" |
| save_safetensors(raw_model.state_dict(), str(epoch_st_path)) |
| |
| |
| if is_best: |
| self.best_acc = val_acc |
| self.best_epoch = epoch |
| self.best_changed_since_upload = True |
| |
| |
| best_pt_path = self.checkpoints_dir / "best_model.pt" |
| torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, best_pt_path) |
| |
| |
| best_st_path = self.checkpoints_dir / "best_model.safetensors" |
| save_safetensors(raw_model.state_dict(), str(best_st_path)) |
| |
| |
| acc_path = self.run_dir / "best_accuracy.json" |
| with open(acc_path, 'w') as f: |
| json.dump({ |
| 'best_acc': val_acc, |
| 'best_epoch': epoch, |
| 'train_acc': train_acc, |
| 'train_loss': train_loss, |
| }, f, indent=2) |
| |
| def save_final(self, model: nn.Module, final_acc: float, final_epoch: int): |
| """Save final model.""" |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| |
| |
| final_st_path = self.checkpoints_dir / "final_model.safetensors" |
| save_safetensors(raw_model.state_dict(), str(final_st_path)) |
| |
| |
| final_pt_path = self.checkpoints_dir / "final_model.pt" |
| torch.save({ |
| 'model_state_dict': raw_model.state_dict(), |
| 'final_acc': final_acc, |
| 'final_epoch': final_epoch, |
| 'best_acc': self.best_acc, |
| 'best_epoch': self.best_epoch, |
| }, final_pt_path) |
| |
| |
| acc_path = self.run_dir / "final_accuracy.json" |
| with open(acc_path, 'w') as f: |
| json.dump({ |
| 'final_acc': final_acc, |
| 'final_epoch': final_epoch, |
| 'best_acc': self.best_acc, |
| 'best_epoch': self.best_epoch, |
| }, f, indent=2) |
| |
| return final_st_path, final_pt_path |
| |
| def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""): |
| """Log scalars to TensorBoard.""" |
| for name, value in scalars.items(): |
| tag = f"{prefix}/{name}" if prefix else name |
| self.writer.add_scalar(tag, value, epoch) |
| |
| def log_lens_stats(self, epoch: int, model: nn.Module): |
| """Log lens statistics to TensorBoard.""" |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| stats = raw_model.get_all_lens_stats() |
| |
| for block_name, block_stats in stats.items(): |
| for stat_name, value in block_stats.items(): |
| self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch) |
| |
| def log_histograms(self, epoch: int, model: nn.Module): |
| """Log weight histograms to TensorBoard.""" |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| |
| for name, param in raw_model.named_parameters(): |
| if param.requires_grad: |
| self.writer.add_histogram(f"weights/{name}", param.data, epoch) |
| if param.grad is not None: |
| self.writer.add_histogram(f"gradients/{name}", param.grad, epoch) |
| |
| def upload_to_hf(self, epoch: int, force: bool = False): |
| """Upload checkpoint every N epochs. Best uploads only on upload epochs if changed.""" |
| if not force and epoch % self.upload_every_n_epochs != 0: |
| return |
| |
| try: |
| hf_base_path = f"checkpoints/{self.run_name}/{self.timestamp}" |
| |
| files_to_upload = [] |
| |
| |
| config_path = self.run_dir / "config.json" |
| if config_path.exists(): |
| files_to_upload.append(config_path) |
| |
| |
| if epoch % self.save_every_n_epochs == 0: |
| ckpt_st = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors" |
| ckpt_pt = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt" |
| if ckpt_st.exists(): |
| files_to_upload.append(ckpt_st) |
| if ckpt_pt.exists(): |
| files_to_upload.append(ckpt_pt) |
| |
| |
| if self.best_changed_since_upload: |
| best_files = [ |
| self.checkpoints_dir / "best_model.safetensors", |
| self.checkpoints_dir / "best_model.pt", |
| self.run_dir / "best_accuracy.json", |
| ] |
| for f in best_files: |
| if f.exists(): |
| files_to_upload.append(f) |
| self.best_changed_since_upload = False |
| |
| |
| for local_path in files_to_upload: |
| rel_path = local_path.relative_to(self.run_dir) |
| hf_path = f"{hf_base_path}/{rel_path}" |
| |
| try: |
| self.hf_api.upload_file( |
| path_or_fileobj=str(local_path), |
| path_in_repo=hf_path, |
| repo_id=self.hf_repo, |
| repo_type="model", |
| ) |
| print(f"Uploaded: {hf_path}") |
| except Exception as e: |
| print(f"Failed to upload {rel_path}: {e}") |
| |
| except Exception as e: |
| print(f"HuggingFace upload error: {e}") |
| |
| def close(self): |
| """Close TensorBoard writer.""" |
| self.writer.close() |
| |
| @staticmethod |
| def load_checkpoint( |
| checkpoint_path: str, |
| model: nn.Module, |
| optimizer: Optional[torch.optim.Optimizer] = None, |
| scheduler: Optional[Any] = None, |
| hf_repo: str = "AbstractPhil/mobiusnet", |
| device: torch.device = torch.device('cpu'), |
| ) -> Dict[str, Any]: |
| """ |
| Load checkpoint from local path or HuggingFace repo. |
| |
| Args: |
| checkpoint_path: Either: |
| - Local file path to .pt checkpoint |
| - Local directory containing checkpoints |
| - HuggingFace path like "checkpoints/variant_dataset/timestamp" |
| model: Model to load weights into |
| optimizer: Optional optimizer to restore state |
| scheduler: Optional scheduler to restore state |
| hf_repo: HuggingFace repo ID |
| device: Device to load tensors to |
| |
| Returns: |
| Dict with checkpoint info (epoch, best_acc, etc.) |
| """ |
| from huggingface_hub import hf_hub_download, list_repo_files |
| |
| checkpoint_file = None |
| |
| |
| if os.path.isfile(checkpoint_path): |
| checkpoint_file = checkpoint_path |
| |
| |
| elif os.path.isdir(checkpoint_path): |
| |
| best_path = os.path.join(checkpoint_path, "checkpoints", "best_model.pt") |
| if os.path.exists(best_path): |
| checkpoint_file = best_path |
| else: |
| |
| ckpt_dir = os.path.join(checkpoint_path, "checkpoints") |
| if os.path.isdir(ckpt_dir): |
| pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.startswith("checkpoint_epoch_") and f.endswith(".pt")]) |
| if pt_files: |
| checkpoint_file = os.path.join(ckpt_dir, pt_files[-1]) |
| |
| |
| if checkpoint_file is None: |
| print(f"Attempting to download from HuggingFace: {hf_repo}/{checkpoint_path}") |
| try: |
| |
| if not checkpoint_path.endswith(".pt"): |
| |
| try: |
| checkpoint_file = hf_hub_download( |
| repo_id=hf_repo, |
| filename=f"{checkpoint_path}/checkpoints/best_model.pt", |
| repo_type="model", |
| ) |
| print(f"Downloaded best_model.pt from {hf_repo}") |
| except: |
| |
| files = list_repo_files(repo_id=hf_repo, repo_type="model") |
| ckpt_files = sorted([f for f in files if checkpoint_path in f and f.endswith(".pt") and "checkpoint_epoch_" in f]) |
| if ckpt_files: |
| checkpoint_file = hf_hub_download( |
| repo_id=hf_repo, |
| filename=ckpt_files[-1], |
| repo_type="model", |
| ) |
| print(f"Downloaded {ckpt_files[-1]} from {hf_repo}") |
| else: |
| |
| checkpoint_file = hf_hub_download( |
| repo_id=hf_repo, |
| filename=checkpoint_path, |
| repo_type="model", |
| ) |
| print(f"Downloaded {checkpoint_path} from {hf_repo}") |
| except Exception as e: |
| raise FileNotFoundError(f"Could not find or download checkpoint: {checkpoint_path}. Error: {e}") |
| |
| if checkpoint_file is None: |
| raise FileNotFoundError(f"Could not find checkpoint: {checkpoint_path}") |
| |
| print(f"Loading checkpoint from: {checkpoint_file}") |
| checkpoint = torch.load(checkpoint_file, map_location=device, weights_only=False) |
| |
| |
| raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| raw_model.load_state_dict(checkpoint['model_state_dict']) |
| print(f"Loaded model weights") |
| |
| |
| if optimizer is not None and 'optimizer_state_dict' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| print(f"Loaded optimizer state") |
| |
| |
| if scheduler is not None and 'scheduler_state_dict' in checkpoint: |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| print(f"Loaded scheduler state") |
| |
| info = { |
| 'epoch': checkpoint.get('epoch', 0), |
| 'best_acc': checkpoint.get('best_acc', 0.0), |
| 'train_acc': checkpoint.get('train_acc', 0.0), |
| 'val_acc': checkpoint.get('val_acc', 0.0), |
| 'train_loss': checkpoint.get('train_loss', 0.0), |
| } |
| |
| print(f"Resuming from epoch {info['epoch']} (best_acc: {info['best_acc']:.4f})") |
| |
| return info |
|
|
|
|
| |
| |
| |
|
|
| def train_tiny_imagenet( |
| preset: str = 'mobius_tiny_m', |
| epochs: int = 100, |
| lr: float = 1e-3, |
| batch_size: int = 128, |
| use_integrator: bool = True, |
| data_dir: str = './data/tiny-imagenet-200', |
| output_dir: str = './outputs', |
| hf_repo: str = "AbstractPhil/mobiusnet", |
| save_every_n_epochs: int = 10, |
| upload_every_n_epochs: int = 10, |
| log_histograms_every: int = 10, |
| use_compile: bool = True, |
| continue_from: Optional[str] = None, |
| ): |
| """ |
| Train MobiusNet on Tiny ImageNet. |
| |
| Args: |
| preset: Model preset name |
| epochs: Total epochs to train |
| lr: Learning rate |
| batch_size: Batch size |
| use_integrator: Whether to use integrator layer |
| data_dir: Path to Tiny ImageNet data |
| output_dir: Output directory for checkpoints |
| hf_repo: HuggingFace repo for uploads/downloads |
| save_every_n_epochs: Save checkpoint every N epochs |
| upload_every_n_epochs: Upload to HF every N epochs |
| log_histograms_every: Log weight histograms every N epochs |
| use_compile: Whether to use torch.compile |
| continue_from: Resume from checkpoint. Can be: |
| - Local .pt file path |
| - Local checkpoint directory |
| - HuggingFace path (e.g., "checkpoints/mobius_base_tiny_imagenet/20240101_120000") |
| """ |
| config = PRESETS[preset] |
| dataset_name = "tiny_imagenet" |
| |
| print("=" * 70) |
| print(f"MÖBIUS NET - {preset.upper()} - TINY IMAGENET") |
| print("=" * 70) |
| print(f"Device: {device}") |
| print(f"Channels: {config['channels']}") |
| print(f"Depths: {config['depths']}") |
| print(f"Scale range: {config['scale_range']}") |
| print(f"Integrator: {use_integrator}") |
| if continue_from: |
| print(f"Continuing from: {continue_from}") |
| print() |
| |
| |
| resume_timestamp = None |
| if continue_from: |
| resume_timestamp = CheckpointManager.extract_timestamp(continue_from) |
| if resume_timestamp: |
| print(f"Using original timestamp: {resume_timestamp}") |
| |
| |
| ckpt_manager = CheckpointManager( |
| base_dir=output_dir, |
| variant_name=preset, |
| dataset_name=dataset_name, |
| hf_repo=hf_repo, |
| upload_every_n_epochs=upload_every_n_epochs, |
| save_every_n_epochs=save_every_n_epochs, |
| timestamp=resume_timestamp, |
| ) |
| |
| |
| train_loader, val_loader = get_tiny_imagenet_loaders(data_dir, batch_size) |
| |
| |
| model = MobiusNet( |
| in_chans=3, |
| num_classes=200, |
| use_integrator=use_integrator, |
| **config |
| ).to(device) |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Total params: {total_params:,}") |
| print() |
| |
| |
| training_config = { |
| 'epochs': epochs, |
| 'lr': lr, |
| 'batch_size': batch_size, |
| 'optimizer': 'AdamW', |
| 'weight_decay': 0.05, |
| 'scheduler': 'CosineAnnealingLR', |
| 'total_params': total_params, |
| } |
| ckpt_manager.save_config(model.get_config(), training_config) |
| |
| |
| if use_compile: |
| model = torch.compile(model, mode='reduce-overhead') |
| |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| |
| |
| start_epoch = 1 |
| best_acc = 0.0 |
| |
| if continue_from: |
| ckpt_info = CheckpointManager.load_checkpoint( |
| checkpoint_path=continue_from, |
| model=model, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| hf_repo=hf_repo, |
| device=device, |
| ) |
| start_epoch = ckpt_info['epoch'] + 1 |
| best_acc = ckpt_info['best_acc'] |
| ckpt_manager.best_acc = best_acc |
| ckpt_manager.best_epoch = ckpt_info['epoch'] |
| print(f"Resuming training from epoch {start_epoch}") |
| |
| for epoch in range(start_epoch, epochs + 1): |
| |
| model.train() |
| train_loss, train_correct, train_total = 0, 0, 0 |
| |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") |
| for x, y in pbar: |
| x, y = x.to(device), y.to(device) |
| |
| optimizer.zero_grad() |
| logits = model(x) |
| loss = F.cross_entropy(logits, y) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| |
| train_loss += loss.item() * x.size(0) |
| train_correct += (logits.argmax(1) == y).sum().item() |
| train_total += x.size(0) |
| |
| pbar.set_postfix(loss=f"{loss.item():.4f}") |
| |
| scheduler.step() |
| |
| |
| model.eval() |
| val_correct, val_total = 0, 0 |
| with torch.no_grad(): |
| for x, y in val_loader: |
| x, y = x.to(device), y.to(device) |
| logits = model(x) |
| val_correct += (logits.argmax(1) == y).sum().item() |
| val_total += x.size(0) |
| |
| |
| train_acc = train_correct / train_total |
| val_acc = val_correct / val_total |
| avg_loss = train_loss / train_total |
| current_lr = scheduler.get_last_lr()[0] |
| |
| is_best = val_acc > best_acc |
| if is_best: |
| best_acc = val_acc |
| |
| marker = " ★" if is_best else "" |
| print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " |
| f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") |
| |
| |
| ckpt_manager.log_scalars(epoch, { |
| 'loss': avg_loss, |
| 'train_acc': train_acc, |
| 'val_acc': val_acc, |
| 'best_acc': best_acc, |
| 'learning_rate': current_lr, |
| }, prefix="train") |
| |
| |
| ckpt_manager.log_lens_stats(epoch, model) |
| |
| |
| if epoch % log_histograms_every == 0: |
| ckpt_manager.log_histograms(epoch, model) |
| |
| |
| ckpt_manager.save_checkpoint( |
| model=model, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| epoch=epoch, |
| train_acc=train_acc, |
| val_acc=val_acc, |
| train_loss=avg_loss, |
| is_best=is_best, |
| ) |
| |
| |
| ckpt_manager.upload_to_hf(epoch) |
| |
| |
| ckpt_manager.save_final(model, val_acc, epochs) |
| |
| |
| ckpt_manager.upload_to_hf(epochs, force=True) |
| ckpt_manager.close() |
| |
| print() |
| print("=" * 70) |
| print("FINAL RESULTS") |
| print("=" * 70) |
| print(f"Preset: {preset}") |
| print(f"Best accuracy: {best_acc:.4f}") |
| print(f"Total params: {total_params:,}") |
| print(f"Checkpoints: {ckpt_manager.run_dir}") |
| print("=" * 70) |
| |
| return model, best_acc |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| model, best_acc = train_tiny_imagenet( |
| preset='mobius_base', |
| epochs=200, |
| lr=3e-4, |
| batch_size=128, |
| use_integrator=True, |
| data_dir='./data/tiny-imagenet-200', |
| output_dir='./outputs', |
| hf_repo='AbstractPhil/mobiusnet', |
| save_every_n_epochs=10, |
| upload_every_n_epochs=10, |
| log_histograms_every=10, |
| use_compile=True, |
| continue_from='/content/outputs/checkpoints/mobius_base_tiny_imagenet/20260110_132436/checkpoints/best_model.pt', |
| |
| |
| |
| |
| ) |