Delete src
Browse files- src/checkpointing/__init__.py +0 -23
- src/checkpointing/evaluation.py +0 -68
- src/checkpointing/learning_dynamics.py +0 -424
- src/checkpointing/training.py +0 -287
- src/config/__init__.py +0 -31
- src/config/_constants.py +0 -18
- src/config/checkpointing_config.py +0 -97
- src/config/data_config.py +0 -36
- src/config/evaluation_config.py +0 -28
- src/config/model_config.py +0 -33
- src/config/monitoring_config.py +0 -29
- src/config/training_config.py +0 -40
- src/evaluation/__init__.py +0 -103
- src/evaluation/tasks/paloma.py +0 -52
- src/model/__init__.py +0 -12
- src/model/pico_decoder.py +0 -911
- src/training/trainer.py +0 -753
- src/training/utils/__init__.py +0 -34
- src/training/utils/data.py +0 -35
- src/training/utils/initialization.py +0 -702
- src/training/utils/io.py +0 -52
- src/training/utils/logging.py +0 -48
src/checkpointing/__init__.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pico Checkpointing Package
|
| 3 |
-
|
| 4 |
-
We subdivide the checkpointing into training, evaluation, and learning_dynamics. Training
|
| 5 |
-
checkpoints store the model, optimizer, and learning rate scheduler. Evaluation checkpoints store
|
| 6 |
-
the evaluation results on the defined metrics. Learning dynamics checkpoints store activations and gradients used for
|
| 7 |
-
learning dynamics analysis.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
from .evaluation import save_evaluation_results
|
| 11 |
-
from .learning_dynamics import (
|
| 12 |
-
compute_learning_dynamics_states,
|
| 13 |
-
save_learning_dynamics_states,
|
| 14 |
-
)
|
| 15 |
-
from .training import load_checkpoint, save_checkpoint
|
| 16 |
-
|
| 17 |
-
__all__ = [
|
| 18 |
-
"compute_learning_dynamics_states",
|
| 19 |
-
"load_checkpoint",
|
| 20 |
-
"save_checkpoint",
|
| 21 |
-
"save_evaluation_results",
|
| 22 |
-
"save_learning_dynamics_states",
|
| 23 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/checkpointing/evaluation.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utilities for checkpointing evaluation-related states (i.e. evaluation results, etc.)
|
| 3 |
-
|
| 4 |
-
We save the evaluation results in a JSON file at the step-specific evaluation results directory.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import json
|
| 8 |
-
import os
|
| 9 |
-
from typing import Any, Dict
|
| 10 |
-
|
| 11 |
-
from huggingface_hub import upload_folder
|
| 12 |
-
from lightning.fabric import Fabric
|
| 13 |
-
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
| 14 |
-
|
| 15 |
-
from src.config import CheckpointingConfig
|
| 16 |
-
from src.training.utils.io import use_backoff
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
@rank_zero_only
|
| 20 |
-
@use_backoff()
|
| 21 |
-
def save_evaluation_results(
|
| 22 |
-
checkpointing_config: CheckpointingConfig,
|
| 23 |
-
checkpoint_step: int,
|
| 24 |
-
fabric: Fabric,
|
| 25 |
-
evaluation_results: Dict[str, Any],
|
| 26 |
-
) -> None:
|
| 27 |
-
"""Save evaluation results to disk and optionally to HuggingFace Hub.
|
| 28 |
-
|
| 29 |
-
The evaluation results are saved in the following directory structure:
|
| 30 |
-
{checkpointing_config.runs_dir}/
|
| 31 |
-
└── {checkpointing_config.run_name}/
|
| 32 |
-
└── {checkpointing_config.eval_results_dir}/
|
| 33 |
-
└── step_{checkpoint_step}.json
|
| 34 |
-
|
| 35 |
-
NOTE: this function is only called on rank 0 to avoid conflicts; assumes that the evaluation
|
| 36 |
-
results are gathered on rank 0.
|
| 37 |
-
|
| 38 |
-
Args:
|
| 39 |
-
checkpointing_config: Configuration object containing checkpoint settings
|
| 40 |
-
checkpoint_step: Current training checkpoint step (i.e. number of learning steps taken)
|
| 41 |
-
fabric: Lightning Fabric instance
|
| 42 |
-
evaluation_results: Dictionary containing evaluation metrics
|
| 43 |
-
"""
|
| 44 |
-
|
| 45 |
-
run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
|
| 46 |
-
eval_results_dir = os.path.join(
|
| 47 |
-
run_dir, checkpointing_config.evaluation.eval_results_dir
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
os.makedirs(eval_results_dir, exist_ok=True)
|
| 51 |
-
|
| 52 |
-
curr_eval_results_path = os.path.join(
|
| 53 |
-
eval_results_dir, f"step_{checkpoint_step}.json"
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
# save out as json
|
| 57 |
-
with open(curr_eval_results_path, "w") as f:
|
| 58 |
-
json.dump(evaluation_results, f)
|
| 59 |
-
|
| 60 |
-
if checkpointing_config.save_to_hf:
|
| 61 |
-
upload_folder(
|
| 62 |
-
folder_path=eval_results_dir,
|
| 63 |
-
path_in_repo=checkpointing_config.evaluation.eval_results_dir,
|
| 64 |
-
repo_id=checkpointing_config.hf_checkpoint.repo_id,
|
| 65 |
-
commit_message=f"Saving Evaluation Results -- Step {checkpoint_step}",
|
| 66 |
-
revision=checkpointing_config.run_name,
|
| 67 |
-
token=os.getenv("HF_TOKEN"),
|
| 68 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/checkpointing/learning_dynamics.py
DELETED
|
@@ -1,424 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utilities for checkpointing learning dynamics-related states (i.e. activations, weights, grads, etc.)
|
| 3 |
-
|
| 4 |
-
We save the learning dynamics states in a subdirectory of the checkpointing directory.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import os
|
| 8 |
-
import re
|
| 9 |
-
from typing import Dict, Optional
|
| 10 |
-
|
| 11 |
-
import deepspeed
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn as nn
|
| 14 |
-
import torch.optim as optim
|
| 15 |
-
from datasets import Dataset
|
| 16 |
-
from huggingface_hub import upload_folder
|
| 17 |
-
from lightning.fabric import Fabric
|
| 18 |
-
from lightning.fabric.strategies import DeepSpeedStrategy
|
| 19 |
-
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
| 20 |
-
from torch.nn import functional as F
|
| 21 |
-
from torch.utils.data import DataLoader
|
| 22 |
-
from transformers import PreTrainedTokenizerBase
|
| 23 |
-
|
| 24 |
-
from src.config import CheckpointingConfig
|
| 25 |
-
from src.config.checkpointing_config import LearningDynamicsCheckpointingConfig
|
| 26 |
-
from src.training.utils.initialization import initialize_model
|
| 27 |
-
from src.training.utils.io import use_backoff
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# NOTE: DeepSpeed requires a dummy optimizer to be passed in to the setup function
|
| 31 |
-
class DummyOptimizer(optim.Optimizer):
|
| 32 |
-
def __init__(self, params):
|
| 33 |
-
super().__init__(params, defaults={})
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class CheckpointStateExtractor:
|
| 37 |
-
"""
|
| 38 |
-
Class to extract and save the states of a model at a given checkpoint step for learning
|
| 39 |
-
dynamics research.
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
def __init__(
|
| 43 |
-
self,
|
| 44 |
-
learning_dynamics_config: LearningDynamicsCheckpointingConfig,
|
| 45 |
-
fabric: Fabric,
|
| 46 |
-
model: nn.Module,
|
| 47 |
-
):
|
| 48 |
-
self.learning_dynamics_config = learning_dynamics_config
|
| 49 |
-
self.fabric = fabric
|
| 50 |
-
self.model = model
|
| 51 |
-
|
| 52 |
-
def extract_states(self, dataloader, compute_gradients: bool = False):
|
| 53 |
-
"""Extracts model states (activations, weights, and optionally gradients).
|
| 54 |
-
|
| 55 |
-
Given a dataloader, this function will perform a forward pass of the model on each batch,
|
| 56 |
-
and save the activations and weights at each layer. If compute_gradients is True, it will
|
| 57 |
-
also compute the gradients of the model parameters.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
dataloader: The dataloader containing the dataset to extract states from.
|
| 61 |
-
compute_gradients: Whether to compute the gradients of the model parameters.
|
| 62 |
-
|
| 63 |
-
Returns:
|
| 64 |
-
A dictionary containing the activations, weights, and optionally gradients of the model.
|
| 65 |
-
"""
|
| 66 |
-
checkpoint_activations = {}
|
| 67 |
-
checkpoint_weights = {}
|
| 68 |
-
|
| 69 |
-
# NOTE: to extract activations and weights, we need to setup forward hooks on the layers
|
| 70 |
-
# of the model that we are interested in. This is a good intro to forward hooks if you
|
| 71 |
-
# are not familiar: https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/
|
| 72 |
-
forward_hooks = self._setup_forward_hooks(
|
| 73 |
-
checkpoint_activations,
|
| 74 |
-
checkpoint_weights,
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
########################################################
|
| 78 |
-
#
|
| 79 |
-
# Forward Pass: Extract activations and weights; and compute gradients
|
| 80 |
-
#
|
| 81 |
-
########################################################
|
| 82 |
-
|
| 83 |
-
for sub_batch in dataloader:
|
| 84 |
-
_input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device)
|
| 85 |
-
|
| 86 |
-
if compute_gradients:
|
| 87 |
-
if "labels" in sub_batch:
|
| 88 |
-
input_ids = _input_ids
|
| 89 |
-
labels = torch.tensor(
|
| 90 |
-
sub_batch["labels"], device=self.fabric.device
|
| 91 |
-
)
|
| 92 |
-
else:
|
| 93 |
-
input_ids = _input_ids[:, :-1]
|
| 94 |
-
labels = _input_ids[:, 1:]
|
| 95 |
-
else:
|
| 96 |
-
input_ids = _input_ids
|
| 97 |
-
labels = None
|
| 98 |
-
|
| 99 |
-
if labels is None:
|
| 100 |
-
# we can throw away the outputs, we are only interested in the hidden states
|
| 101 |
-
with torch.no_grad():
|
| 102 |
-
_ = self.model(input_ids)
|
| 103 |
-
else:
|
| 104 |
-
# NOTE: if we are computing gradients, calling backwards will compute the gradients
|
| 105 |
-
# of the model parameters.
|
| 106 |
-
outputs, _ = self.model(input_ids)
|
| 107 |
-
outputs = outputs.transpose(1, 2)
|
| 108 |
-
loss = F.cross_entropy(outputs, labels)
|
| 109 |
-
self.fabric.backward(loss, model=self.model)
|
| 110 |
-
|
| 111 |
-
# cleanup forward hooks
|
| 112 |
-
# NOTE this is not strictly necessary, since self.model is a deepcopy of the original model
|
| 113 |
-
# but it is good practice to remove the hooks after the forward pass is complete.
|
| 114 |
-
for hook in forward_hooks:
|
| 115 |
-
hook.remove()
|
| 116 |
-
|
| 117 |
-
########################################################
|
| 118 |
-
#
|
| 119 |
-
# Extract gradients from the target tensors of the model
|
| 120 |
-
#
|
| 121 |
-
########################################################
|
| 122 |
-
|
| 123 |
-
layer_suffixes = self.learning_dynamics_config.layer_suffixes
|
| 124 |
-
checkpoint_gradients = {}
|
| 125 |
-
if compute_gradients:
|
| 126 |
-
for name, param in self.model.named_parameters():
|
| 127 |
-
# only do this for the weight matrix of the layer_suffixes
|
| 128 |
-
if (
|
| 129 |
-
any(layer_suffix in name for layer_suffix in layer_suffixes)
|
| 130 |
-
and "weight" in name
|
| 131 |
-
):
|
| 132 |
-
if isinstance(self.fabric.strategy, DeepSpeedStrategy):
|
| 133 |
-
_grad = deepspeed.utils.safe_get_full_grad(param)
|
| 134 |
-
else:
|
| 135 |
-
_grad = param.grad
|
| 136 |
-
|
| 137 |
-
assert _grad is not None, f"Gradient is None for layer: {name}"
|
| 138 |
-
name = re.sub(r"\.weight", "", name)
|
| 139 |
-
checkpoint_gradients[name] = _grad.detach().cpu()
|
| 140 |
-
|
| 141 |
-
# zero out the gradients
|
| 142 |
-
self.model.zero_grad()
|
| 143 |
-
|
| 144 |
-
return checkpoint_activations, checkpoint_weights, checkpoint_gradients
|
| 145 |
-
|
| 146 |
-
########################################################
|
| 147 |
-
#
|
| 148 |
-
# Setup forward hooks to save activations and weights at each layer
|
| 149 |
-
#
|
| 150 |
-
########################################################
|
| 151 |
-
|
| 152 |
-
def _setup_forward_hooks(self, checkpoint_activations, checkpoint_weights):
|
| 153 |
-
"""Setup forward hooks for the model to save activations and weights at each layer.
|
| 154 |
-
|
| 155 |
-
This function will setup forward hooks on the layers of the model that we are interested in.
|
| 156 |
-
The forward hooks will save the activations and weights at each layer whenever the forward pass
|
| 157 |
-
is performed.
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
checkpoint_activations: A dictionary to store the activations at each layer.
|
| 161 |
-
checkpoint_weights: A dictionary to store the weights at each layer.
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
A list of forward hooks. We do this so that we can remove the hooks after the forward pass
|
| 165 |
-
is complete.
|
| 166 |
-
"""
|
| 167 |
-
|
| 168 |
-
forward_hooks = []
|
| 169 |
-
layer_suffixes = self.learning_dynamics_config.layer_suffixes
|
| 170 |
-
|
| 171 |
-
for name, module in self.model.named_modules():
|
| 172 |
-
if any(layer_suffix in name for layer_suffix in layer_suffixes):
|
| 173 |
-
_forward_hook = module.register_forward_hook(
|
| 174 |
-
self._get_forward_hook(
|
| 175 |
-
name, checkpoint_activations, checkpoint_weights
|
| 176 |
-
)
|
| 177 |
-
)
|
| 178 |
-
forward_hooks.append(_forward_hook)
|
| 179 |
-
return forward_hooks
|
| 180 |
-
|
| 181 |
-
def _get_forward_hook(
|
| 182 |
-
self, module_name, checkpoint_activations, checkpoint_weights
|
| 183 |
-
):
|
| 184 |
-
"""Get a forward hook for a given module.
|
| 185 |
-
|
| 186 |
-
This function is called by the _setup_forward_hooks function to setup a forward hook for a given
|
| 187 |
-
module. This functions is a closure that captures the module_name, checkpoint_activations, and
|
| 188 |
-
checkpoint_weights.
|
| 189 |
-
|
| 190 |
-
Args:
|
| 191 |
-
module_name: The name of the module to setup a forward hook for.
|
| 192 |
-
checkpoint_activations: A dictionary to store the activations at each layer.
|
| 193 |
-
checkpoint_weights: A dictionary to store the weights at each layer.
|
| 194 |
-
|
| 195 |
-
Returns:
|
| 196 |
-
A forward hook for the given module.
|
| 197 |
-
"""
|
| 198 |
-
|
| 199 |
-
def _forward_hook(module, _, module_out):
|
| 200 |
-
sequence_idx = self.learning_dynamics_config.sequence_idx
|
| 201 |
-
|
| 202 |
-
local_activations = module_out[:, sequence_idx, :].detach()
|
| 203 |
-
|
| 204 |
-
# Gather activations from all processes using fabric
|
| 205 |
-
gathered_activations = self.fabric.all_gather(local_activations)
|
| 206 |
-
|
| 207 |
-
# Reshape from [num_processes, batch_size, hidden_dim] to [total_batch_size, hidden_dim]
|
| 208 |
-
# NOTE: transposing allows us to interleave the activations from each process so that
|
| 209 |
-
# they are in the correct order. (i.e. activation N is from data sample N)
|
| 210 |
-
gathered_activations = gathered_activations.transpose(0, 1).reshape(
|
| 211 |
-
-1, gathered_activations.shape[-1]
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
# check if there is already a key for the module name
|
| 215 |
-
if module_name not in checkpoint_activations:
|
| 216 |
-
# if there is no key, then we create a new key and store the hidden states
|
| 217 |
-
checkpoint_activations[module_name] = (
|
| 218 |
-
gathered_activations.detach().cpu()
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
# extract the weight matrix just once
|
| 222 |
-
weight_matrix = module.weight.detach().cpu()
|
| 223 |
-
checkpoint_weights[module_name] = weight_matrix
|
| 224 |
-
else:
|
| 225 |
-
# if there is already a key, then we concatenate the new hidden states to the existing ones
|
| 226 |
-
checkpoint_activations[module_name] = torch.cat(
|
| 227 |
-
(
|
| 228 |
-
checkpoint_activations[module_name],
|
| 229 |
-
gathered_activations.detach().cpu(),
|
| 230 |
-
)
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
return _forward_hook
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
def compute_learning_dynamics_states(
|
| 237 |
-
checkpointing_config: CheckpointingConfig,
|
| 238 |
-
fabric: Fabric,
|
| 239 |
-
model: nn.Module,
|
| 240 |
-
dataset: Dataset,
|
| 241 |
-
compute_gradients: bool = False,
|
| 242 |
-
) -> Dict[str, torch.Tensor]:
|
| 243 |
-
"""Computes the learning dynamics metrics for a given checkpoint step.
|
| 244 |
-
|
| 245 |
-
Uses the CheckpointStateExtractor to extract the activations, weights, and optionally gradients
|
| 246 |
-
of the model at a given checkpoint step.
|
| 247 |
-
|
| 248 |
-
Args:
|
| 249 |
-
checkpointing_config: The configuration object for checkpointing.
|
| 250 |
-
fabric: The Fabric instance for distributed training.
|
| 251 |
-
model: The model to extract states from.
|
| 252 |
-
dataset: The dataset to extract states from.
|
| 253 |
-
compute_gradients: Whether to compute the gradients of the model parameters.
|
| 254 |
-
|
| 255 |
-
Returns:
|
| 256 |
-
A dictionary containing the activations, weights, and optionally gradients of the model.
|
| 257 |
-
"""
|
| 258 |
-
|
| 259 |
-
# NOTE: Synchronizing processes for fabric dataloader setup
|
| 260 |
-
fabric.barrier()
|
| 261 |
-
model.to("cpu") # Offloading model to CPU
|
| 262 |
-
|
| 263 |
-
# Setting up Dataloader for learning dynamics
|
| 264 |
-
def _collate_fn(batch):
|
| 265 |
-
return {"input_ids": [entry["input_ids"] for entry in batch]}
|
| 266 |
-
|
| 267 |
-
batch_size = checkpointing_config.learning_dynamics.batch_size
|
| 268 |
-
sub_batch_size = batch_size // fabric.world_size
|
| 269 |
-
|
| 270 |
-
# NOTE: Make sure to set drop_last to False, otherwise the last batch will be dropped
|
| 271 |
-
# and we will not have a complete set of activations for the last sample. Also,
|
| 272 |
-
# we need to set shuffle to False, otherwise the activations will be shuffled across
|
| 273 |
-
# processes and we will not be able to interleave them correctly.
|
| 274 |
-
extractor_dataloader = DataLoader(
|
| 275 |
-
dataset,
|
| 276 |
-
batch_size=sub_batch_size,
|
| 277 |
-
shuffle=False,
|
| 278 |
-
collate_fn=_collate_fn,
|
| 279 |
-
drop_last=False,
|
| 280 |
-
)
|
| 281 |
-
extractor_dataloader = fabric.setup_dataloaders(
|
| 282 |
-
extractor_dataloader, use_distributed_sampler=True
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
# Create a new model instance with same parameters but zero gradients
|
| 286 |
-
_model = initialize_model(model.config)
|
| 287 |
-
_model.load_state_dict(model.state_dict())
|
| 288 |
-
|
| 289 |
-
if isinstance(fabric.strategy, DeepSpeedStrategy):
|
| 290 |
-
_model, _ = fabric.setup(_model, DummyOptimizer(_model.parameters()))
|
| 291 |
-
else:
|
| 292 |
-
_model = fabric.setup(_model)
|
| 293 |
-
|
| 294 |
-
_model.zero_grad()
|
| 295 |
-
|
| 296 |
-
# setup forward hooks for the model to save activations and weights at each layer
|
| 297 |
-
state_extractor = CheckpointStateExtractor(
|
| 298 |
-
checkpointing_config.learning_dynamics, fabric, _model
|
| 299 |
-
)
|
| 300 |
-
|
| 301 |
-
checkpoint_activations, checkpoint_weights, checkpoint_gradients = (
|
| 302 |
-
state_extractor.extract_states(
|
| 303 |
-
extractor_dataloader, compute_gradients=compute_gradients
|
| 304 |
-
)
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
del _model
|
| 308 |
-
torch.cuda.empty_cache()
|
| 309 |
-
|
| 310 |
-
# NOTE: Synchronizing processes for model setup
|
| 311 |
-
fabric.barrier()
|
| 312 |
-
|
| 313 |
-
model.to(fabric.device)
|
| 314 |
-
|
| 315 |
-
# NOTE: Trimming down the activations to match the dataset size;
|
| 316 |
-
# This is because the DataSampler might add extra samples to the dataset to make it evenly divisible
|
| 317 |
-
# by the number of processes. We need to remove these extra samples.
|
| 318 |
-
for layer_name, layer_activations in checkpoint_activations.items():
|
| 319 |
-
if len(layer_activations) > len(dataset):
|
| 320 |
-
checkpoint_activations[layer_name] = layer_activations[: len(dataset)]
|
| 321 |
-
elif len(layer_activations) < len(dataset):
|
| 322 |
-
raise ValueError(
|
| 323 |
-
f"Number of activations ({len(layer_activations)}) in layer {layer_name} does not match number of samples in dataset ({len(dataset)})"
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
return {
|
| 327 |
-
"activations": checkpoint_activations,
|
| 328 |
-
"weights": checkpoint_weights,
|
| 329 |
-
"gradients": checkpoint_gradients,
|
| 330 |
-
}
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
@rank_zero_only
|
| 334 |
-
@use_backoff()
|
| 335 |
-
def save_learning_dynamics_states(
|
| 336 |
-
checkpointing_config: CheckpointingConfig,
|
| 337 |
-
checkpoint_step: int,
|
| 338 |
-
prefix: str,
|
| 339 |
-
fabric: Fabric,
|
| 340 |
-
learning_dynamics_states: Dict[str, torch.Tensor],
|
| 341 |
-
learning_dynamics_dataset: Optional[Dataset] = None,
|
| 342 |
-
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 343 |
-
) -> None:
|
| 344 |
-
"""Save the learning dynamics metrics to the checkpointing directory.
|
| 345 |
-
|
| 346 |
-
By default only the learning dynamics states are saved. If the learning dynamics dataset
|
| 347 |
-
is provided, it is also saved; if a tokenizer is provided, the dataset is also detokenized
|
| 348 |
-
(i.e. a new column with the text is added to the dataset).
|
| 349 |
-
|
| 350 |
-
The learning dynamics dataset is saved in the checkpointing directory as a HuggingFace
|
| 351 |
-
dataset.
|
| 352 |
-
|
| 353 |
-
Creates a versioned checkpoint directory with the following structure:
|
| 354 |
-
|
| 355 |
-
{checkpointing_config.runs_dir}/
|
| 356 |
-
└── {checkpointing_config.run_name}/
|
| 357 |
-
└── {checkpointing_config.checkpoints_dir}/
|
| 358 |
-
├── step_{checkpoint_step}/
|
| 359 |
-
│ └── {checkpointing_config.learning_dynamics_dir}/ # Learning Dynamics files
|
| 360 |
-
│ ├── {prefix}_activations.pt
|
| 361 |
-
│ ├── {prefix}_weights.pt
|
| 362 |
-
│ └── {prefix}_gradients.pt
|
| 363 |
-
│ └── {prefix}_data/ # if learning_dynamics_dataset is provided
|
| 364 |
-
└── latest -> step_{checkpoint_step}/
|
| 365 |
-
|
| 366 |
-
NOTE: this function is only called on rank 0
|
| 367 |
-
|
| 368 |
-
Args:
|
| 369 |
-
checkpointing_config: The configuration object for checkpointing.
|
| 370 |
-
checkpoint_step: The checkpoint step at which the learning dynamics states were computed.
|
| 371 |
-
prefix: The prefix for the learning dynamics states.
|
| 372 |
-
fabric: The Fabric instance for distributed training.
|
| 373 |
-
learning_dynamics_states: The learning dynamics states to save.
|
| 374 |
-
learning_dynamics_dataset: The dataset containing learning dynamics data,
|
| 375 |
-
including input IDs that need to be decoded. (optional)
|
| 376 |
-
tokenizer: The tokenizer used to decode input IDs into text. (optional)
|
| 377 |
-
"""
|
| 378 |
-
|
| 379 |
-
runs_dir = checkpointing_config.runs_dir
|
| 380 |
-
run_name = checkpointing_config.run_name
|
| 381 |
-
checkpoints_dir = checkpointing_config.checkpoints_dir
|
| 382 |
-
learning_dynamics_dir = checkpointing_config.learning_dynamics_dir
|
| 383 |
-
|
| 384 |
-
run_path = os.path.join(runs_dir, run_name)
|
| 385 |
-
root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
|
| 386 |
-
checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
|
| 387 |
-
learning_dynamics_path = os.path.join(checkpoint_path, learning_dynamics_dir)
|
| 388 |
-
os.makedirs(learning_dynamics_path, exist_ok=True)
|
| 389 |
-
|
| 390 |
-
# save the learning dynamics states
|
| 391 |
-
for key, value in learning_dynamics_states.items():
|
| 392 |
-
if value is not None and len(value) > 0:
|
| 393 |
-
torch.save(
|
| 394 |
-
value, os.path.join(learning_dynamics_path, f"{prefix}_{key}.pt")
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
if learning_dynamics_dataset is not None:
|
| 398 |
-
if tokenizer is not None:
|
| 399 |
-
# go through dataset and decode the input ids; and add back into dataset
|
| 400 |
-
detokenized_dataset = {"input_ids": [], "text": []}
|
| 401 |
-
|
| 402 |
-
for entry in learning_dynamics_dataset:
|
| 403 |
-
input_ids = entry["input_ids"]
|
| 404 |
-
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True)
|
| 405 |
-
detokenized_dataset["input_ids"].append(input_ids)
|
| 406 |
-
detokenized_dataset["text"].append(decoded_text)
|
| 407 |
-
|
| 408 |
-
learning_dynamics_dataset = Dataset.from_dict(detokenized_dataset)
|
| 409 |
-
|
| 410 |
-
learning_dynamics_dataset_path = os.path.join(
|
| 411 |
-
learning_dynamics_path, f"{prefix}_data"
|
| 412 |
-
)
|
| 413 |
-
learning_dynamics_dataset.save_to_disk(learning_dynamics_dataset_path)
|
| 414 |
-
|
| 415 |
-
if checkpointing_config.save_to_hf:
|
| 416 |
-
# Upload the HF model
|
| 417 |
-
upload_folder(
|
| 418 |
-
folder_path=learning_dynamics_path,
|
| 419 |
-
path_in_repo=learning_dynamics_dir,
|
| 420 |
-
repo_id=checkpointing_config.hf_checkpoint.repo_id,
|
| 421 |
-
commit_message=f"Saving Learning Dynamics Data ({prefix}) -- Step {checkpoint_step}",
|
| 422 |
-
revision=checkpointing_config.run_name,
|
| 423 |
-
token=os.getenv("HF_TOKEN"),
|
| 424 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/checkpointing/training.py
DELETED
|
@@ -1,287 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utilities for checkpointing training-related states (i.e. model, optimizer, lr_scheduler, etc.)
|
| 3 |
-
|
| 4 |
-
We save both a HuggingFace model and a Fabric-specific checkpoint. The HuggingFace model is
|
| 5 |
-
saved at the step-specific checkpoint directory, while the Fabric-specific checkpoint is saved
|
| 6 |
-
in a subdirectory. This is done to facilitate easier versioning of the HuggingFace model files
|
| 7 |
-
(which are what gets uploaded to the Hub).
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import os
|
| 11 |
-
from dataclasses import asdict
|
| 12 |
-
from typing import Any, Dict, Tuple, Union
|
| 13 |
-
|
| 14 |
-
import yaml
|
| 15 |
-
from huggingface_hub import upload_file, upload_folder
|
| 16 |
-
from lightning.fabric import Fabric
|
| 17 |
-
from lightning.fabric.strategies import DeepSpeedStrategy
|
| 18 |
-
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
|
| 19 |
-
from torch import nn
|
| 20 |
-
from torch.optim import Optimizer
|
| 21 |
-
from torch.optim.lr_scheduler import LRScheduler
|
| 22 |
-
from transformers import PreTrainedTokenizerBase
|
| 23 |
-
|
| 24 |
-
from src.config import CheckpointingConfig
|
| 25 |
-
from src.training.utils.io import use_backoff
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
@use_backoff()
|
| 29 |
-
def load_checkpoint(
|
| 30 |
-
checkpointing_config: CheckpointingConfig,
|
| 31 |
-
checkpoint_step: Union[str, int],
|
| 32 |
-
fabric: Fabric,
|
| 33 |
-
model: nn.Module,
|
| 34 |
-
optimizer: Optimizer,
|
| 35 |
-
lr_scheduler: LRScheduler,
|
| 36 |
-
) -> Tuple[nn.Module, Optimizer, LRScheduler, int]:
|
| 37 |
-
"""Load model checkpoint and associated states from a given step.
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
checkpointing_config: Configuration object containing checkpoint settings
|
| 41 |
-
checkpoint_step: The step at which to load the checkpoint
|
| 42 |
-
fabric: Lightning Fabric instance for distributed training support
|
| 43 |
-
model: The model instance to load weights into
|
| 44 |
-
optimizer: The optimizer instance to load states into
|
| 45 |
-
lr_scheduler: The learning rate scheduler to load states into
|
| 46 |
-
|
| 47 |
-
Returns:
|
| 48 |
-
Tuple containing the model, optimizer, lr_scheduler, and checkpoint step.
|
| 49 |
-
Returns None if no checkpoint is found.
|
| 50 |
-
"""
|
| 51 |
-
|
| 52 |
-
if isinstance(checkpoint_step, int):
|
| 53 |
-
checkpoint_step = f"step_{checkpoint_step}"
|
| 54 |
-
|
| 55 |
-
checkpoint_path = os.path.join(
|
| 56 |
-
checkpointing_config.runs_dir,
|
| 57 |
-
checkpointing_config.run_name,
|
| 58 |
-
checkpointing_config.checkpoints_dir,
|
| 59 |
-
checkpoint_step,
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
if not os.path.exists(checkpoint_path):
|
| 63 |
-
return None
|
| 64 |
-
|
| 65 |
-
# Load from specified fabric checkpoint subdirectory
|
| 66 |
-
fabric_checkpoint_path = os.path.join(
|
| 67 |
-
checkpoint_path, checkpointing_config.fabric_checkpoint_dir
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
checkpoint_state = {
|
| 71 |
-
"_model": model,
|
| 72 |
-
"_optimizer": optimizer,
|
| 73 |
-
"_lr_scheduler": lr_scheduler,
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
if not isinstance(fabric.strategy, DeepSpeedStrategy):
|
| 77 |
-
fabric_load_file = os.path.join(
|
| 78 |
-
fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
|
| 79 |
-
)
|
| 80 |
-
else:
|
| 81 |
-
# Deepspeed checkpoints create sub-directory with distributed checkpoint file
|
| 82 |
-
fabric_load_file = fabric_checkpoint_path
|
| 83 |
-
|
| 84 |
-
extra_state = fabric.load(os.path.join(fabric_load_file), state=checkpoint_state)
|
| 85 |
-
|
| 86 |
-
# NOTE: extra_state will contain any additional states that were saved in the checkpoint
|
| 87 |
-
checkpoint_step = extra_state["_checkpoint_step"]
|
| 88 |
-
|
| 89 |
-
if "_rng_states" in extra_state:
|
| 90 |
-
_rng_states = extra_state["_rng_states"]
|
| 91 |
-
_set_rng_states(_rng_states)
|
| 92 |
-
|
| 93 |
-
return model, optimizer, lr_scheduler, checkpoint_step
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
@use_backoff()
|
| 97 |
-
def save_checkpoint(
|
| 98 |
-
configs: Dict[str, Any],
|
| 99 |
-
checkpoint_step: int,
|
| 100 |
-
fabric: Fabric,
|
| 101 |
-
model: nn.Module,
|
| 102 |
-
optimizer: Optimizer,
|
| 103 |
-
lr_scheduler: LRScheduler,
|
| 104 |
-
tokenizer: PreTrainedTokenizerBase,
|
| 105 |
-
upload_logs: bool = False,
|
| 106 |
-
) -> None:
|
| 107 |
-
"""Save training checkpoint and associated states to disk and optionally to HuggingFace Hub.
|
| 108 |
-
|
| 109 |
-
We save the following files:
|
| 110 |
-
- HuggingFace model files (config.json, pytorch_model.bin)
|
| 111 |
-
- Tokenizer files (vocab.json, merges.txt)
|
| 112 |
-
- Fabric-specific files - fabric state of the model, optimizer, and lr_scheduler. If using
|
| 113 |
-
DeepSpeed, the checkpoint is saved in a subdirectory, otherwise it is saved in a single file.
|
| 114 |
-
|
| 115 |
-
Note that the HuggingFace model files are saved at the step-specific checkpoint directory, while the
|
| 116 |
-
Fabric-specific files are saved in a subdirectory. This is done to facilitate easier
|
| 117 |
-
versioning of the HuggingFace model files (which are what gets uploaded to the Hub).
|
| 118 |
-
|
| 119 |
-
NOTE: Why do we save a HF model at all? We do this because it makes it easier to load the model
|
| 120 |
-
in a separate script for evaluation and to play nicely with the HuggingFace Hub.
|
| 121 |
-
|
| 122 |
-
Creates a versioned checkpoint directory with the following structure:
|
| 123 |
-
|
| 124 |
-
{checkpointing_config.runs_dir}/
|
| 125 |
-
└── {checkpointing_config.run_name}/
|
| 126 |
-
└── training_config.yaml # Training config
|
| 127 |
-
└── {checkpointing_config.checkpoints_dir}/
|
| 128 |
-
├── step_{checkpoint_step}/
|
| 129 |
-
│ ├── config.json # HuggingFace model config
|
| 130 |
-
│ ├── model.safetensors # HuggingFace model weights
|
| 131 |
-
│ ├── pico_{model_type}.py # HuggingFace custom model class
|
| 132 |
-
│ ├── tokenizer.json # Tokenizer vocab
|
| 133 |
-
│ ├── tokenizer_config.json # Tokenizer config
|
| 134 |
-
│ └── {checkpointing_config.fabric_checkpoint_dir}/ # Fabric-specific files
|
| 135 |
-
│ └── checkpoint/ # Distributed model checkpoint files (if using DeepSpeed)
|
| 136 |
-
│ OR
|
| 137 |
-
│ └── checkpoint.pt # Single checkpoint file (if using other strategies)
|
| 138 |
-
└── latest -> step_{checkpoint_step}/
|
| 139 |
-
|
| 140 |
-
Args:
|
| 141 |
-
configs: A dictionary containing the initialized configuration objects.
|
| 142 |
-
checkpoint_step: The current training checkpoint step (i.e. number of learning steps taken)
|
| 143 |
-
fabric: Lightning Fabric instance for distributed training support
|
| 144 |
-
model: The model instance to save
|
| 145 |
-
optimizer: The optimizer instance to save
|
| 146 |
-
lr_scheduler: The learning rate scheduler to save
|
| 147 |
-
tokenizer: The tokenizer to save
|
| 148 |
-
upload_logs: Whether to upload training logs to HF Hub (default: False)
|
| 149 |
-
|
| 150 |
-
"""
|
| 151 |
-
|
| 152 |
-
checkpointing_config = configs["checkpointing"]
|
| 153 |
-
|
| 154 |
-
# Get the directories from the training config
|
| 155 |
-
runs_dir = checkpointing_config.runs_dir
|
| 156 |
-
checkpoints_dir = checkpointing_config.checkpoints_dir
|
| 157 |
-
fabric_checkpoint_dir = checkpointing_config.fabric_checkpoint_dir
|
| 158 |
-
logs_dir = checkpointing_config.logs_dir
|
| 159 |
-
|
| 160 |
-
run_path = os.path.join(runs_dir, checkpointing_config.run_name)
|
| 161 |
-
root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
|
| 162 |
-
checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
|
| 163 |
-
|
| 164 |
-
# Create directories
|
| 165 |
-
os.makedirs(checkpoint_path, exist_ok=True)
|
| 166 |
-
|
| 167 |
-
########################################################
|
| 168 |
-
#
|
| 169 |
-
# Save HuggingFace files
|
| 170 |
-
#
|
| 171 |
-
########################################################
|
| 172 |
-
|
| 173 |
-
# NOTE: we convert the Pico model to a HuggingFace model before saving it. See `model.py`
|
| 174 |
-
# for more details.
|
| 175 |
-
if fabric.global_rank == 0:
|
| 176 |
-
hf_model = model.convert_to_hf_model()
|
| 177 |
-
hf_model.save_pretrained(checkpoint_path)
|
| 178 |
-
tokenizer.save_pretrained(checkpoint_path)
|
| 179 |
-
|
| 180 |
-
########################################################
|
| 181 |
-
#
|
| 182 |
-
# Save Fabric-specific files
|
| 183 |
-
#
|
| 184 |
-
########################################################
|
| 185 |
-
|
| 186 |
-
# Create fabric-specific subdirectory
|
| 187 |
-
fabric_checkpoint_path = os.path.join(checkpoint_path, fabric_checkpoint_dir)
|
| 188 |
-
os.makedirs(fabric_checkpoint_path, exist_ok=True)
|
| 189 |
-
|
| 190 |
-
# Save model states (use underscore to avoid conflicts with third-party libraries)
|
| 191 |
-
checkpoint_state = {
|
| 192 |
-
"_model": model,
|
| 193 |
-
"_optimizer": optimizer,
|
| 194 |
-
"_lr_scheduler": lr_scheduler,
|
| 195 |
-
"_checkpoint_step": checkpoint_step,
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
if not isinstance(fabric.strategy, DeepSpeedStrategy):
|
| 199 |
-
checkpoint_state["_rng_states"] = _collect_rng_states()
|
| 200 |
-
fabric_save_file = os.path.join(
|
| 201 |
-
fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
|
| 202 |
-
)
|
| 203 |
-
else:
|
| 204 |
-
# Deepspeed checkpoints create sub-directory with distributed checkpoint file
|
| 205 |
-
fabric_save_file = fabric_checkpoint_path
|
| 206 |
-
|
| 207 |
-
fabric.save(fabric_save_file, checkpoint_state)
|
| 208 |
-
|
| 209 |
-
if fabric.global_rank == 0:
|
| 210 |
-
# Save config in fabric directory
|
| 211 |
-
config_path = os.path.join(run_path, "training_config.yaml")
|
| 212 |
-
if not os.path.exists(config_path):
|
| 213 |
-
# Converting dataclasses to joined dicts and saving to file
|
| 214 |
-
_training_config = {}
|
| 215 |
-
for config_name, config in configs.items():
|
| 216 |
-
_training_config[config_name] = asdict(config)
|
| 217 |
-
with open(config_path, "w") as f:
|
| 218 |
-
yaml.dump(_training_config, f)
|
| 219 |
-
|
| 220 |
-
# Update latest symlink
|
| 221 |
-
latest_symlink_path = os.path.join(root_checkpoint_path, "latest")
|
| 222 |
-
if os.path.lexists(latest_symlink_path):
|
| 223 |
-
os.remove(latest_symlink_path)
|
| 224 |
-
os.symlink(
|
| 225 |
-
f"step_{checkpoint_step}", latest_symlink_path, target_is_directory=True
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
########################################################
|
| 229 |
-
#
|
| 230 |
-
# Push to HuggingFace Hub (if configured)
|
| 231 |
-
#
|
| 232 |
-
########################################################
|
| 233 |
-
|
| 234 |
-
if fabric.global_rank == 0:
|
| 235 |
-
# Push only on rank zero thread
|
| 236 |
-
|
| 237 |
-
if checkpointing_config.save_to_hf:
|
| 238 |
-
repo_id = checkpointing_config.hf_checkpoint.repo_id
|
| 239 |
-
|
| 240 |
-
# Upload the HF model
|
| 241 |
-
hf_model.push_to_hub(
|
| 242 |
-
repo_id=repo_id,
|
| 243 |
-
commit_message=f"Saving HF Model -- Step {checkpoint_step}",
|
| 244 |
-
revision=checkpointing_config.run_name,
|
| 245 |
-
token=os.getenv("HF_TOKEN"),
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
if checkpoint_step == 0:
|
| 249 |
-
# Uploading Tokenizer during first step since it never changes
|
| 250 |
-
tokenizer.push_to_hub(
|
| 251 |
-
repo_id=repo_id,
|
| 252 |
-
commit_message=f"Saving Tokenizer -- Step {checkpoint_step}",
|
| 253 |
-
revision=checkpointing_config.run_name,
|
| 254 |
-
token=os.getenv("HF_TOKEN"),
|
| 255 |
-
)
|
| 256 |
-
|
| 257 |
-
# Upload training config, also only in first step
|
| 258 |
-
upload_file(
|
| 259 |
-
path_or_fileobj=config_path,
|
| 260 |
-
path_in_repo="training_config.yaml",
|
| 261 |
-
repo_id=repo_id,
|
| 262 |
-
commit_message=f"Saving Training Config -- Step {checkpoint_step}",
|
| 263 |
-
revision=checkpointing_config.run_name,
|
| 264 |
-
token=os.getenv("HF_TOKEN"),
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
# Upload the fabric checkpoint directory
|
| 268 |
-
upload_folder(
|
| 269 |
-
folder_path=fabric_checkpoint_path,
|
| 270 |
-
path_in_repo=fabric_checkpoint_dir,
|
| 271 |
-
repo_id=repo_id,
|
| 272 |
-
commit_message=f"Saving Fabric Checkpoint -- Step {checkpoint_step}",
|
| 273 |
-
revision=checkpointing_config.run_name,
|
| 274 |
-
token=os.getenv("HF_TOKEN"),
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
# Upload logs if requested
|
| 278 |
-
if upload_logs:
|
| 279 |
-
logs_path = os.path.join(run_path, logs_dir)
|
| 280 |
-
upload_folder(
|
| 281 |
-
folder_path=logs_path,
|
| 282 |
-
path_in_repo=logs_dir,
|
| 283 |
-
repo_id=repo_id,
|
| 284 |
-
commit_message=f"Saving Logs -- Step {checkpoint_step}",
|
| 285 |
-
revision=checkpointing_config.run_name,
|
| 286 |
-
token=os.getenv("HF_TOKEN"),
|
| 287 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config/__init__.py
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pico Config Package
|
| 3 |
-
|
| 4 |
-
The modules of this package are where you can specify the hyperparameters for the Pico model,
|
| 5 |
-
the dataset, the training process, evaluation, etc.
|
| 6 |
-
|
| 7 |
-
As with anything else in Pico, we've designed for the configuration setup to be as flexible
|
| 8 |
-
as possible. By default the configs are implemented as vanilla dataclasses -- this makes it easy to
|
| 9 |
-
switch to different config management systems if you want, like hydra.
|
| 10 |
-
|
| 11 |
-
Some things to NOTE:
|
| 12 |
-
- All hyperparameters are initialized with default values, which can be overridden.
|
| 13 |
-
- The default vocab size is set to the size of the OLMo tokenizer.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
# For convenience, we export the config classes here
|
| 17 |
-
from .checkpointing_config import CheckpointingConfig
|
| 18 |
-
from .data_config import DataConfig
|
| 19 |
-
from .evaluation_config import EvaluationConfig
|
| 20 |
-
from .model_config import ModelConfig
|
| 21 |
-
from .monitoring_config import MonitoringConfig
|
| 22 |
-
from .training_config import TrainingConfig
|
| 23 |
-
|
| 24 |
-
__all__ = [
|
| 25 |
-
"CheckpointingConfig",
|
| 26 |
-
"DataConfig",
|
| 27 |
-
"EvaluationConfig",
|
| 28 |
-
"ModelConfig",
|
| 29 |
-
"MonitoringConfig",
|
| 30 |
-
"TrainingConfig",
|
| 31 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config/_constants.py
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Constants used throughout the codebase
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
# Basic Training Constants used throughout the codebase
|
| 6 |
-
VOCAB_SIZE = 50304
|
| 7 |
-
MAX_SEQ_LEN = 2048
|
| 8 |
-
BATCH_SIZE = 1024
|
| 9 |
-
GRADIENT_ACCUMULATION_STEPS = 128
|
| 10 |
-
|
| 11 |
-
# Directories used to store training runs, checkpoints, logs, and evaluation results
|
| 12 |
-
RUNS_DIR = "runs"
|
| 13 |
-
CHECKPOINTS_DIR = "checkpoints"
|
| 14 |
-
LOGS_DIR = "logs"
|
| 15 |
-
FABRIC_CHECKPOINT_DIR = "fabric_state"
|
| 16 |
-
FABRIC_CHECKPOINT_FILENAME = "checkpoint.pt"
|
| 17 |
-
LEARNING_DYNAMICS_DIR = "learning_dynamics"
|
| 18 |
-
EVAL_RESULTS_DIR = "eval_results"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config/checkpointing_config.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Checkpointing Config
|
| 3 |
-
|
| 4 |
-
Specifies the hyperparameters for the checkpointing process; checkpointing is used to save
|
| 5 |
-
the model and optimizer states, as well as the learning dynamics metrics.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from dataclasses import dataclass, field
|
| 9 |
-
from typing import List, Optional
|
| 10 |
-
|
| 11 |
-
from ._constants import (
|
| 12 |
-
CHECKPOINTS_DIR,
|
| 13 |
-
EVAL_RESULTS_DIR,
|
| 14 |
-
FABRIC_CHECKPOINT_DIR,
|
| 15 |
-
FABRIC_CHECKPOINT_FILENAME,
|
| 16 |
-
LEARNING_DYNAMICS_DIR,
|
| 17 |
-
LOGS_DIR,
|
| 18 |
-
RUNS_DIR,
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@dataclass
|
| 23 |
-
class TrainingCheckpointingConfig:
|
| 24 |
-
# Automatically resume training from the most recent checkpoint
|
| 25 |
-
auto_resume: bool = True
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
@dataclass
|
| 29 |
-
class EvaluationCheckpointingConfig:
|
| 30 |
-
# Directory in which evaluation results are saved
|
| 31 |
-
eval_results_dir: str = EVAL_RESULTS_DIR
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@dataclass
|
| 35 |
-
class LearningDynamicsCheckpointingConfig:
|
| 36 |
-
# Suffixes of the layers to compute learning dynamics for
|
| 37 |
-
layer_suffixes: List[str] = field(
|
| 38 |
-
default_factory=lambda: [
|
| 39 |
-
"attention.v_proj",
|
| 40 |
-
"attention.o_proj",
|
| 41 |
-
"swiglu.w_2",
|
| 42 |
-
]
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
# Sequence index at which to extract hidden states; by default, we extract the hidden states
|
| 46 |
-
# at the last token of the sequence (-1)
|
| 47 |
-
sequence_idx: int = -1
|
| 48 |
-
|
| 49 |
-
# size of the sub-batch used for extracting learning dynamics states
|
| 50 |
-
batch_size: int = 8
|
| 51 |
-
|
| 52 |
-
# Path to evaluation dataset - used across learning dynamics checkpointing for consistency
|
| 53 |
-
# NOTE: set to None to disable extracting learning dynamics states for an eval_batch
|
| 54 |
-
# NOTE: this dataset should be small, ideally just a batch of additional data
|
| 55 |
-
eval_data: Optional[str] = "pico-lm/pretokenized-paloma-tinsy"
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
@dataclass
|
| 59 |
-
class HuggingFaceCheckpointingConfig:
|
| 60 |
-
# Should be in the format of <(username or organization name)>/<repo_name>, e.g. pico-lm/demo
|
| 61 |
-
repo_id: str = ""
|
| 62 |
-
|
| 63 |
-
# HuggingFace Collection Slug (specifies a tag for the run)
|
| 64 |
-
collection_slug: Optional[str] = None
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
@dataclass
|
| 68 |
-
class CheckpointingConfig:
|
| 69 |
-
# Assign a name to the run
|
| 70 |
-
run_name: Optional[str] = None
|
| 71 |
-
|
| 72 |
-
# Defining checkpointing directories
|
| 73 |
-
runs_dir: str = RUNS_DIR
|
| 74 |
-
checkpoints_dir: str = CHECKPOINTS_DIR
|
| 75 |
-
logs_dir: str = LOGS_DIR
|
| 76 |
-
fabric_checkpoint_dir: str = FABRIC_CHECKPOINT_DIR
|
| 77 |
-
fabric_checkpoint_filename: str = FABRIC_CHECKPOINT_FILENAME
|
| 78 |
-
learning_dynamics_dir: str = LEARNING_DYNAMICS_DIR
|
| 79 |
-
|
| 80 |
-
# How often to save checkpoints
|
| 81 |
-
save_every_n_steps: int = 1000
|
| 82 |
-
|
| 83 |
-
# Whether to save checkpoints to HuggingFace
|
| 84 |
-
save_to_hf: Optional[bool] = False
|
| 85 |
-
hf_checkpoint: HuggingFaceCheckpointingConfig = field(
|
| 86 |
-
default_factory=HuggingFaceCheckpointingConfig
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
training: TrainingCheckpointingConfig = field(
|
| 90 |
-
default_factory=TrainingCheckpointingConfig
|
| 91 |
-
)
|
| 92 |
-
evaluation: EvaluationCheckpointingConfig = field(
|
| 93 |
-
default_factory=EvaluationCheckpointingConfig
|
| 94 |
-
)
|
| 95 |
-
learning_dynamics: LearningDynamicsCheckpointingConfig = field(
|
| 96 |
-
default_factory=LearningDynamicsCheckpointingConfig
|
| 97 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config/data_config.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Data Config
|
| 3 |
-
|
| 4 |
-
Specifies the hyperparameters for the dataset, dataloader, and tokenizer.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from dataclasses import dataclass, field
|
| 8 |
-
|
| 9 |
-
from ._constants import BATCH_SIZE, VOCAB_SIZE
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@dataclass
|
| 13 |
-
class DatasetConfig:
|
| 14 |
-
# Defines the HuggingFace name of a dataset
|
| 15 |
-
name: str = "pico-lm/pretokenized-dolma"
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
@dataclass
|
| 19 |
-
class DataLoaderConfig:
|
| 20 |
-
# NOTE: You should only change these values jointly with the training config; so that the
|
| 21 |
-
# sub-batch size is consistent with the gradient accumulation steps
|
| 22 |
-
batch_size: int = BATCH_SIZE
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class TokenizerConfig:
|
| 27 |
-
# Specify a tokenizer to use
|
| 28 |
-
name: str = "allenai/OLMo-7B-0724-hf"
|
| 29 |
-
vocab_size: int = VOCAB_SIZE
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
@dataclass
|
| 33 |
-
class DataConfig:
|
| 34 |
-
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
| 35 |
-
dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig)
|
| 36 |
-
tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config/evaluation_config.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Evaluation Config
|
| 3 |
-
|
| 4 |
-
Specifies the hyperparameters for the evaluation process, i.e. what metrics to compute, etc.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from dataclasses import dataclass, field
|
| 8 |
-
from typing import List, Optional
|
| 9 |
-
|
| 10 |
-
from src.config._constants import MAX_SEQ_LEN
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@dataclass
|
| 14 |
-
class PalomaEvaluationConfig:
|
| 15 |
-
dataset_name: str = "pico-lm/pretokenized-paloma-tinsy"
|
| 16 |
-
dataset_split: str = "val"
|
| 17 |
-
max_length: int = MAX_SEQ_LEN
|
| 18 |
-
batch_size: int = 16
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@dataclass
|
| 22 |
-
class EvaluationConfig:
|
| 23 |
-
# Evaluation metrics to compute: by default, we compute the perplexity of the model on the paloma dataset
|
| 24 |
-
metrics: Optional[List[str]] = field(default_factory=lambda: ["paloma"])
|
| 25 |
-
|
| 26 |
-
# NOTE: Add other evaluation configs here
|
| 27 |
-
# Each evaluation metric should have its own config
|
| 28 |
-
paloma: PalomaEvaluationConfig = field(default_factory=PalomaEvaluationConfig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config/model_config.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Model Config
|
| 3 |
-
|
| 4 |
-
Specifies the hyperparameters for the Pico model/model architecture.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from dataclasses import dataclass
|
| 8 |
-
from typing import Optional
|
| 9 |
-
|
| 10 |
-
from ._constants import BATCH_SIZE, MAX_SEQ_LEN, VOCAB_SIZE
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@dataclass
|
| 14 |
-
class ModelConfig:
|
| 15 |
-
model_type: str = "pico_decoder"
|
| 16 |
-
|
| 17 |
-
# Pico Decoder default hyperparameters
|
| 18 |
-
|
| 19 |
-
d_model: int = 768
|
| 20 |
-
n_layers: int = 12
|
| 21 |
-
|
| 22 |
-
vocab_size: int = VOCAB_SIZE
|
| 23 |
-
batch_size: int = BATCH_SIZE
|
| 24 |
-
max_seq_len: int = MAX_SEQ_LEN
|
| 25 |
-
|
| 26 |
-
attention_n_heads: int = 12
|
| 27 |
-
attention_n_kv_heads: Optional[int] = 4
|
| 28 |
-
|
| 29 |
-
activation_hidden_dim: int = 3072
|
| 30 |
-
|
| 31 |
-
norm_eps: float = 1e-6
|
| 32 |
-
|
| 33 |
-
position_emb_theta: float = 10000.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config/monitoring_config.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Monitoring Config
|
| 3 |
-
|
| 4 |
-
Specifies the monitoring process, e.g. how to log metrics and keep track of training progress.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from dataclasses import dataclass, field
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@dataclass
|
| 11 |
-
class LoggingConfig:
|
| 12 |
-
log_level: str = "INFO"
|
| 13 |
-
log_every_n_steps: int = 100
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
@dataclass
|
| 17 |
-
class WandbConfig:
|
| 18 |
-
# configure logging to Weights and Biases
|
| 19 |
-
project: str = ""
|
| 20 |
-
entity: str = ""
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
@dataclass
|
| 24 |
-
class MonitoringConfig:
|
| 25 |
-
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
| 26 |
-
|
| 27 |
-
# Weights and Biases
|
| 28 |
-
save_to_wandb: bool = False
|
| 29 |
-
wandb: WandbConfig = field(default_factory=WandbConfig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/config/training_config.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Training Config
|
| 3 |
-
|
| 4 |
-
Specifies the hyperparameters for the training process, i.e. the optimizer, learning rate, etc.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from dataclasses import dataclass, field
|
| 8 |
-
|
| 9 |
-
from ._constants import GRADIENT_ACCUMULATION_STEPS
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@dataclass
|
| 13 |
-
class FabricConfig:
|
| 14 |
-
# Configure nodes/devices for parallelised training
|
| 15 |
-
num_nodes: int = 1
|
| 16 |
-
num_devices: int = 1
|
| 17 |
-
precision: str = "bf16-mixed"
|
| 18 |
-
# Hardware accelerator to use, can be cpu/cuda/mps etc.
|
| 19 |
-
accelerator: str = "cuda"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@dataclass
|
| 23 |
-
class OptimizationConfig:
|
| 24 |
-
# Optimizer
|
| 25 |
-
optimizer: str = "adamw"
|
| 26 |
-
lr: float = 3e-4
|
| 27 |
-
|
| 28 |
-
# Learning Rate Scheduler
|
| 29 |
-
lr_scheduler: str = "linear_with_warmup"
|
| 30 |
-
lr_warmup_steps: int = 2500
|
| 31 |
-
|
| 32 |
-
# Define number of gradient accumulation steps
|
| 33 |
-
gradient_accumulation_steps: int = GRADIENT_ACCUMULATION_STEPS
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
@dataclass
|
| 37 |
-
class TrainingConfig:
|
| 38 |
-
fabric: FabricConfig = field(default_factory=FabricConfig)
|
| 39 |
-
optimization: OptimizationConfig = field(default_factory=OptimizationConfig)
|
| 40 |
-
max_steps: int = 200_000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/evaluation/__init__.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pico Evaluation Package
|
| 3 |
-
|
| 4 |
-
This package implements the evaluation pipeline for the Pico language model. It provides
|
| 5 |
-
functionality to evaluate model performance using various metrics and handles the complete
|
| 6 |
-
evaluation workflow.
|
| 7 |
-
|
| 8 |
-
We recommend that each evaluation metric should have its own config, and should be
|
| 9 |
-
implemented as a module in the `evaluation/tasks` directory that exposes a `run_<metric_name>` function.
|
| 10 |
-
|
| 11 |
-
NOTE: Out of the box we only support Paloma, but the structure is designed to be flexible and
|
| 12 |
-
you are meant to add whatever metrics you want. One of the main reasons we store out
|
| 13 |
-
the model in the HuggingFace format is so that its easy to use third-party evaluation
|
| 14 |
-
libraries/frameworks.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
import os
|
| 18 |
-
|
| 19 |
-
import torch
|
| 20 |
-
from lightning.fabric import Fabric
|
| 21 |
-
from torch import nn
|
| 22 |
-
|
| 23 |
-
from src.config import CheckpointingConfig, EvaluationConfig
|
| 24 |
-
|
| 25 |
-
from .tasks.paloma import run_paloma_evaluation
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def run_evaluation(
|
| 29 |
-
evaluation_config: EvaluationConfig,
|
| 30 |
-
checkpointing_config: CheckpointingConfig,
|
| 31 |
-
fabric: Fabric,
|
| 32 |
-
model: nn.Module,
|
| 33 |
-
) -> None:
|
| 34 |
-
"""Run model evaluation using specified metrics in `evaluation_config`.
|
| 35 |
-
|
| 36 |
-
This function orchestrates the complete evaluation pipeline by:
|
| 37 |
-
1. Resolving the model checkpoint path (either specified or latest) to load the model from;
|
| 38 |
-
during training, this is the path to the latest checkpoint in the run directory.
|
| 39 |
-
2. Iterating over each evaluation metric, and running the corresponding evaluation function.
|
| 40 |
-
NOTE: we suggest you follow the pattern of the Paloma evaluation function, and implement
|
| 41 |
-
your own evaluation function for each metric in the `evaluation/tasks` directory.
|
| 42 |
-
3. Aggregating results across all metrics in a dictionary, and returning it.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
evaluation_config (EvaluationConfig): Configuration object containing:
|
| 46 |
-
- metrics (List[str]): Metrics to evaluate; each metric should have its
|
| 47 |
-
own config. Currently supported: ["paloma"];
|
| 48 |
-
- paloma (PalomaConfig): Configuration for Paloma evaluation
|
| 49 |
-
- max_length (int): Maximum sequence length
|
| 50 |
-
- limit_eval_examples (Optional[int]): Number of examples to evaluate
|
| 51 |
-
checkpointing_config (CheckpointingConfig): Configuration object containing:
|
| 52 |
-
fabric (Fabric): Lightning Fabric instance
|
| 53 |
-
model (nn.Module): Original model instance
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
Dict[str, float]: Dictionary mapping metric names to their values
|
| 57 |
-
Example: {"paloma": 3.45}
|
| 58 |
-
|
| 59 |
-
Raises:
|
| 60 |
-
ValueError: If an unsupported evaluation metric is requested
|
| 61 |
-
|
| 62 |
-
Example:
|
| 63 |
-
results = run_evaluation(
|
| 64 |
-
EvaluationConfig(
|
| 65 |
-
run_name="experiment_1",
|
| 66 |
-
metrics=["paloma"],
|
| 67 |
-
paloma=PalomaConfig(max_length=2048, batch_size=16)
|
| 68 |
-
)
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
"""
|
| 72 |
-
|
| 73 |
-
fabric.barrier()
|
| 74 |
-
|
| 75 |
-
model.to("cpu") # Offloading model to CPU
|
| 76 |
-
|
| 77 |
-
evaluation_results = {}
|
| 78 |
-
|
| 79 |
-
# NOTE: Evaluation is only run on first processes to enable third-party evaluation libraries
|
| 80 |
-
# to determine how to handle distributed evaluation.
|
| 81 |
-
if fabric.global_rank == 0:
|
| 82 |
-
run_name = checkpointing_config.run_name
|
| 83 |
-
model_path = f"{os.getcwd()}/{checkpointing_config.runs_dir}/{run_name}/{checkpointing_config.checkpoints_dir}/latest"
|
| 84 |
-
os.makedirs(model_path, exist_ok=True)
|
| 85 |
-
|
| 86 |
-
for metric in evaluation_config.metrics:
|
| 87 |
-
# NOTE: add your own metrics here
|
| 88 |
-
if metric == "paloma":
|
| 89 |
-
evaluation_result = run_paloma_evaluation(
|
| 90 |
-
model_path, evaluation_config.paloma
|
| 91 |
-
)
|
| 92 |
-
else:
|
| 93 |
-
raise ValueError(f"Metric {metric} not supported")
|
| 94 |
-
|
| 95 |
-
evaluation_results[metric] = evaluation_result
|
| 96 |
-
|
| 97 |
-
torch.cuda.empty_cache()
|
| 98 |
-
|
| 99 |
-
fabric.barrier()
|
| 100 |
-
|
| 101 |
-
model.to(fabric.device)
|
| 102 |
-
|
| 103 |
-
return evaluation_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/evaluation/tasks/paloma.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Paloma is a comprehensive evaluation benchmark for large language models (LLMs) that focuses
|
| 3 |
-
on measuring perplexity across diverse text domains.
|
| 4 |
-
|
| 5 |
-
To evaluate on Paloma, we use the huggingface evaluation framework.
|
| 6 |
-
|
| 7 |
-
For more details, see: https://huggingface.co/datasets/allenai/paloma
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import evaluate
|
| 11 |
-
from datasets import load_dataset
|
| 12 |
-
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
| 13 |
-
|
| 14 |
-
from src.config.evaluation_config import PalomaEvaluationConfig
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def run_paloma_evaluation(
|
| 18 |
-
model_path: str,
|
| 19 |
-
paloma_config: PalomaEvaluationConfig,
|
| 20 |
-
) -> None:
|
| 21 |
-
"""Run Perplexity evaluation on the Paloma evaluation dataset.
|
| 22 |
-
|
| 23 |
-
We use the HuggingFace evaluate library to load in and compute the perplexity metric.
|
| 24 |
-
|
| 25 |
-
Args:
|
| 26 |
-
model_path (str): Path to the model checkpoint to be evaluated
|
| 27 |
-
paloma_config (PalomaEvaluationConfig): Configuration for Paloma evaluation
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
disable_progress_bar()
|
| 31 |
-
|
| 32 |
-
# load custom evaluation space, see https://huggingface.co/spaces/pico-lm/perplexity
|
| 33 |
-
perplexity = evaluate.load("pico-lm/perplexity")
|
| 34 |
-
|
| 35 |
-
dataset = load_dataset(
|
| 36 |
-
paloma_config.dataset_name, split=paloma_config.dataset_split
|
| 37 |
-
)["text"]
|
| 38 |
-
|
| 39 |
-
# compute perplexity score on Paloma dataset
|
| 40 |
-
perplexity_result = perplexity.compute(
|
| 41 |
-
model_id=model_path,
|
| 42 |
-
predictions=dataset,
|
| 43 |
-
add_start_token=False,
|
| 44 |
-
max_length=paloma_config.max_length,
|
| 45 |
-
batch_size=paloma_config.batch_size,
|
| 46 |
-
trust_remote_code=True,
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
mean_perplexity = perplexity_result["mean_perplexity"]
|
| 50 |
-
|
| 51 |
-
enable_progress_bar()
|
| 52 |
-
return mean_perplexity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/model/__init__.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Model Package
|
| 3 |
-
|
| 4 |
-
This Package contains Pico models (currently only the Pico Decoder). We plan to implement other
|
| 5 |
-
architectures in the future.
|
| 6 |
-
|
| 7 |
-
If you have other models you'd like to implement, we recommend you add modules to this package.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
from .pico_decoder import PicoDecoder
|
| 11 |
-
|
| 12 |
-
__all__ = ["PicoDecoder"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/model/pico_decoder.py
DELETED
|
@@ -1,911 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pico Decoder: A Lightweight Causal Transformer Language Model
|
| 3 |
-
|
| 4 |
-
Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
|
| 5 |
-
|
| 6 |
-
Everything is written with a modular design for easy modification and experimentation.
|
| 7 |
-
|
| 8 |
-
Key features:
|
| 9 |
-
- RMSNorm for layer normalization
|
| 10 |
-
- Rotary Positional Embeddings (RoPE)
|
| 11 |
-
- Multi-head attention with KV-cache support
|
| 12 |
-
- SwiGLU activation function
|
| 13 |
-
- Residual connections throughout
|
| 14 |
-
|
| 15 |
-
- KV-cache for faster autoregressive generation
|
| 16 |
-
|
| 17 |
-
References:
|
| 18 |
-
- RoPE: https://arxiv.org/abs/2104.09864
|
| 19 |
-
- SwiGLU: https://arxiv.org/abs/2002.05202
|
| 20 |
-
- LLAMA: https://arxiv.org/abs/2302.13971
|
| 21 |
-
|
| 22 |
-
Adapted from:
|
| 23 |
-
- OLMO: https://github.com/allenai/OLMo
|
| 24 |
-
- LLAMA: https://github.com/meta/llama
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
from dataclasses import asdict
|
| 28 |
-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
| 29 |
-
|
| 30 |
-
import torch
|
| 31 |
-
import torch.nn as nn
|
| 32 |
-
import torch.nn.functional as F
|
| 33 |
-
|
| 34 |
-
# Handle PyTorch version compatibility for attention backend
|
| 35 |
-
try:
|
| 36 |
-
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 37 |
-
|
| 38 |
-
HAS_TORCH_ATTENTION = True
|
| 39 |
-
except ImportError:
|
| 40 |
-
# Fallback for older PyTorch versions
|
| 41 |
-
HAS_TORCH_ATTENTION = False
|
| 42 |
-
SDPBackend = None
|
| 43 |
-
sdpa_kernel = None
|
| 44 |
-
|
| 45 |
-
from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
|
| 46 |
-
from transformers.generation import GenerationConfig
|
| 47 |
-
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
|
| 48 |
-
|
| 49 |
-
try:
|
| 50 |
-
if TYPE_CHECKING:
|
| 51 |
-
# We need to do this to avoid importing these when creating the HF-compatible models
|
| 52 |
-
from src.config import ModelConfig
|
| 53 |
-
except ImportError:
|
| 54 |
-
pass
|
| 55 |
-
|
| 56 |
-
########################################################
|
| 57 |
-
#
|
| 58 |
-
# Layer Normalization
|
| 59 |
-
#
|
| 60 |
-
########################################################
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class RMSNorm(torch.nn.Module):
|
| 64 |
-
"""Root Mean Square Layer Normalization.
|
| 65 |
-
|
| 66 |
-
A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
|
| 67 |
-
resulting in improved stability and performance.
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
|
| 71 |
-
- config.norm_eps: Small constant for numerical stability
|
| 72 |
-
- config.d_model: Model dimension for the weight parameter
|
| 73 |
-
|
| 74 |
-
References:
|
| 75 |
-
https://arxiv.org/abs/1910.07467
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
|
| 79 |
-
super().__init__()
|
| 80 |
-
self.eps = config.norm_eps
|
| 81 |
-
self.weight = nn.Parameter(torch.ones(config.d_model))
|
| 82 |
-
|
| 83 |
-
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
| 84 |
-
"""
|
| 85 |
-
Normalizes the input tensor by its RMS value.
|
| 86 |
-
"""
|
| 87 |
-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 88 |
-
|
| 89 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 90 |
-
"""
|
| 91 |
-
Applies RMS normalization to the input tensor and scales it by the weight parameter.
|
| 92 |
-
"""
|
| 93 |
-
output = self._norm(x.float()).type_as(x)
|
| 94 |
-
return output * self.weight
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
########################################################
|
| 98 |
-
#
|
| 99 |
-
# Positional Embedding
|
| 100 |
-
#
|
| 101 |
-
########################################################
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class RoPE(nn.Module):
|
| 105 |
-
"""Rotary Positional Embeddings (RoPE).
|
| 106 |
-
|
| 107 |
-
Implements position-dependent rotation of keys and queries in attention mechanism,
|
| 108 |
-
allowing better modeling of relative positions in sequences. Uses complex number
|
| 109 |
-
operations for efficient rotation.
|
| 110 |
-
|
| 111 |
-
Args:
|
| 112 |
-
config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
|
| 113 |
-
- config.position_emb_theta: Base for frequency computation
|
| 114 |
-
- config.d_model: Model dimension
|
| 115 |
-
- config.attention_n_heads: Number of attention heads
|
| 116 |
-
- config.max_seq_len: Maximum sequence length
|
| 117 |
-
|
| 118 |
-
References:
|
| 119 |
-
https://arxiv.org/abs/2104.09864
|
| 120 |
-
"""
|
| 121 |
-
|
| 122 |
-
_freqs_cis_tensor: torch.Tensor | None = None
|
| 123 |
-
|
| 124 |
-
def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
|
| 125 |
-
super().__init__()
|
| 126 |
-
|
| 127 |
-
self.theta = config.position_emb_theta
|
| 128 |
-
self.dim = config.d_model // config.attention_n_heads
|
| 129 |
-
|
| 130 |
-
max_seq_len = config.max_seq_len
|
| 131 |
-
|
| 132 |
-
# only gets set once, and then reused for all RoPE instances
|
| 133 |
-
if RoPE._freqs_cis_tensor is None:
|
| 134 |
-
RoPE._freqs_cis_tensor = self._setup_freqs_cis(
|
| 135 |
-
max_seq_len, self.theta, self.dim
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
# register _freqs_cis buffer
|
| 139 |
-
# can be easily recomputed so persistent=False
|
| 140 |
-
self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
|
| 141 |
-
|
| 142 |
-
@classmethod
|
| 143 |
-
def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
|
| 144 |
-
"""Setup Frequency Tensor for RoPE Embeddings
|
| 145 |
-
|
| 146 |
-
Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
|
| 147 |
-
|
| 148 |
-
Note other implementations will use cos and sin directly, but using the complex
|
| 149 |
-
number representation is (probably) more efficient:
|
| 150 |
-
|
| 151 |
-
e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
|
| 152 |
-
"""
|
| 153 |
-
_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 154 |
-
positions = torch.arange(seq_len)
|
| 155 |
-
freqs = torch.outer(positions, _freqs)
|
| 156 |
-
return torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 157 |
-
|
| 158 |
-
def get_freqs_cis(
|
| 159 |
-
self, input_shape: torch.Size, start_pos: int, end_pos: int
|
| 160 |
-
) -> torch.Tensor:
|
| 161 |
-
"""Reshape Frequency Tensor for RoPE Embeddings
|
| 162 |
-
|
| 163 |
-
Makes the frequency tensor broadcastable with the input tensor.
|
| 164 |
-
"""
|
| 165 |
-
_freqs_cis = self._freqs_cis[start_pos:end_pos]
|
| 166 |
-
ndim = len(input_shape)
|
| 167 |
-
assert 0 <= 1 < ndim
|
| 168 |
-
assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
|
| 169 |
-
|
| 170 |
-
# TODO: Check whether this is correct (might be able to remove this)
|
| 171 |
-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
|
| 172 |
-
return _freqs_cis.view(*shape)
|
| 173 |
-
|
| 174 |
-
def forward(
|
| 175 |
-
self,
|
| 176 |
-
queries: torch.Tensor,
|
| 177 |
-
keys: torch.Tensor,
|
| 178 |
-
start_pos: int = 0,
|
| 179 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 180 |
-
"""Apply RoPE Embeddings to Queries and Keys
|
| 181 |
-
|
| 182 |
-
Applies the rotary positional embeddings to the input tensors via complex num multiplication
|
| 183 |
-
|
| 184 |
-
NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
|
| 185 |
-
"""
|
| 186 |
-
queries_ = torch.view_as_complex(
|
| 187 |
-
queries.float().reshape(*queries.shape[:-1], -1, 2)
|
| 188 |
-
)
|
| 189 |
-
keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
|
| 190 |
-
|
| 191 |
-
input_shape = (
|
| 192 |
-
queries_.shape
|
| 193 |
-
) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
|
| 194 |
-
freqs_start_pos = start_pos
|
| 195 |
-
freqs_end_pos = freqs_start_pos + queries_.shape[1]
|
| 196 |
-
|
| 197 |
-
freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
|
| 198 |
-
|
| 199 |
-
queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
|
| 200 |
-
keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
|
| 201 |
-
return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
########################################################
|
| 205 |
-
#
|
| 206 |
-
# Attention
|
| 207 |
-
#
|
| 208 |
-
########################################################
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
class Attention(nn.Module):
|
| 212 |
-
"""Multi-head Attention with Group Query Attention support.
|
| 213 |
-
|
| 214 |
-
Implements scaled dot-product attention and supports:
|
| 215 |
-
- Grouped Query Attention (GQA)
|
| 216 |
-
- Key-Value caching for efficient inference
|
| 217 |
-
- RoPE integration
|
| 218 |
-
|
| 219 |
-
Args:
|
| 220 |
-
config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
|
| 221 |
-
- config.attention_n_heads: Number of attention heads
|
| 222 |
-
- config.attention_n_kv_heads: Number of key/value heads
|
| 223 |
-
- config.d_model: Model dimension
|
| 224 |
-
- config.batch_size: Maximum batch size
|
| 225 |
-
- config.max_seq_len: Maximum sequence length
|
| 226 |
-
|
| 227 |
-
Shape:
|
| 228 |
-
- Input: (batch_size, seq_len, d_model)
|
| 229 |
-
- Output: (batch_size, seq_len, d_model)
|
| 230 |
-
"""
|
| 231 |
-
|
| 232 |
-
def __init__(
|
| 233 |
-
self,
|
| 234 |
-
config: Union["ModelConfig", "PicoDecoderHFConfig"],
|
| 235 |
-
):
|
| 236 |
-
super().__init__()
|
| 237 |
-
|
| 238 |
-
self.n_heads = config.attention_n_heads
|
| 239 |
-
self.n_kv_heads = config.attention_n_kv_heads
|
| 240 |
-
|
| 241 |
-
self.batch_size = config.batch_size
|
| 242 |
-
self.max_seq_len = config.max_seq_len
|
| 243 |
-
|
| 244 |
-
d_model = config.d_model
|
| 245 |
-
self.head_dim = d_model // self.n_heads
|
| 246 |
-
|
| 247 |
-
self.n_rep = self.n_heads // self.n_kv_heads
|
| 248 |
-
|
| 249 |
-
self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
|
| 250 |
-
self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
|
| 251 |
-
self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
|
| 252 |
-
self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
|
| 253 |
-
|
| 254 |
-
self.rope = RoPE(config)
|
| 255 |
-
|
| 256 |
-
def forward(
|
| 257 |
-
self,
|
| 258 |
-
input: torch.Tensor,
|
| 259 |
-
mask: Optional[torch.Tensor] = None,
|
| 260 |
-
past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
|
| 261 |
-
use_cache: bool = False,
|
| 262 |
-
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 263 |
-
"""Forward pass for the attention mechanism.
|
| 264 |
-
|
| 265 |
-
Computes queries, keys, and values for the attention mechanism. Applies rotary positional
|
| 266 |
-
embeddings to the queries and keys, and then computes attention scores and outputs.
|
| 267 |
-
|
| 268 |
-
For an introduction to the attention mechanism, see:
|
| 269 |
-
https://arxiv.org/abs/1706.03762
|
| 270 |
-
|
| 271 |
-
A few things to note:
|
| 272 |
-
- The past_key_values is used to implement the KV cache, which is used to speed up
|
| 273 |
-
generation by caching the KV pairs from previous forward passes. This is useful when doing
|
| 274 |
-
tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
|
| 275 |
-
modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
|
| 276 |
-
its own KV cache - this KV cache is implemented as a tuple.
|
| 277 |
-
"""
|
| 278 |
-
bsz, seq_len, _ = input.shape
|
| 279 |
-
_queries, _keys, _values = (
|
| 280 |
-
self.q_proj(input),
|
| 281 |
-
self.k_proj(input),
|
| 282 |
-
self.v_proj(input),
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
# Reshaping for multi-head attention
|
| 286 |
-
queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
|
| 287 |
-
keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
| 288 |
-
values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
| 289 |
-
|
| 290 |
-
# The start position is used to apply the RoPE embeddings to only the new tokens
|
| 291 |
-
# when using the kv_cache in the attention mechanism.
|
| 292 |
-
# We want to start from the last position in the cache.
|
| 293 |
-
start_pos = 0
|
| 294 |
-
if past_key_values is not None and past_key_values[0] is not None:
|
| 295 |
-
start_pos = past_key_values[0].shape[1]
|
| 296 |
-
|
| 297 |
-
# apply rotary positional embeddings
|
| 298 |
-
queries, keys = self.rope(queries, keys, start_pos)
|
| 299 |
-
|
| 300 |
-
if (
|
| 301 |
-
past_key_values is not None
|
| 302 |
-
and past_key_values[0] is not None
|
| 303 |
-
and past_key_values[1] is not None
|
| 304 |
-
):
|
| 305 |
-
keys = torch.cat([past_key_values[0], keys], dim=1)
|
| 306 |
-
values = torch.cat([past_key_values[1], values], dim=1)
|
| 307 |
-
|
| 308 |
-
if use_cache:
|
| 309 |
-
cached_keys = keys
|
| 310 |
-
cached_values = values
|
| 311 |
-
else:
|
| 312 |
-
cached_keys = None
|
| 313 |
-
cached_values = None
|
| 314 |
-
|
| 315 |
-
queries = queries.transpose(1, 2)
|
| 316 |
-
keys = keys.transpose(1, 2)
|
| 317 |
-
values = values.transpose(1, 2)
|
| 318 |
-
|
| 319 |
-
apply_gqa = self.n_rep > 1
|
| 320 |
-
if apply_gqa and queries.device.type == "mps":
|
| 321 |
-
# NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
|
| 322 |
-
# outside of the kernel to get the same effect.
|
| 323 |
-
# See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
| 324 |
-
keys = keys.repeat_interleave(self.n_rep, dim=-3)
|
| 325 |
-
values = values.repeat_interleave(self.n_rep, dim=-3)
|
| 326 |
-
apply_gqa = False
|
| 327 |
-
|
| 328 |
-
if HAS_TORCH_ATTENTION:
|
| 329 |
-
backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
|
| 330 |
-
with sdpa_kernel(backends=backends):
|
| 331 |
-
attn_output = F.scaled_dot_product_attention(
|
| 332 |
-
queries.contiguous(),
|
| 333 |
-
keys.contiguous(),
|
| 334 |
-
values.contiguous(),
|
| 335 |
-
attn_mask=mask.to(queries.dtype) if mask is not None else None,
|
| 336 |
-
enable_gqa=apply_gqa,
|
| 337 |
-
)
|
| 338 |
-
else:
|
| 339 |
-
# Fallback for older PyTorch versions - use default backend
|
| 340 |
-
attn_output = F.scaled_dot_product_attention(
|
| 341 |
-
queries.contiguous(),
|
| 342 |
-
keys.contiguous(),
|
| 343 |
-
values.contiguous(),
|
| 344 |
-
attn_mask=mask.to(queries.dtype) if mask is not None else None,
|
| 345 |
-
enable_gqa=apply_gqa,
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
-
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
|
| 349 |
-
output = self.o_proj(attn_output)
|
| 350 |
-
|
| 351 |
-
return output, (cached_keys, cached_values)
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
########################################################
|
| 355 |
-
#
|
| 356 |
-
# SwiGLU (Combines MLP and Activation)
|
| 357 |
-
#
|
| 358 |
-
########################################################
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
class SwiGLU(nn.Module):
|
| 362 |
-
"""SwiGLU Activation Function with Linear Projections.
|
| 363 |
-
|
| 364 |
-
Implements the SwiGLU activation function combined with linear transformations,
|
| 365 |
-
serving as the feed-forward network in transformer blocks.
|
| 366 |
-
|
| 367 |
-
Args:
|
| 368 |
-
config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
|
| 369 |
-
- config.d_model: Model dimension
|
| 370 |
-
- config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
|
| 371 |
-
|
| 372 |
-
References:
|
| 373 |
-
https://arxiv.org/abs/2002.05202
|
| 374 |
-
"""
|
| 375 |
-
|
| 376 |
-
def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
|
| 377 |
-
super().__init__()
|
| 378 |
-
|
| 379 |
-
model_dim = config.d_model
|
| 380 |
-
act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
|
| 381 |
-
|
| 382 |
-
self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
|
| 383 |
-
self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
|
| 384 |
-
self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
|
| 385 |
-
|
| 386 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 387 |
-
return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
########################################################
|
| 391 |
-
#
|
| 392 |
-
# PicoDecoderBlock
|
| 393 |
-
#
|
| 394 |
-
########################################################
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
class PicoDecoderBlock(nn.Module):
|
| 398 |
-
"""Single Transformer Block with Attention and Feed-forward layers.
|
| 399 |
-
|
| 400 |
-
Implements a standard transformer block with:
|
| 401 |
-
- Multi-head attention with normalization and residual connection
|
| 402 |
-
- SwiGLU feed-forward network with normalization and residual connection
|
| 403 |
-
|
| 404 |
-
Args:
|
| 405 |
-
config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
|
| 406 |
-
a HuggingFace PicoDecoderHFConfig
|
| 407 |
-
"""
|
| 408 |
-
|
| 409 |
-
def __init__(
|
| 410 |
-
self,
|
| 411 |
-
config: Union["ModelConfig", "PicoDecoderHFConfig"],
|
| 412 |
-
):
|
| 413 |
-
super().__init__()
|
| 414 |
-
|
| 415 |
-
self.attention = Attention(config)
|
| 416 |
-
self.swiglu = SwiGLU(config)
|
| 417 |
-
self.attention_norm = RMSNorm(config)
|
| 418 |
-
self.swiglu_norm = RMSNorm(config)
|
| 419 |
-
|
| 420 |
-
def forward(
|
| 421 |
-
self,
|
| 422 |
-
input: torch.Tensor,
|
| 423 |
-
mask: Optional[torch.Tensor] = None,
|
| 424 |
-
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
| 425 |
-
use_cache: bool = False,
|
| 426 |
-
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 427 |
-
attention_output, cached_key_values = self.attention(
|
| 428 |
-
self.attention_norm(input),
|
| 429 |
-
mask=mask,
|
| 430 |
-
past_key_values=past_key_values,
|
| 431 |
-
use_cache=use_cache,
|
| 432 |
-
)
|
| 433 |
-
# NOTE: cached_key_values is None if use_cache is False
|
| 434 |
-
|
| 435 |
-
h = input + attention_output
|
| 436 |
-
out = h + self.swiglu(self.swiglu_norm(h))
|
| 437 |
-
return out, cached_key_values
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
########################################################
|
| 441 |
-
#
|
| 442 |
-
# Pico Decoder (Causal Transformer Model)
|
| 443 |
-
#
|
| 444 |
-
########################################################
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
class PicoDecoder(nn.Module):
|
| 448 |
-
"""
|
| 449 |
-
Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
|
| 450 |
-
single autoregressive model.
|
| 451 |
-
|
| 452 |
-
For more information on the model, see the classes for the modules that make up the model.
|
| 453 |
-
"""
|
| 454 |
-
|
| 455 |
-
def __init__(
|
| 456 |
-
self,
|
| 457 |
-
model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
|
| 458 |
-
):
|
| 459 |
-
super().__init__()
|
| 460 |
-
self.config = model_config
|
| 461 |
-
|
| 462 |
-
self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
|
| 463 |
-
self.layers = nn.ModuleList(
|
| 464 |
-
[PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
|
| 465 |
-
)
|
| 466 |
-
self.output_norm = RMSNorm(self.config)
|
| 467 |
-
self.de_embedding_proj = nn.Linear(
|
| 468 |
-
self.config.d_model, self.config.vocab_size, bias=False
|
| 469 |
-
)
|
| 470 |
-
|
| 471 |
-
def convert_to_hf_model(self) -> "PicoDecoderHF":
|
| 472 |
-
"""Convert the Lightning model to a HuggingFace model."""
|
| 473 |
-
# Create HF config without fabric-specific settings
|
| 474 |
-
hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
|
| 475 |
-
|
| 476 |
-
# Create new HF model
|
| 477 |
-
hf_model = PicoDecoderHF(hf_config)
|
| 478 |
-
|
| 479 |
-
# Copy state dict, excluding fabric-specific keys
|
| 480 |
-
hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
|
| 481 |
-
|
| 482 |
-
return hf_model
|
| 483 |
-
|
| 484 |
-
def forward(
|
| 485 |
-
self,
|
| 486 |
-
input_ids: torch.Tensor,
|
| 487 |
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 488 |
-
use_cache: bool = False,
|
| 489 |
-
) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
|
| 490 |
-
"""
|
| 491 |
-
This is the forward pass for the entire Pico model. It boils down to:
|
| 492 |
-
- Embedding the input ids
|
| 493 |
-
- Creating a causal mask
|
| 494 |
-
- Processing through the pico layers
|
| 495 |
-
- Projecting the output to logits
|
| 496 |
-
|
| 497 |
-
NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
|
| 498 |
-
generation by caching the KV pairs from previous forward passes. This is useful when doing
|
| 499 |
-
tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
|
| 500 |
-
modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
|
| 501 |
-
its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
|
| 502 |
-
KV caches (so a tuple of tuples).
|
| 503 |
-
"""
|
| 504 |
-
|
| 505 |
-
seq_len = input_ids.shape[-1]
|
| 506 |
-
h = self.embedding_proj(input_ids)
|
| 507 |
-
|
| 508 |
-
# Calculate start position from past cached KV pairs. Remember that each layer has its
|
| 509 |
-
# own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
|
| 510 |
-
# correct layer and then for either the keys or values.
|
| 511 |
-
start_pos = 0
|
| 512 |
-
if (
|
| 513 |
-
past_key_values is not None
|
| 514 |
-
and past_key_values[0] is not None
|
| 515 |
-
and past_key_values[0][0] is not None
|
| 516 |
-
):
|
| 517 |
-
start_pos = past_key_values[0][0].shape[1]
|
| 518 |
-
|
| 519 |
-
# Create causal mask for current sequence
|
| 520 |
-
mask = None
|
| 521 |
-
if seq_len > 1:
|
| 522 |
-
mask = torch.full((seq_len, seq_len), float("-inf"))
|
| 523 |
-
mask = torch.triu(mask, diagonal=1)
|
| 524 |
-
|
| 525 |
-
# If using KV cache, extend mask to cover cached sequence length
|
| 526 |
-
if past_key_values is not None:
|
| 527 |
-
# Add zeros for cached tokens (we can attend to all of them)
|
| 528 |
-
mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
|
| 529 |
-
|
| 530 |
-
mask = mask.to(h.device)
|
| 531 |
-
|
| 532 |
-
# NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
|
| 533 |
-
# in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
|
| 534 |
-
cached_key_values = () if use_cache else None
|
| 535 |
-
|
| 536 |
-
# Process through transformer blocks
|
| 537 |
-
for idx, layer in enumerate(self.layers):
|
| 538 |
-
layer_past_key_values = None
|
| 539 |
-
if past_key_values is not None:
|
| 540 |
-
try:
|
| 541 |
-
# Handle both tuple-based cache and HuggingFace cache objects
|
| 542 |
-
if hasattr(past_key_values, "__getitem__") and idx < len(
|
| 543 |
-
past_key_values
|
| 544 |
-
):
|
| 545 |
-
layer_past_key_values = past_key_values[idx]
|
| 546 |
-
except (KeyError, IndexError, TypeError):
|
| 547 |
-
# If we can't access the cache properly, just skip it
|
| 548 |
-
layer_past_key_values = None
|
| 549 |
-
|
| 550 |
-
h, layer_cached_key_values = layer(
|
| 551 |
-
h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
|
| 552 |
-
)
|
| 553 |
-
|
| 554 |
-
if use_cache:
|
| 555 |
-
cached_key_values += (layer_cached_key_values,)
|
| 556 |
-
|
| 557 |
-
# Final norm and projection
|
| 558 |
-
h = self.output_norm(h)
|
| 559 |
-
logits = self.de_embedding_proj(h).float()
|
| 560 |
-
|
| 561 |
-
return logits, cached_key_values
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
########################################################
|
| 565 |
-
#
|
| 566 |
-
# HuggingFace Wrapper for the Pico Decoder model.
|
| 567 |
-
#
|
| 568 |
-
########################################################
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
class PicoDecoderHFConfig(PretrainedConfig):
|
| 572 |
-
"""Config class for the Pico Decoder HuggingFace wrapper."""
|
| 573 |
-
|
| 574 |
-
model_type = "pico_decoder"
|
| 575 |
-
|
| 576 |
-
@classmethod
|
| 577 |
-
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
|
| 578 |
-
"""
|
| 579 |
-
Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
|
| 580 |
-
this is because with some kwargs special handling is required and can make this class
|
| 581 |
-
brittle.
|
| 582 |
-
"""
|
| 583 |
-
pico_config = cls(**config_dict)
|
| 584 |
-
|
| 585 |
-
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
| 586 |
-
unused_kwargs = {
|
| 587 |
-
key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
|
| 588 |
-
}
|
| 589 |
-
|
| 590 |
-
if return_unused_kwargs:
|
| 591 |
-
return pico_config, unused_kwargs
|
| 592 |
-
return pico_config
|
| 593 |
-
|
| 594 |
-
@classmethod
|
| 595 |
-
def from_dataclass(cls, model_config: "ModelConfig"):
|
| 596 |
-
"""Initialise from our custom config dataclass."""
|
| 597 |
-
return cls.from_dict(asdict(model_config))
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
class PicoDecoderHF(PreTrainedModel, GenerationMixin):
|
| 601 |
-
"""
|
| 602 |
-
HuggingFace wrapper for the Pico model with generation support.
|
| 603 |
-
|
| 604 |
-
Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
|
| 605 |
-
wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
|
| 606 |
-
Pico model as well as the model wrapped in this HuggingFace class.
|
| 607 |
-
|
| 608 |
-
This also lets you do cool things like:
|
| 609 |
-
|
| 610 |
-
`model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
|
| 611 |
-
"""
|
| 612 |
-
|
| 613 |
-
config_class = PicoDecoderHFConfig
|
| 614 |
-
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
| 615 |
-
main_input_name = "input_ids"
|
| 616 |
-
|
| 617 |
-
def __init__(self, config: PicoDecoderHFConfig):
|
| 618 |
-
super().__init__(config)
|
| 619 |
-
self.pico_decoder = PicoDecoder(config)
|
| 620 |
-
# Initialize generation config with defaults
|
| 621 |
-
self.generation_config = GenerationConfig()
|
| 622 |
-
# Set some reasonable defaults for the model
|
| 623 |
-
if hasattr(config, "max_position_embeddings"):
|
| 624 |
-
self.generation_config.max_length = config.max_position_embeddings
|
| 625 |
-
if hasattr(config, "vocab_size"):
|
| 626 |
-
self.generation_config.vocab_size = config.vocab_size
|
| 627 |
-
|
| 628 |
-
def forward(
|
| 629 |
-
self,
|
| 630 |
-
input_ids: torch.Tensor,
|
| 631 |
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 632 |
-
use_cache: bool = False,
|
| 633 |
-
**kwargs,
|
| 634 |
-
) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
|
| 635 |
-
"""HuggingFace forward pass wrapper.
|
| 636 |
-
|
| 637 |
-
Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
|
| 638 |
-
Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
|
| 639 |
-
"""
|
| 640 |
-
logits, past_key_values = self.pico_decoder(
|
| 641 |
-
input_ids, past_key_values, use_cache
|
| 642 |
-
)
|
| 643 |
-
if use_cache:
|
| 644 |
-
return CausalLMOutputWithPast(
|
| 645 |
-
logits=logits,
|
| 646 |
-
past_key_values=past_key_values,
|
| 647 |
-
)
|
| 648 |
-
else:
|
| 649 |
-
return CausalLMOutput(
|
| 650 |
-
logits=logits,
|
| 651 |
-
)
|
| 652 |
-
|
| 653 |
-
def prepare_inputs_for_generation(
|
| 654 |
-
self,
|
| 655 |
-
input_ids: torch.LongTensor,
|
| 656 |
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 657 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
| 658 |
-
**kwargs,
|
| 659 |
-
) -> Dict[str, Any]:
|
| 660 |
-
"""
|
| 661 |
-
Prepare inputs for generation.
|
| 662 |
-
|
| 663 |
-
Args:
|
| 664 |
-
input_ids: Input token IDs
|
| 665 |
-
past_key_values: Cached key-value pairs from previous forward passes
|
| 666 |
-
attention_mask: Attention mask for the input
|
| 667 |
-
**kwargs: Additional arguments
|
| 668 |
-
|
| 669 |
-
Returns:
|
| 670 |
-
Dictionary containing prepared inputs
|
| 671 |
-
"""
|
| 672 |
-
# If we have past_key_values, we only need the last token
|
| 673 |
-
if past_key_values is not None:
|
| 674 |
-
input_ids = input_ids[:, -1:]
|
| 675 |
-
|
| 676 |
-
return {
|
| 677 |
-
"input_ids": input_ids,
|
| 678 |
-
"past_key_values": past_key_values,
|
| 679 |
-
"use_cache": True,
|
| 680 |
-
}
|
| 681 |
-
|
| 682 |
-
def get_input_embeddings(self):
|
| 683 |
-
"""Get the input embeddings layer."""
|
| 684 |
-
return self.pico_decoder.embedding_proj
|
| 685 |
-
|
| 686 |
-
def set_input_embeddings(self, value):
|
| 687 |
-
"""Set the input embeddings layer."""
|
| 688 |
-
self.pico_decoder.embedding_proj = value
|
| 689 |
-
|
| 690 |
-
def get_output_embeddings(self):
|
| 691 |
-
"""Get the output embeddings layer."""
|
| 692 |
-
return self.pico_decoder.de_embedding_proj
|
| 693 |
-
|
| 694 |
-
def set_output_embeddings(self, value):
|
| 695 |
-
"""Set the output embeddings layer."""
|
| 696 |
-
self.pico_decoder.de_embedding_proj = value
|
| 697 |
-
|
| 698 |
-
def get_lm_head(self):
|
| 699 |
-
"""Get the language model head."""
|
| 700 |
-
return self.pico_decoder.de_embedding_proj
|
| 701 |
-
|
| 702 |
-
def can_generate(self) -> bool:
|
| 703 |
-
"""Check if the model can generate text."""
|
| 704 |
-
return True
|
| 705 |
-
|
| 706 |
-
@property
|
| 707 |
-
def is_encoder_decoder(self) -> bool:
|
| 708 |
-
"""Check if the model is an encoder-decoder model."""
|
| 709 |
-
return False
|
| 710 |
-
|
| 711 |
-
@property
|
| 712 |
-
def can_use_cache(self) -> bool:
|
| 713 |
-
"""Check if the model can use KV cache."""
|
| 714 |
-
return True
|
| 715 |
-
|
| 716 |
-
def resize_token_embeddings(
|
| 717 |
-
self, new_num_tokens: Optional[int] = None
|
| 718 |
-
) -> torch.nn.Embedding:
|
| 719 |
-
"""Resize token embeddings."""
|
| 720 |
-
old_embeddings = self.get_input_embeddings()
|
| 721 |
-
if new_num_tokens is None:
|
| 722 |
-
new_num_tokens = old_embeddings.num_embeddings
|
| 723 |
-
|
| 724 |
-
new_embeddings = torch.nn.Embedding(
|
| 725 |
-
new_num_tokens, old_embeddings.embedding_dim
|
| 726 |
-
)
|
| 727 |
-
new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
|
| 728 |
-
old_embeddings.weight.data
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
self.pico_decoder.embedding_proj = new_embeddings
|
| 732 |
-
self.pico_decoder.de_embedding_proj = torch.nn.Linear(
|
| 733 |
-
old_embeddings.embedding_dim, new_num_tokens, bias=False
|
| 734 |
-
)
|
| 735 |
-
|
| 736 |
-
return new_embeddings
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
# Register for auto classes
|
| 740 |
-
PicoDecoderHFConfig.register_for_auto_class()
|
| 741 |
-
PicoDecoderHF.register_for_auto_class("AutoModel")
|
| 742 |
-
PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
########################################################
|
| 746 |
-
#
|
| 747 |
-
# New PicoDecoderForCausalLM class for generation support
|
| 748 |
-
#
|
| 749 |
-
########################################################
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
|
| 753 |
-
"""
|
| 754 |
-
PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
|
| 755 |
-
|
| 756 |
-
This class is designed to work with existing checkpoints and provides full generation support.
|
| 757 |
-
It inherits from the right base classes that HuggingFace expects for text generation.
|
| 758 |
-
"""
|
| 759 |
-
|
| 760 |
-
config_class = PicoDecoderHFConfig
|
| 761 |
-
_no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
|
| 762 |
-
main_input_name = "input_ids"
|
| 763 |
-
|
| 764 |
-
def __init__(self, config: PicoDecoderHFConfig):
|
| 765 |
-
super().__init__(config)
|
| 766 |
-
self.pico_decoder = PicoDecoder(config)
|
| 767 |
-
# Initialize generation config with defaults
|
| 768 |
-
self.generation_config = GenerationConfig()
|
| 769 |
-
# Set some reasonable defaults for the model
|
| 770 |
-
if hasattr(config, "max_position_embeddings"):
|
| 771 |
-
self.generation_config.max_length = config.max_position_embeddings
|
| 772 |
-
if hasattr(config, "vocab_size"):
|
| 773 |
-
self.generation_config.vocab_size = config.vocab_size
|
| 774 |
-
|
| 775 |
-
def forward(
|
| 776 |
-
self,
|
| 777 |
-
input_ids: torch.Tensor,
|
| 778 |
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 779 |
-
use_cache: bool = False,
|
| 780 |
-
**kwargs,
|
| 781 |
-
) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
|
| 782 |
-
"""Forward pass for text generation."""
|
| 783 |
-
logits, past_key_values = self.pico_decoder(
|
| 784 |
-
input_ids, past_key_values, use_cache
|
| 785 |
-
)
|
| 786 |
-
if use_cache:
|
| 787 |
-
return CausalLMOutputWithPast(
|
| 788 |
-
logits=logits,
|
| 789 |
-
past_key_values=past_key_values,
|
| 790 |
-
)
|
| 791 |
-
else:
|
| 792 |
-
return CausalLMOutput(
|
| 793 |
-
logits=logits,
|
| 794 |
-
)
|
| 795 |
-
|
| 796 |
-
def prepare_inputs_for_generation(
|
| 797 |
-
self,
|
| 798 |
-
input_ids: torch.LongTensor,
|
| 799 |
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 800 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
| 801 |
-
**kwargs,
|
| 802 |
-
) -> Dict[str, Any]:
|
| 803 |
-
"""Prepare inputs for generation."""
|
| 804 |
-
# If we have past_key_values, we only need the last token
|
| 805 |
-
if past_key_values is not None:
|
| 806 |
-
input_ids = input_ids[:, -1:]
|
| 807 |
-
|
| 808 |
-
return {
|
| 809 |
-
"input_ids": input_ids,
|
| 810 |
-
"past_key_values": past_key_values,
|
| 811 |
-
"use_cache": True,
|
| 812 |
-
}
|
| 813 |
-
|
| 814 |
-
def get_input_embeddings(self):
|
| 815 |
-
"""Get the input embeddings layer."""
|
| 816 |
-
return self.pico_decoder.embedding_proj
|
| 817 |
-
|
| 818 |
-
def set_input_embeddings(self, value):
|
| 819 |
-
"""Set the input embeddings layer."""
|
| 820 |
-
self.pico_decoder.embedding_proj = value
|
| 821 |
-
|
| 822 |
-
def get_output_embeddings(self):
|
| 823 |
-
"""Get the output embeddings layer."""
|
| 824 |
-
return self.pico_decoder.de_embedding_proj
|
| 825 |
-
|
| 826 |
-
def set_output_embeddings(self, value):
|
| 827 |
-
"""Set the output embeddings layer."""
|
| 828 |
-
self.pico_decoder.de_embedding_proj = value
|
| 829 |
-
|
| 830 |
-
def get_lm_head(self):
|
| 831 |
-
"""Get the language model head."""
|
| 832 |
-
return self.pico_decoder.de_embedding_proj
|
| 833 |
-
|
| 834 |
-
def can_generate(self) -> bool:
|
| 835 |
-
"""Check if the model can generate text."""
|
| 836 |
-
return True
|
| 837 |
-
|
| 838 |
-
@property
|
| 839 |
-
def is_encoder_decoder(self) -> bool:
|
| 840 |
-
"""Check if the model is an encoder-decoder model."""
|
| 841 |
-
return False
|
| 842 |
-
|
| 843 |
-
@property
|
| 844 |
-
def can_use_cache(self) -> bool:
|
| 845 |
-
"""Check if the model can use KV cache."""
|
| 846 |
-
return True
|
| 847 |
-
|
| 848 |
-
def resize_token_embeddings(
|
| 849 |
-
self, new_num_tokens: Optional[int] = None
|
| 850 |
-
) -> torch.nn.Embedding:
|
| 851 |
-
"""Resize token embeddings."""
|
| 852 |
-
old_embeddings = self.get_input_embeddings()
|
| 853 |
-
if new_num_tokens is None:
|
| 854 |
-
new_num_tokens = old_embeddings.num_embeddings
|
| 855 |
-
|
| 856 |
-
new_embeddings = torch.nn.Embedding(
|
| 857 |
-
new_num_tokens, old_embeddings.embedding_dim
|
| 858 |
-
)
|
| 859 |
-
new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
|
| 860 |
-
old_embeddings.weight.data
|
| 861 |
-
)
|
| 862 |
-
|
| 863 |
-
self.pico_decoder.embedding_proj = new_embeddings
|
| 864 |
-
self.pico_decoder.de_embedding_proj = torch.nn.Linear(
|
| 865 |
-
old_embeddings.embedding_dim, new_num_tokens, bias=False
|
| 866 |
-
)
|
| 867 |
-
|
| 868 |
-
return new_embeddings
|
| 869 |
-
|
| 870 |
-
@classmethod
|
| 871 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 872 |
-
"""
|
| 873 |
-
Load a pretrained model from a checkpoint.
|
| 874 |
-
|
| 875 |
-
This method handles loading from both the old PicoDecoderHF format and the new format.
|
| 876 |
-
"""
|
| 877 |
-
# First try to load with the new class
|
| 878 |
-
try:
|
| 879 |
-
return super().from_pretrained(
|
| 880 |
-
pretrained_model_name_or_path, *model_args, **kwargs
|
| 881 |
-
)
|
| 882 |
-
except Exception as e:
|
| 883 |
-
print(f"Failed to load with new class: {e}")
|
| 884 |
-
print("Attempting to load with legacy class and convert...")
|
| 885 |
-
|
| 886 |
-
# Try to load with the old class and convert
|
| 887 |
-
try:
|
| 888 |
-
from transformers import AutoModel
|
| 889 |
-
|
| 890 |
-
old_model = AutoModel.from_pretrained(
|
| 891 |
-
pretrained_model_name_or_path,
|
| 892 |
-
trust_remote_code=True,
|
| 893 |
-
*model_args,
|
| 894 |
-
**kwargs,
|
| 895 |
-
)
|
| 896 |
-
|
| 897 |
-
# Create new model instance
|
| 898 |
-
new_model = cls(old_model.config)
|
| 899 |
-
|
| 900 |
-
# Copy state dict
|
| 901 |
-
new_model.load_state_dict(old_model.state_dict(), strict=False)
|
| 902 |
-
|
| 903 |
-
return new_model
|
| 904 |
-
|
| 905 |
-
except Exception as e2:
|
| 906 |
-
print(f"Failed to convert from legacy format: {e2}")
|
| 907 |
-
raise e
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
# Register the new class
|
| 911 |
-
PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/trainer.py
DELETED
|
@@ -1,753 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Pico Language Model Trainer
|
| 3 |
-
|
| 4 |
-
This Trainer implements a minimalistic end-to-end training pipeline of the Pico language model with
|
| 5 |
-
distributed training support via Lightning Fabric. It provides a modular and configurable training
|
| 6 |
-
pipeline with the features:
|
| 7 |
-
|
| 8 |
-
- Configuration Management: YAML-based configuration for all aspects of training
|
| 9 |
-
- Distributed Training: Multi-GPU support via Lightning Fabric
|
| 10 |
-
- Checkpointing: Regular model saving and training state recovery
|
| 11 |
-
- Evaluation: Periodic model evaluation on validation datasets
|
| 12 |
-
- Logging: Comprehensive metric tracking and experiment monitoring
|
| 13 |
-
- Optimization: Support for gradient accumulation, clipping, and LR scheduling
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
import logging
|
| 17 |
-
import os
|
| 18 |
-
import platform
|
| 19 |
-
from typing import Any, Dict
|
| 20 |
-
|
| 21 |
-
import lightning as L
|
| 22 |
-
import psutil
|
| 23 |
-
import torch
|
| 24 |
-
import torch.nn.functional as F
|
| 25 |
-
import yaml
|
| 26 |
-
from datasets import Dataset, load_dataset
|
| 27 |
-
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
| 28 |
-
|
| 29 |
-
from src.checkpointing import (
|
| 30 |
-
compute_learning_dynamics_states,
|
| 31 |
-
load_checkpoint,
|
| 32 |
-
save_checkpoint,
|
| 33 |
-
save_evaluation_results,
|
| 34 |
-
save_learning_dynamics_states,
|
| 35 |
-
)
|
| 36 |
-
from src.evaluation import run_evaluation
|
| 37 |
-
from src.training.utils import (
|
| 38 |
-
initialize_configuration,
|
| 39 |
-
initialize_dataloader,
|
| 40 |
-
initialize_dataset,
|
| 41 |
-
initialize_fabric,
|
| 42 |
-
initialize_hf_checkpointing,
|
| 43 |
-
initialize_logging,
|
| 44 |
-
initialize_lr_scheduler,
|
| 45 |
-
initialize_model,
|
| 46 |
-
initialize_optimizer,
|
| 47 |
-
initialize_run_dir,
|
| 48 |
-
initialize_tokenizer,
|
| 49 |
-
initialize_wandb,
|
| 50 |
-
)
|
| 51 |
-
from src.training.utils.logging import pretty_print_yaml_config
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class Trainer:
|
| 55 |
-
def __init__(self, config_path: str):
|
| 56 |
-
"""
|
| 57 |
-
Initializes the Trainer class. This Trainer class implements a `train` method, which is the
|
| 58 |
-
main entry point for training the Pico model. Before calling `train`, the Trainer class
|
| 59 |
-
initializes the following:
|
| 60 |
-
|
| 61 |
-
- Configuration loading and validation
|
| 62 |
-
- Model, optimizer, and dataset setup
|
| 63 |
-
- Logging and experiment tracking setup
|
| 64 |
-
- Checkpoint management
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
config_path (str): Path to the YAML configuration file containing any overrides.
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
########################################################
|
| 71 |
-
#
|
| 72 |
-
# Basic Initialization of Configs, Fabric, Model, Optimizer, etc.
|
| 73 |
-
#
|
| 74 |
-
########################################################
|
| 75 |
-
|
| 76 |
-
# Setup Config
|
| 77 |
-
self.configs = initialize_configuration(config_path)
|
| 78 |
-
|
| 79 |
-
# Setup Run Directory (i.e. where we store checkpoints, logs, etc.)
|
| 80 |
-
initialize_run_dir(checkpointing_config=self.configs["checkpointing"])
|
| 81 |
-
|
| 82 |
-
# Setup Logger
|
| 83 |
-
if self.configs["monitoring"].save_to_wandb:
|
| 84 |
-
wandb_logger = initialize_wandb(
|
| 85 |
-
monitoring_config=self.configs["monitoring"],
|
| 86 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 87 |
-
)
|
| 88 |
-
else:
|
| 89 |
-
wandb_logger = None
|
| 90 |
-
|
| 91 |
-
# Setup Fabric
|
| 92 |
-
self.fabric = initialize_fabric(
|
| 93 |
-
training_config=self.configs["training"],
|
| 94 |
-
wandb_logger=wandb_logger,
|
| 95 |
-
)
|
| 96 |
-
L.seed_everything(42, verbose=False)
|
| 97 |
-
|
| 98 |
-
# Optimize for Tensor Cores on RTX 5090
|
| 99 |
-
if self.fabric.device.type == "cuda":
|
| 100 |
-
torch.set_float32_matmul_precision(
|
| 101 |
-
"high"
|
| 102 |
-
) # Best performance for Tensor Cores
|
| 103 |
-
print(
|
| 104 |
-
"Enabled Tensor Core optimization: torch.set_float32_matmul_precision('high')"
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
# Set up logging
|
| 108 |
-
self.logger = initialize_logging(
|
| 109 |
-
monitoring_config=self.configs["monitoring"],
|
| 110 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 111 |
-
fabric=self.fabric,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
# Setup Model, Optimizer, and Dataloaders
|
| 115 |
-
self.model = initialize_model(model_config=self.configs["model"])
|
| 116 |
-
self.optimizer = initialize_optimizer(
|
| 117 |
-
training_config=self.configs["training"], model=self.model
|
| 118 |
-
)
|
| 119 |
-
self.lr_scheduler = initialize_lr_scheduler(
|
| 120 |
-
training_config=self.configs["training"], optimizer=self.optimizer
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
# Wrap model and optimizer with Fabric
|
| 124 |
-
self.model, self.optimizer = self.fabric.setup(self.model, self.optimizer)
|
| 125 |
-
|
| 126 |
-
# Setup HuggingFace Checkpointing
|
| 127 |
-
if self.configs["checkpointing"].save_to_hf:
|
| 128 |
-
initialize_hf_checkpointing(
|
| 129 |
-
checkpointing_config=self.configs["checkpointing"], fabric=self.fabric
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
########################################################
|
| 133 |
-
#
|
| 134 |
-
# Boilerplate to deal with loading/resuming from checkpoints
|
| 135 |
-
#
|
| 136 |
-
########################################################
|
| 137 |
-
|
| 138 |
-
self.should_load_checkpoint = self.configs["checkpointing"].training.auto_resume
|
| 139 |
-
|
| 140 |
-
# Possibly load a checkpoint
|
| 141 |
-
if self.should_load_checkpoint:
|
| 142 |
-
resume_checkpoint = load_checkpoint(
|
| 143 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 144 |
-
checkpoint_step="latest",
|
| 145 |
-
fabric=self.fabric,
|
| 146 |
-
model=self.model,
|
| 147 |
-
optimizer=self.optimizer,
|
| 148 |
-
lr_scheduler=self.lr_scheduler,
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
if resume_checkpoint:
|
| 152 |
-
(
|
| 153 |
-
self.model,
|
| 154 |
-
self.optimizer,
|
| 155 |
-
self.lr_scheduler,
|
| 156 |
-
self.initial_batch_step,
|
| 157 |
-
) = resume_checkpoint
|
| 158 |
-
else:
|
| 159 |
-
self.initial_batch_step = 0
|
| 160 |
-
else:
|
| 161 |
-
self.initial_batch_step = 0
|
| 162 |
-
|
| 163 |
-
########################################################
|
| 164 |
-
#
|
| 165 |
-
# Initialization of Dataset & DataLoader (possibly fast-forwarding to correct batch)
|
| 166 |
-
#
|
| 167 |
-
########################################################
|
| 168 |
-
|
| 169 |
-
self.train_dataset, fast_forward_steps = initialize_dataset(
|
| 170 |
-
data_config=self.configs["data"],
|
| 171 |
-
fabric=self.fabric,
|
| 172 |
-
initial_batch_step=self.initial_batch_step,
|
| 173 |
-
return_fast_forward_steps=True,
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
self.train_dataloader = initialize_dataloader(
|
| 177 |
-
data_config=self.configs["data"],
|
| 178 |
-
training_config=self.configs["training"],
|
| 179 |
-
fabric=self.fabric,
|
| 180 |
-
dataset=self.train_dataset,
|
| 181 |
-
)
|
| 182 |
-
self.train_dataloader = self.fabric.setup_dataloaders(
|
| 183 |
-
self.train_dataloader, use_distributed_sampler=False
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
self.tokenizer = initialize_tokenizer(data_config=self.configs["data"])
|
| 187 |
-
|
| 188 |
-
# NOTE: We may need to fast-forward the iterator to the correct step so that we can
|
| 189 |
-
# continue from the correct batch of data we would have seen had training not
|
| 190 |
-
# previously stopped.
|
| 191 |
-
train_iterator = iter(self.train_dataloader)
|
| 192 |
-
if fast_forward_steps > 0:
|
| 193 |
-
fast_forward_sub_steps = (
|
| 194 |
-
fast_forward_steps
|
| 195 |
-
* self.configs["training"].optimization.gradient_accumulation_steps
|
| 196 |
-
)
|
| 197 |
-
for _ in range(fast_forward_sub_steps):
|
| 198 |
-
next(train_iterator)
|
| 199 |
-
|
| 200 |
-
self.train_iterator = train_iterator
|
| 201 |
-
|
| 202 |
-
# NOTE: Sychronizing processes after fast-forwarding iterator
|
| 203 |
-
self.fabric.barrier()
|
| 204 |
-
|
| 205 |
-
########################################################
|
| 206 |
-
#
|
| 207 |
-
# Helper flags used during training for checkpointing and evaluation
|
| 208 |
-
#
|
| 209 |
-
########################################################
|
| 210 |
-
|
| 211 |
-
# Helper flag to determine if we should evaluate the model
|
| 212 |
-
self.should_evaluate = (
|
| 213 |
-
self.configs["evaluation"].metrics is not None
|
| 214 |
-
and len(self.configs["evaluation"].metrics) > 0
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
self.should_compute_learning_dynamics = (
|
| 218 |
-
self.configs["checkpointing"].learning_dynamics.layer_suffixes is not None
|
| 219 |
-
and len(self.configs["checkpointing"].learning_dynamics.layer_suffixes) > 0
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
if self.should_compute_learning_dynamics:
|
| 223 |
-
if self.configs["checkpointing"].learning_dynamics.eval_data is not None:
|
| 224 |
-
self.learning_dynamics_eval_dataset = load_dataset(
|
| 225 |
-
self.configs["checkpointing"].learning_dynamics.eval_data,
|
| 226 |
-
split="val",
|
| 227 |
-
)
|
| 228 |
-
else:
|
| 229 |
-
self.learning_dynamics_eval_dataset = None
|
| 230 |
-
|
| 231 |
-
def train(self) -> None:
|
| 232 |
-
"""Execute the main training pipeline.
|
| 233 |
-
|
| 234 |
-
This method orchestrates the complete training process by:
|
| 235 |
-
1. Creating an initial checkpoint to save the starting state and evaluate the model as a
|
| 236 |
-
baseline
|
| 237 |
-
2. Running the main training loop via `_training_loop`
|
| 238 |
-
3. Handling final checkpointing and evaluation
|
| 239 |
-
|
| 240 |
-
The training progress is tracked through checkpoints and evaluations
|
| 241 |
-
at intervals specified in the configuration.
|
| 242 |
-
"""
|
| 243 |
-
|
| 244 |
-
########################################################
|
| 245 |
-
#
|
| 246 |
-
# Initial Checkpointing and Evaluation
|
| 247 |
-
#
|
| 248 |
-
########################################################
|
| 249 |
-
|
| 250 |
-
# Save Initial Checkpoint -- If the checkpoint already exists, this performs a no-op
|
| 251 |
-
save_checkpoint(
|
| 252 |
-
configs=self.configs,
|
| 253 |
-
checkpoint_step=self.initial_batch_step,
|
| 254 |
-
fabric=self.fabric,
|
| 255 |
-
model=self.model,
|
| 256 |
-
optimizer=self.optimizer,
|
| 257 |
-
lr_scheduler=self.lr_scheduler,
|
| 258 |
-
tokenizer=self.tokenizer,
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
# Save Initial Evaluation Results
|
| 262 |
-
if self.should_evaluate:
|
| 263 |
-
if self.initial_batch_step == 0:
|
| 264 |
-
evaluation_results = run_evaluation(
|
| 265 |
-
evaluation_config=self.configs["evaluation"],
|
| 266 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 267 |
-
fabric=self.fabric,
|
| 268 |
-
model=self.model,
|
| 269 |
-
)
|
| 270 |
-
self._log_evaluation_results(
|
| 271 |
-
evaluation_results, self.initial_batch_step
|
| 272 |
-
)
|
| 273 |
-
save_evaluation_results(
|
| 274 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 275 |
-
fabric=self.fabric,
|
| 276 |
-
evaluation_results=evaluation_results,
|
| 277 |
-
checkpoint_step=self.initial_batch_step,
|
| 278 |
-
)
|
| 279 |
-
else:
|
| 280 |
-
# NOTE: If the run crashed while evaluating, we need to restart the evaluation
|
| 281 |
-
eval_results_path = os.path.join(
|
| 282 |
-
self.configs["checkpointing"].evaluation.eval_results_dir,
|
| 283 |
-
f"step_{self.initial_batch_step}.json",
|
| 284 |
-
)
|
| 285 |
-
if not os.path.exists(eval_results_path):
|
| 286 |
-
evaluation_results = run_evaluation(
|
| 287 |
-
evaluation_config=self.configs["evaluation"],
|
| 288 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 289 |
-
fabric=self.fabric,
|
| 290 |
-
model=self.model,
|
| 291 |
-
)
|
| 292 |
-
self._log_evaluation_results(
|
| 293 |
-
evaluation_results, self.initial_batch_step
|
| 294 |
-
)
|
| 295 |
-
save_evaluation_results(
|
| 296 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 297 |
-
fabric=self.fabric,
|
| 298 |
-
evaluation_results=evaluation_results,
|
| 299 |
-
checkpoint_step=self.initial_batch_step,
|
| 300 |
-
)
|
| 301 |
-
|
| 302 |
-
########################################################
|
| 303 |
-
#
|
| 304 |
-
# Main Training Loop (see `_training_loop` for details)
|
| 305 |
-
#
|
| 306 |
-
########################################################
|
| 307 |
-
|
| 308 |
-
if self.initial_batch_step < self.configs["training"].max_steps:
|
| 309 |
-
self._log_training_configuration()
|
| 310 |
-
final_step = self._training_loop()
|
| 311 |
-
else:
|
| 312 |
-
final_step = self.initial_batch_step
|
| 313 |
-
|
| 314 |
-
########################################################
|
| 315 |
-
#
|
| 316 |
-
# Final Checkpointing and Evaluation
|
| 317 |
-
#
|
| 318 |
-
########################################################
|
| 319 |
-
|
| 320 |
-
# Save Learning Dynamics States
|
| 321 |
-
if self.should_compute_learning_dynamics:
|
| 322 |
-
if self.learning_dynamics_eval_dataset is not None:
|
| 323 |
-
self.log(f"Step {final_step} -- 📈 Saving Learning Dynamics")
|
| 324 |
-
learning_dynamics_val_states = compute_learning_dynamics_states(
|
| 325 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 326 |
-
fabric=self.fabric,
|
| 327 |
-
model=self.model,
|
| 328 |
-
dataset=self.learning_dynamics_eval_dataset,
|
| 329 |
-
compute_gradients=True,
|
| 330 |
-
)
|
| 331 |
-
save_learning_dynamics_states(
|
| 332 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 333 |
-
fabric=self.fabric,
|
| 334 |
-
learning_dynamics_states=learning_dynamics_val_states,
|
| 335 |
-
checkpoint_step=final_step,
|
| 336 |
-
prefix="val",
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
# Handle checkpointing and final evaluation
|
| 340 |
-
if final_step % self.configs["checkpointing"].save_every_n_steps != 0:
|
| 341 |
-
self.log(f"Step {final_step} -- 💾 Saving Final Checkpoint")
|
| 342 |
-
save_checkpoint(
|
| 343 |
-
configs=self.configs,
|
| 344 |
-
checkpoint_step=final_step,
|
| 345 |
-
fabric=self.fabric,
|
| 346 |
-
model=self.model,
|
| 347 |
-
optimizer=self.optimizer,
|
| 348 |
-
lr_scheduler=self.lr_scheduler,
|
| 349 |
-
tokenizer=self.tokenizer,
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
-
# Final evaluation
|
| 353 |
-
if self.should_evaluate:
|
| 354 |
-
evaluation_results = run_evaluation(
|
| 355 |
-
evaluation_config=self.configs["evaluation"],
|
| 356 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 357 |
-
fabric=self.fabric,
|
| 358 |
-
model=self.model,
|
| 359 |
-
)
|
| 360 |
-
self._log_evaluation_results(evaluation_results, final_step)
|
| 361 |
-
save_evaluation_results(
|
| 362 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 363 |
-
checkpoint_step=final_step,
|
| 364 |
-
fabric=self.fabric,
|
| 365 |
-
evaluation_results=evaluation_results,
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
self.log(f"🎉 Training complete! Final step: {final_step}")
|
| 369 |
-
|
| 370 |
-
if final_step < self.configs["training"].max_steps:
|
| 371 |
-
self.log(
|
| 372 |
-
f"\t Note: Training stopped before max steps ({self.configs['training'].max_steps})",
|
| 373 |
-
level=logging.WARNING,
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
# Cleanup distributed training
|
| 377 |
-
self.fabric.barrier()
|
| 378 |
-
if torch.cuda.is_available():
|
| 379 |
-
torch.cuda.empty_cache()
|
| 380 |
-
if torch.distributed.is_initialized():
|
| 381 |
-
torch.distributed.destroy_process_group()
|
| 382 |
-
|
| 383 |
-
del self.train_dataloader # NOTE: shutting down worker nodes
|
| 384 |
-
|
| 385 |
-
self.fabric.barrier()
|
| 386 |
-
|
| 387 |
-
def _training_loop(self) -> int:
|
| 388 |
-
"""Execute the main training loop.
|
| 389 |
-
|
| 390 |
-
This method orchestrates the core training loop and includes the following features:
|
| 391 |
-
- Gradient accumulation
|
| 392 |
-
- Gradient clipping
|
| 393 |
-
- Periodic model evaluation and checkpointing
|
| 394 |
-
- Learning Dynamics Checkpointing
|
| 395 |
-
- Learning rate scheduling
|
| 396 |
-
- Logging of training metrics including loss and learning rate
|
| 397 |
-
- Handling of infinite/NaN losses
|
| 398 |
-
|
| 399 |
-
Returns:
|
| 400 |
-
int: The final step count reached during training.
|
| 401 |
-
NOTE: A complete training run should match the configured max_steps.
|
| 402 |
-
"""
|
| 403 |
-
# Setup training loop variables
|
| 404 |
-
batch_step = self.initial_batch_step
|
| 405 |
-
|
| 406 |
-
# NOTE: these are used to compute the average loss over a training interval.
|
| 407 |
-
# This is more accurate than using the loss at the end of the interval.
|
| 408 |
-
interval_loss = torch.tensor(0.0, device=self.fabric.device)
|
| 409 |
-
interval_steps = torch.tensor(0, device=self.fabric.device)
|
| 410 |
-
interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device)
|
| 411 |
-
|
| 412 |
-
if self.should_compute_learning_dynamics:
|
| 413 |
-
# NOTE: we basically re-construct the full batch here so that we can compute learning dynamics
|
| 414 |
-
training_batch = {"input_ids": []}
|
| 415 |
-
|
| 416 |
-
# NOTE: determine what sub-batch we should start from
|
| 417 |
-
initial_sub_batch_step = (
|
| 418 |
-
batch_step
|
| 419 |
-
* self.configs["training"].optimization.gradient_accumulation_steps
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
###############################################################
|
| 423 |
-
#
|
| 424 |
-
# Core loop starts here
|
| 425 |
-
# NOTE: the ratio between sub_batch_step and batch_step
|
| 426 |
-
# is the configured number of gradient_accumulation_steps
|
| 427 |
-
# i.e. with 32 configured gradient accumulation steps,
|
| 428 |
-
# there are 32 sub_batch_steps for each batch_step
|
| 429 |
-
#
|
| 430 |
-
###############################################################
|
| 431 |
-
|
| 432 |
-
for sub_batch_step, sub_batch in enumerate(
|
| 433 |
-
self.train_iterator, start=initial_sub_batch_step
|
| 434 |
-
):
|
| 435 |
-
# NOTE: We want to store the entire training batch whenever we are computing learning dynamics
|
| 436 |
-
# and we are at a checkpointing step.
|
| 437 |
-
should_store_training_batch = self.should_compute_learning_dynamics and (
|
| 438 |
-
batch_step % self.configs["checkpointing"].save_every_n_steps == 0
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
########################################################
|
| 442 |
-
#
|
| 443 |
-
# Forward Pass
|
| 444 |
-
#
|
| 445 |
-
########################################################
|
| 446 |
-
|
| 447 |
-
_input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device)
|
| 448 |
-
input_ids = _input_ids[:, :-1]
|
| 449 |
-
labels = _input_ids[:, 1:]
|
| 450 |
-
|
| 451 |
-
if should_store_training_batch:
|
| 452 |
-
gathered_input_ids = self.fabric.all_gather(_input_ids)
|
| 453 |
-
|
| 454 |
-
# NOTE: On multi-GPU, we need to reshape the input_ids to be a 2D tensor; on
|
| 455 |
-
# a single GPU, the input_ids are already a 2D tensor.
|
| 456 |
-
if self.fabric.world_size > 1:
|
| 457 |
-
gathered_input_ids = gathered_input_ids.reshape(
|
| 458 |
-
-1, *gathered_input_ids.shape[2:]
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
training_batch["input_ids"].extend(gathered_input_ids.tolist())
|
| 462 |
-
|
| 463 |
-
# Forward pass
|
| 464 |
-
model_output, _ = self.model(input_ids)
|
| 465 |
-
model_output = model_output.transpose(1, 2)
|
| 466 |
-
|
| 467 |
-
########################################################
|
| 468 |
-
#
|
| 469 |
-
# Gradient accumulation
|
| 470 |
-
#
|
| 471 |
-
########################################################
|
| 472 |
-
|
| 473 |
-
should_accumulate_gradients = (sub_batch_step + 1) % self.configs[
|
| 474 |
-
"training"
|
| 475 |
-
].optimization.gradient_accumulation_steps != 0
|
| 476 |
-
|
| 477 |
-
with self.fabric.no_backward_sync(
|
| 478 |
-
self.model, enabled=should_accumulate_gradients
|
| 479 |
-
):
|
| 480 |
-
loss = F.cross_entropy(model_output, labels)
|
| 481 |
-
self.fabric.backward(
|
| 482 |
-
loss
|
| 483 |
-
/ self.configs["training"].optimization.gradient_accumulation_steps,
|
| 484 |
-
model=self.model,
|
| 485 |
-
)
|
| 486 |
-
|
| 487 |
-
if torch.isnan(loss) or torch.isinf(loss):
|
| 488 |
-
interval_inf_or_nan_count += 1
|
| 489 |
-
else:
|
| 490 |
-
interval_loss += loss.item()
|
| 491 |
-
interval_steps += 1
|
| 492 |
-
|
| 493 |
-
# NOTE: if we are not accumulating gradients, we should skip the logging and optimization steps
|
| 494 |
-
if should_accumulate_gradients:
|
| 495 |
-
continue
|
| 496 |
-
|
| 497 |
-
########################################################
|
| 498 |
-
#
|
| 499 |
-
# Logging
|
| 500 |
-
#
|
| 501 |
-
########################################################
|
| 502 |
-
|
| 503 |
-
if batch_step % self.configs["monitoring"].logging.log_every_n_steps == 0:
|
| 504 |
-
self._log_training_metrics(
|
| 505 |
-
interval_loss=interval_loss,
|
| 506 |
-
interval_steps=interval_steps,
|
| 507 |
-
interval_inf_or_nan_count=interval_inf_or_nan_count,
|
| 508 |
-
batch_step=batch_step,
|
| 509 |
-
)
|
| 510 |
-
interval_loss = torch.tensor(0.0, device=self.fabric.device)
|
| 511 |
-
interval_steps = torch.tensor(0, device=self.fabric.device)
|
| 512 |
-
interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device)
|
| 513 |
-
|
| 514 |
-
########################################################
|
| 515 |
-
#
|
| 516 |
-
# Learning Dynamics Checkpointing
|
| 517 |
-
#
|
| 518 |
-
########################################################
|
| 519 |
-
|
| 520 |
-
if batch_step % self.configs["checkpointing"].save_every_n_steps == 0:
|
| 521 |
-
if self.should_compute_learning_dynamics:
|
| 522 |
-
self.log(f"Step {batch_step} -- 📈 Saving Learning Dynamics")
|
| 523 |
-
|
| 524 |
-
# Training Batch Learning Dynamics
|
| 525 |
-
training_batch_dataset = Dataset.from_dict(training_batch)
|
| 526 |
-
|
| 527 |
-
learning_dynamics_train_states = compute_learning_dynamics_states(
|
| 528 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 529 |
-
fabric=self.fabric,
|
| 530 |
-
model=self.model,
|
| 531 |
-
dataset=training_batch_dataset,
|
| 532 |
-
compute_gradients=True,
|
| 533 |
-
)
|
| 534 |
-
|
| 535 |
-
save_learning_dynamics_states(
|
| 536 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 537 |
-
checkpoint_step=batch_step,
|
| 538 |
-
prefix="train",
|
| 539 |
-
fabric=self.fabric,
|
| 540 |
-
learning_dynamics_states=learning_dynamics_train_states,
|
| 541 |
-
learning_dynamics_dataset=training_batch_dataset,
|
| 542 |
-
tokenizer=self.tokenizer,
|
| 543 |
-
)
|
| 544 |
-
training_batch = {
|
| 545 |
-
"input_ids": []
|
| 546 |
-
} # Resetting training_batch for next training batch
|
| 547 |
-
|
| 548 |
-
# Validation Data Learning Dynamics
|
| 549 |
-
if self.learning_dynamics_eval_dataset is not None:
|
| 550 |
-
learning_dynamics_val_states = compute_learning_dynamics_states(
|
| 551 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 552 |
-
fabric=self.fabric,
|
| 553 |
-
model=self.model,
|
| 554 |
-
dataset=self.learning_dynamics_eval_dataset,
|
| 555 |
-
compute_gradients=True,
|
| 556 |
-
)
|
| 557 |
-
save_learning_dynamics_states(
|
| 558 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 559 |
-
checkpoint_step=batch_step,
|
| 560 |
-
prefix="val",
|
| 561 |
-
fabric=self.fabric,
|
| 562 |
-
learning_dynamics_states=learning_dynamics_val_states,
|
| 563 |
-
)
|
| 564 |
-
|
| 565 |
-
########################################################
|
| 566 |
-
#
|
| 567 |
-
# Optimization step
|
| 568 |
-
#
|
| 569 |
-
########################################################
|
| 570 |
-
|
| 571 |
-
self.optimizer.step()
|
| 572 |
-
self.optimizer.zero_grad()
|
| 573 |
-
self.lr_scheduler.step()
|
| 574 |
-
|
| 575 |
-
batch_step += 1
|
| 576 |
-
|
| 577 |
-
########################################################
|
| 578 |
-
#
|
| 579 |
-
# Training Checkpointing and evaluation
|
| 580 |
-
#
|
| 581 |
-
########################################################
|
| 582 |
-
|
| 583 |
-
if batch_step % self.configs["checkpointing"].save_every_n_steps == 0:
|
| 584 |
-
self.log(f"Step {batch_step} -- 💾 Saving Checkpoint")
|
| 585 |
-
save_checkpoint(
|
| 586 |
-
configs=self.configs,
|
| 587 |
-
checkpoint_step=batch_step,
|
| 588 |
-
fabric=self.fabric,
|
| 589 |
-
model=self.model,
|
| 590 |
-
optimizer=self.optimizer,
|
| 591 |
-
lr_scheduler=self.lr_scheduler,
|
| 592 |
-
tokenizer=self.tokenizer,
|
| 593 |
-
)
|
| 594 |
-
|
| 595 |
-
if self.should_evaluate:
|
| 596 |
-
evaluation_results = run_evaluation(
|
| 597 |
-
evaluation_config=self.configs["evaluation"],
|
| 598 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 599 |
-
fabric=self.fabric,
|
| 600 |
-
model=self.model,
|
| 601 |
-
)
|
| 602 |
-
if evaluation_results is not None:
|
| 603 |
-
self._log_evaluation_results(evaluation_results, batch_step)
|
| 604 |
-
save_evaluation_results(
|
| 605 |
-
checkpointing_config=self.configs["checkpointing"],
|
| 606 |
-
fabric=self.fabric,
|
| 607 |
-
evaluation_results=evaluation_results,
|
| 608 |
-
checkpoint_step=batch_step,
|
| 609 |
-
)
|
| 610 |
-
|
| 611 |
-
# Break if we've reached training steps
|
| 612 |
-
if batch_step >= self.configs["training"].max_steps:
|
| 613 |
-
break
|
| 614 |
-
|
| 615 |
-
return batch_step
|
| 616 |
-
|
| 617 |
-
########################################################
|
| 618 |
-
#
|
| 619 |
-
# Trainer Logging Functinalities
|
| 620 |
-
#
|
| 621 |
-
########################################################
|
| 622 |
-
|
| 623 |
-
def _log_training_metrics(
|
| 624 |
-
self,
|
| 625 |
-
interval_loss: torch.Tensor,
|
| 626 |
-
interval_steps: torch.Tensor,
|
| 627 |
-
interval_inf_or_nan_count: torch.Tensor,
|
| 628 |
-
batch_step: int,
|
| 629 |
-
):
|
| 630 |
-
"""
|
| 631 |
-
Gathers together the training metrics computed across all processes in distributed training
|
| 632 |
-
and logs them in a tree-style format.
|
| 633 |
-
"""
|
| 634 |
-
gathered_interval_loss = self.fabric.all_reduce(
|
| 635 |
-
interval_loss, reduce_op="sum"
|
| 636 |
-
).item()
|
| 637 |
-
gathered_interval_inf_or_nan_count = self.fabric.all_reduce(
|
| 638 |
-
interval_inf_or_nan_count, reduce_op="sum"
|
| 639 |
-
).item()
|
| 640 |
-
gathered_interval_steps = self.fabric.all_reduce(
|
| 641 |
-
interval_steps, reduce_op="sum"
|
| 642 |
-
).item()
|
| 643 |
-
|
| 644 |
-
avg_loss = (
|
| 645 |
-
gathered_interval_loss / gathered_interval_steps
|
| 646 |
-
if gathered_interval_steps > 0
|
| 647 |
-
else float("inf")
|
| 648 |
-
)
|
| 649 |
-
|
| 650 |
-
self.fabric.log("train/loss", avg_loss, step=batch_step)
|
| 651 |
-
self.fabric.log(
|
| 652 |
-
"trainer/inf_or_nan_count",
|
| 653 |
-
gathered_interval_inf_or_nan_count,
|
| 654 |
-
step=batch_step,
|
| 655 |
-
)
|
| 656 |
-
self.fabric.log(
|
| 657 |
-
"trainer/learning_rate",
|
| 658 |
-
self.lr_scheduler.get_last_lr()[0],
|
| 659 |
-
step=batch_step,
|
| 660 |
-
)
|
| 661 |
-
|
| 662 |
-
# Log to console in tree format
|
| 663 |
-
self.log(f"Step {batch_step} -- 🔄 Training Metrics")
|
| 664 |
-
self.log(f"├── Loss: {avg_loss:.4f}")
|
| 665 |
-
self.log(f"├── Learning Rate: {self.lr_scheduler.get_last_lr()[0]:.2e}")
|
| 666 |
-
self.log(f"└── Inf/NaN count: {gathered_interval_inf_or_nan_count}")
|
| 667 |
-
|
| 668 |
-
def _log_evaluation_results(
|
| 669 |
-
self, evaluation_results: Dict[str, Any], batch_step: int
|
| 670 |
-
):
|
| 671 |
-
"""Log model evaluation metrics to experiment tracking system and console."""
|
| 672 |
-
self.log(f"Step {batch_step} -- 📊 Evaluation Results")
|
| 673 |
-
for i, (metric, result) in enumerate(evaluation_results.items()):
|
| 674 |
-
prefix = "└──" if i == len(evaluation_results) - 1 else "├──"
|
| 675 |
-
self.log(f"{prefix} {metric}: {result}")
|
| 676 |
-
self.fabric.log(f"eval/{metric}", result, step=batch_step)
|
| 677 |
-
|
| 678 |
-
def _log_training_configuration(self):
|
| 679 |
-
"""
|
| 680 |
-
Log training configuration details as well as runtime information about the hardware,
|
| 681 |
-
software, and batch settings.
|
| 682 |
-
|
| 683 |
-
This function is called at the beginning of the training loop to provide a summary of the
|
| 684 |
-
training configuration.
|
| 685 |
-
"""
|
| 686 |
-
|
| 687 |
-
total_params = sum(p.numel() for p in self.model.parameters())
|
| 688 |
-
trainable_params = sum(
|
| 689 |
-
p.numel() for p in self.model.parameters() if p.requires_grad
|
| 690 |
-
)
|
| 691 |
-
global_batch_size = self.configs["data"].dataloader.batch_size
|
| 692 |
-
per_device_batch_size = self.train_dataloader.batch_size
|
| 693 |
-
gradient_accumulation_steps = self.configs[
|
| 694 |
-
"training"
|
| 695 |
-
].optimization.gradient_accumulation_steps
|
| 696 |
-
|
| 697 |
-
device_type = ""
|
| 698 |
-
fabric_device = str(self.fabric.device)
|
| 699 |
-
if torch.cuda.is_available() and "cuda" in fabric_device:
|
| 700 |
-
device_type = torch.cuda.get_device_name(self.fabric.device)
|
| 701 |
-
elif torch.backends.mps.is_available() and "mps" in fabric_device:
|
| 702 |
-
device_type = "MPS (Apple Silicon)"
|
| 703 |
-
else:
|
| 704 |
-
device_type = "CPU"
|
| 705 |
-
|
| 706 |
-
training_config_path = os.path.join(
|
| 707 |
-
self.configs["checkpointing"].runs_dir,
|
| 708 |
-
self.configs["checkpointing"].run_name,
|
| 709 |
-
"training_config.yaml",
|
| 710 |
-
)
|
| 711 |
-
if os.path.exists(training_config_path):
|
| 712 |
-
self.log("=" * 50)
|
| 713 |
-
self.log("✨ Training Configuration")
|
| 714 |
-
self.log("=" * 50)
|
| 715 |
-
training_config = yaml.safe_load(open(training_config_path, "r"))
|
| 716 |
-
pretty_print_yaml_config(self.logger, training_config)
|
| 717 |
-
|
| 718 |
-
self.log("=" * 50)
|
| 719 |
-
self.log("⛭ Runtime Summary:")
|
| 720 |
-
self.log("=" * 50)
|
| 721 |
-
self.log(f"Starting from step: {self.initial_batch_step}")
|
| 722 |
-
|
| 723 |
-
self.log("Model Setup:")
|
| 724 |
-
self.log(f"└─ Total Parameters: {total_params:,}")
|
| 725 |
-
self.log(f"└─ Trainable Parameters: {trainable_params:,}")
|
| 726 |
-
|
| 727 |
-
self.log("Distributed Setup:")
|
| 728 |
-
self.log(f"└─ Number of Devices: {self.fabric.world_size}")
|
| 729 |
-
self.log(f"└─ Device Type: {device_type}")
|
| 730 |
-
self.log(
|
| 731 |
-
f"└─ Available Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
|
| 732 |
-
if torch.cuda.is_available()
|
| 733 |
-
else f"└─ Available Memory: {psutil.virtual_memory().total / 1e9:.2f} GB"
|
| 734 |
-
)
|
| 735 |
-
|
| 736 |
-
self.log("Software Setup:")
|
| 737 |
-
self.log(f"└─ Python Version: {platform.python_version()}")
|
| 738 |
-
self.log(f"└─ PyTorch Version: {torch.__version__}")
|
| 739 |
-
self.log(
|
| 740 |
-
f"└─ CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}"
|
| 741 |
-
)
|
| 742 |
-
self.log(f"└─ Operating System: {platform.system()} {platform.release()}")
|
| 743 |
-
|
| 744 |
-
self.log("Batch Size Configuration:")
|
| 745 |
-
self.log(f"└─ Global Batch Size: {global_batch_size}")
|
| 746 |
-
self.log(f"└─ Per Device Batch Size: {per_device_batch_size}")
|
| 747 |
-
self.log(f"└─ Gradient Accumulation Steps: {gradient_accumulation_steps}")
|
| 748 |
-
self.log("=" * 50)
|
| 749 |
-
|
| 750 |
-
@rank_zero_only
|
| 751 |
-
def log(self, msg: str, level: int = logging.INFO) -> None:
|
| 752 |
-
"""NOTE: Log messages only from rank zero process."""
|
| 753 |
-
self.logger.log(level, msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/utils/__init__.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utility package that contains functions for the training process, e.g. initialization, logging, etc.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
# For convenience, we export the initialization functions here
|
| 6 |
-
from .initialization import (
|
| 7 |
-
initialize_configuration,
|
| 8 |
-
initialize_dataloader,
|
| 9 |
-
initialize_dataset,
|
| 10 |
-
initialize_fabric,
|
| 11 |
-
initialize_hf_checkpointing,
|
| 12 |
-
initialize_logging,
|
| 13 |
-
initialize_lr_scheduler,
|
| 14 |
-
initialize_model,
|
| 15 |
-
initialize_optimizer,
|
| 16 |
-
initialize_run_dir,
|
| 17 |
-
initialize_tokenizer,
|
| 18 |
-
initialize_wandb,
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
__all__ = [
|
| 22 |
-
"initialize_configuration",
|
| 23 |
-
"initialize_dataloader",
|
| 24 |
-
"initialize_dataset",
|
| 25 |
-
"initialize_fabric",
|
| 26 |
-
"initialize_hf_checkpointing",
|
| 27 |
-
"initialize_logging",
|
| 28 |
-
"initialize_lr_scheduler",
|
| 29 |
-
"initialize_model",
|
| 30 |
-
"initialize_optimizer",
|
| 31 |
-
"initialize_run_dir",
|
| 32 |
-
"initialize_tokenizer",
|
| 33 |
-
"initialize_wandb",
|
| 34 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/utils/data.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utilities for data loading and processing.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from torch.utils.data import IterableDataset
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class ShardedIterableDataset(IterableDataset):
|
| 9 |
-
"""
|
| 10 |
-
A super simple implementation of a sharded iterable dataset that enables DataParallelism
|
| 11 |
-
across multiple workers. Ensures that each worker gets a unique shard of the dataset.
|
| 12 |
-
|
| 13 |
-
NOTE: Also works fine if there is only one worker.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
def __init__(self, dataset, rank, world_size):
|
| 17 |
-
self.dataset = dataset
|
| 18 |
-
self.rank = rank
|
| 19 |
-
self.world_size = world_size
|
| 20 |
-
|
| 21 |
-
def __iter__(self):
|
| 22 |
-
iterator = iter(self.dataset)
|
| 23 |
-
# NOTE: Start by skipping to this worker's shard
|
| 24 |
-
for _ in range(self.rank):
|
| 25 |
-
next(iterator)
|
| 26 |
-
|
| 27 |
-
# NOTE: Yield every world_size-th item
|
| 28 |
-
while True:
|
| 29 |
-
try:
|
| 30 |
-
yield next(iterator)
|
| 31 |
-
# Skip other workers' samples
|
| 32 |
-
for _ in range(self.world_size - 1):
|
| 33 |
-
next(iterator)
|
| 34 |
-
except StopIteration:
|
| 35 |
-
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/utils/initialization.py
DELETED
|
@@ -1,702 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utilities for initializing components of the training process.
|
| 3 |
-
|
| 4 |
-
Here, we initialize all of the components that are part of the learning process. From logging,
|
| 5 |
-
and checkpointing to the optimizer to the dataset and the dataloader, this file contains the
|
| 6 |
-
logic for setting up the classes and functions that are used in the training loop.
|
| 7 |
-
|
| 8 |
-
As always, this code is meant to be basic. We hard-code the obvious defaults, and leave the
|
| 9 |
-
more experimental stuff to you.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import logging
|
| 13 |
-
import math
|
| 14 |
-
import os
|
| 15 |
-
import warnings
|
| 16 |
-
from dataclasses import fields, is_dataclass
|
| 17 |
-
from datetime import datetime
|
| 18 |
-
from typing import Dict, Optional, Union
|
| 19 |
-
|
| 20 |
-
import lightning as L
|
| 21 |
-
import torch
|
| 22 |
-
import yaml
|
| 23 |
-
from datasets import Dataset, DownloadConfig, load_dataset
|
| 24 |
-
from datasets import config as datasets_config
|
| 25 |
-
from huggingface_hub import add_collection_item, create_branch, create_repo
|
| 26 |
-
from lightning.fabric.loggers import Logger as FabricLogger
|
| 27 |
-
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
| 28 |
-
from torch.utils.data import DataLoader
|
| 29 |
-
from transformers import AutoTokenizer
|
| 30 |
-
|
| 31 |
-
import wandb
|
| 32 |
-
from src.config import (
|
| 33 |
-
CheckpointingConfig,
|
| 34 |
-
DataConfig,
|
| 35 |
-
EvaluationConfig,
|
| 36 |
-
ModelConfig,
|
| 37 |
-
MonitoringConfig,
|
| 38 |
-
TrainingConfig,
|
| 39 |
-
)
|
| 40 |
-
from src.model import PicoDecoder
|
| 41 |
-
from src.training.utils.io import use_backoff
|
| 42 |
-
from wandb.integration.lightning.fabric import WandbLogger
|
| 43 |
-
|
| 44 |
-
warnings.filterwarnings(
|
| 45 |
-
"ignore",
|
| 46 |
-
message=".*This integration is tested and supported for lightning Fabric.*",
|
| 47 |
-
)
|
| 48 |
-
warnings.filterwarnings(
|
| 49 |
-
"ignore",
|
| 50 |
-
message=".*Please report any issues to.*",
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
########################################################
|
| 54 |
-
#
|
| 55 |
-
# Basic Initialization
|
| 56 |
-
#
|
| 57 |
-
########################################################
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def _apply_config_overrides(config, overrides: dict):
|
| 61 |
-
"""Recursively apply configuration overrides to a dataclass config object.
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
config: Base configuration object (must be a dataclass)
|
| 65 |
-
overrides: Dictionary of override values matching config structure
|
| 66 |
-
|
| 67 |
-
Returns:
|
| 68 |
-
Modified config object with overrides to the config.
|
| 69 |
-
"""
|
| 70 |
-
for field in fields(config):
|
| 71 |
-
field_value = getattr(config, field.name)
|
| 72 |
-
if is_dataclass(field_value):
|
| 73 |
-
_apply_config_overrides(field_value, overrides.get(field.name, {}))
|
| 74 |
-
else:
|
| 75 |
-
if field.name in overrides:
|
| 76 |
-
setattr(config, field.name, overrides[field.name])
|
| 77 |
-
return config
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def initialize_configuration(
|
| 81 |
-
config_path: Optional[str] = None,
|
| 82 |
-
) -> Dict[
|
| 83 |
-
str,
|
| 84 |
-
Union[
|
| 85 |
-
DataConfig,
|
| 86 |
-
ModelConfig,
|
| 87 |
-
TrainingConfig,
|
| 88 |
-
EvaluationConfig,
|
| 89 |
-
MonitoringConfig,
|
| 90 |
-
CheckpointingConfig,
|
| 91 |
-
],
|
| 92 |
-
]:
|
| 93 |
-
"""Initialize configuration objects with optional overrides from a YAML file.
|
| 94 |
-
|
| 95 |
-
This function initializes all of the configuration objects, and then applies
|
| 96 |
-
any overrides from the config_path file. If no config_path is provided,
|
| 97 |
-
the function will use the default configuration objects.
|
| 98 |
-
|
| 99 |
-
Args:
|
| 100 |
-
config_path: Path to a YAML file containing configuration overrides.
|
| 101 |
-
|
| 102 |
-
Returns:
|
| 103 |
-
A dictionary containing the initialized configuration objects.
|
| 104 |
-
"""
|
| 105 |
-
data_config = DataConfig()
|
| 106 |
-
model_config = ModelConfig()
|
| 107 |
-
training_config = TrainingConfig()
|
| 108 |
-
evaluation_config = EvaluationConfig()
|
| 109 |
-
monitoring_config = MonitoringConfig()
|
| 110 |
-
checkpointing_config = CheckpointingConfig()
|
| 111 |
-
|
| 112 |
-
if config_path:
|
| 113 |
-
overrides = yaml.safe_load(open(config_path, "r"))
|
| 114 |
-
data_config = _apply_config_overrides(data_config, overrides.get("data", {}))
|
| 115 |
-
model_config = _apply_config_overrides(model_config, overrides.get("model", {}))
|
| 116 |
-
training_config = _apply_config_overrides(
|
| 117 |
-
training_config, overrides.get("training", {})
|
| 118 |
-
)
|
| 119 |
-
evaluation_config = _apply_config_overrides(
|
| 120 |
-
evaluation_config, overrides.get("evaluation", {})
|
| 121 |
-
)
|
| 122 |
-
monitoring_config = _apply_config_overrides(
|
| 123 |
-
monitoring_config, overrides.get("monitoring", {})
|
| 124 |
-
)
|
| 125 |
-
checkpointing_config = _apply_config_overrides(
|
| 126 |
-
checkpointing_config, overrides.get("checkpointing", {})
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
configs = {
|
| 130 |
-
"data": data_config,
|
| 131 |
-
"model": model_config,
|
| 132 |
-
"training": training_config,
|
| 133 |
-
"evaluation": evaluation_config,
|
| 134 |
-
"monitoring": monitoring_config,
|
| 135 |
-
"checkpointing": checkpointing_config,
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
return configs
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def initialize_run_dir(checkpointing_config: CheckpointingConfig) -> str:
|
| 142 |
-
"""Initialize a directory for the current training run.
|
| 143 |
-
|
| 144 |
-
Creates a unique directory for storing training, evaluation, and logging artifacts.
|
| 145 |
-
If no run name is specified in the config, generates a timestamp-based name.
|
| 146 |
-
|
| 147 |
-
Args:
|
| 148 |
-
checkpointing_config: Configuration object containing run settings.
|
| 149 |
-
NOTE: Must have a 'run_name' attribute that can be None, in which case
|
| 150 |
-
a timestamp-based name will be generated.
|
| 151 |
-
|
| 152 |
-
Returns:
|
| 153 |
-
str: The path to the run directory.
|
| 154 |
-
"""
|
| 155 |
-
run_name = checkpointing_config.run_name
|
| 156 |
-
if run_name is None:
|
| 157 |
-
run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 158 |
-
checkpointing_config.run_name = run_name
|
| 159 |
-
|
| 160 |
-
run_dir = os.path.join(checkpointing_config.runs_dir, run_name)
|
| 161 |
-
|
| 162 |
-
os.makedirs(run_dir, exist_ok=True)
|
| 163 |
-
return run_dir
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def initialize_fabric(
|
| 167 |
-
training_config: TrainingConfig, wandb_logger: Optional[FabricLogger] = None
|
| 168 |
-
):
|
| 169 |
-
"""Initialize Lightning Fabric for distributed training.
|
| 170 |
-
|
| 171 |
-
Sets up a Lightning Fabric instance with the specified configuration for
|
| 172 |
-
handling distributed training, mixed precision, and logging.
|
| 173 |
-
|
| 174 |
-
Args:
|
| 175 |
-
training_config: Configuration object containing fabric settings
|
| 176 |
-
(accelerator, precision, devices, etc.).
|
| 177 |
-
wandb_logger: Optional weights and biases logger instance for experiment tracking
|
| 178 |
-
|
| 179 |
-
Returns:
|
| 180 |
-
L.Fabric: Initialized Lightning Fabric instance.
|
| 181 |
-
|
| 182 |
-
Example:
|
| 183 |
-
>>> fabric = initialize_fabric(training_config, wandb_logger)
|
| 184 |
-
"""
|
| 185 |
-
|
| 186 |
-
total_devices = (
|
| 187 |
-
training_config.fabric.num_devices * training_config.fabric.num_nodes
|
| 188 |
-
)
|
| 189 |
-
|
| 190 |
-
if total_devices > 1:
|
| 191 |
-
strategy = "deepspeed_stage_2"
|
| 192 |
-
else:
|
| 193 |
-
strategy = "auto" # Sets up SingleDevice Strategy by default
|
| 194 |
-
|
| 195 |
-
# NOTE: The strategy is set to use either DeepSpeed (Zero Stage 2) on multi-GPU,
|
| 196 |
-
# or SingleDevice Strategy on single-GPU set ups. If you'd like to use a different strategy,
|
| 197 |
-
# you can change the strategy flag in the fabric initialization, but be aware that this might
|
| 198 |
-
# cause issues with checkpointing, evaluation, etc.
|
| 199 |
-
|
| 200 |
-
fabric = L.Fabric(
|
| 201 |
-
accelerator=training_config.fabric.accelerator,
|
| 202 |
-
precision=training_config.fabric.precision,
|
| 203 |
-
devices=training_config.fabric.num_devices,
|
| 204 |
-
num_nodes=training_config.fabric.num_nodes,
|
| 205 |
-
loggers=[wandb_logger] if wandb_logger is not None else None,
|
| 206 |
-
strategy=strategy,
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
fabric.launch()
|
| 210 |
-
|
| 211 |
-
return fabric
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
########################################################
|
| 215 |
-
#
|
| 216 |
-
# Dataset and Tokenization Initialization
|
| 217 |
-
#
|
| 218 |
-
########################################################
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
@use_backoff(max_retries=20)
|
| 222 |
-
def initialize_dataset(
|
| 223 |
-
data_config: DataConfig,
|
| 224 |
-
fabric: L.Fabric,
|
| 225 |
-
initial_batch_step: Optional[int] = 0,
|
| 226 |
-
return_fast_forward_steps: bool = False,
|
| 227 |
-
):
|
| 228 |
-
"""Initialize dataset based on the given config.
|
| 229 |
-
|
| 230 |
-
This function will return a dataset object, and optionally a fast_forward_steps value.
|
| 231 |
-
|
| 232 |
-
The fast_forward_steps value is the number of steps that we need to fast-forward an iterator by,
|
| 233 |
-
so that we can continue from a ertain batch of data we would have seen had training not previously
|
| 234 |
-
stopped. Depending on how the dataset is loaded, the amount of steps to fast-forward may be
|
| 235 |
-
different from the initial_batch_step value.
|
| 236 |
-
|
| 237 |
-
NOTE: This functionality is primarily useful for streaming datasets (which for large
|
| 238 |
-
datasets is most of the time).
|
| 239 |
-
|
| 240 |
-
Args:
|
| 241 |
-
data_config: Configuration object containing dataset settings.
|
| 242 |
-
fabric: A Lightning Fabric instance.
|
| 243 |
-
initial_batch_step: The initial batch step to fast-forward to.
|
| 244 |
-
return_fast_forward_steps: Whether to return the fast-forward steps value.
|
| 245 |
-
|
| 246 |
-
Returns:
|
| 247 |
-
Dataset: Initialized dataset object.
|
| 248 |
-
Optional[int]: Number of steps to fast-forward the iterator by, if return_fast_forward_steps is True.
|
| 249 |
-
"""
|
| 250 |
-
|
| 251 |
-
datasets_config.STREAMING_READ_MAX_RETRIES = 40 # default is 20
|
| 252 |
-
datasets_config.STREAMING_READ_RETRY_INTERVAL = 10 # default is 5
|
| 253 |
-
download_config = DownloadConfig(
|
| 254 |
-
max_retries=20, # default is 1 and can lead to pre-mature HTTPS errors
|
| 255 |
-
)
|
| 256 |
-
|
| 257 |
-
fast_forward_steps = 0
|
| 258 |
-
|
| 259 |
-
if data_config.dataset.name == "pico-lm/pretokenized-dolma":
|
| 260 |
-
# NOTE: We know that the dataset is sharded into 10,000 shards, so we can easily compute
|
| 261 |
-
# the data file that we need to load in that contains the batch of data at
|
| 262 |
-
# initial_batch_step.
|
| 263 |
-
|
| 264 |
-
if initial_batch_step is not None:
|
| 265 |
-
examples_per_shard = 20_480
|
| 266 |
-
total_shards = 10_000
|
| 267 |
-
batches_per_shard = examples_per_shard // data_config.dataloader.batch_size
|
| 268 |
-
shard_idx = initial_batch_step // batches_per_shard
|
| 269 |
-
|
| 270 |
-
data_files = [
|
| 271 |
-
f"data/train-{str(_shard_idx).zfill(5)}-of-{total_shards}.parquet"
|
| 272 |
-
for _shard_idx in range(shard_idx, total_shards)
|
| 273 |
-
]
|
| 274 |
-
|
| 275 |
-
fast_forward_steps = initial_batch_step % batches_per_shard
|
| 276 |
-
else:
|
| 277 |
-
data_files = None
|
| 278 |
-
|
| 279 |
-
base_dataset = load_dataset(
|
| 280 |
-
data_config.dataset.name,
|
| 281 |
-
split="train",
|
| 282 |
-
streaming=True,
|
| 283 |
-
data_files=data_files,
|
| 284 |
-
download_config=download_config,
|
| 285 |
-
)
|
| 286 |
-
else:
|
| 287 |
-
# NOTE: For other datasets, you might want to add some custom loading logic, especially
|
| 288 |
-
# to help with loading or fast-forwarding to the correct batch.
|
| 289 |
-
|
| 290 |
-
base_dataset = load_dataset(
|
| 291 |
-
data_config.dataset.name,
|
| 292 |
-
split="train",
|
| 293 |
-
streaming=True,
|
| 294 |
-
download_config=download_config,
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
if data_config.dataset.name == "pico-lm/pretokenized-dolma":
|
| 298 |
-
from .data import ShardedIterableDataset
|
| 299 |
-
|
| 300 |
-
# NOTE: We wrap the dataset in a ShardedIterableDataset, which is a custom class that
|
| 301 |
-
# allows us to shard an iterable dataset across multiple processes. This is useful for
|
| 302 |
-
# distributed training, where we want data-parallelism.
|
| 303 |
-
dataset = ShardedIterableDataset(
|
| 304 |
-
base_dataset, fabric.global_rank, fabric.world_size
|
| 305 |
-
)
|
| 306 |
-
else:
|
| 307 |
-
dataset = base_dataset
|
| 308 |
-
|
| 309 |
-
if return_fast_forward_steps:
|
| 310 |
-
return dataset, fast_forward_steps
|
| 311 |
-
else:
|
| 312 |
-
return dataset
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
def initialize_tokenizer(data_config: DataConfig):
|
| 316 |
-
"""Initialize the tokenizer for text processing.
|
| 317 |
-
|
| 318 |
-
This function can be extended to include custom tokenization logic.
|
| 319 |
-
|
| 320 |
-
Args:
|
| 321 |
-
data_config: Configuration object containing tokenizer settings.
|
| 322 |
-
|
| 323 |
-
Returns:
|
| 324 |
-
AutoTokenizer: A HuggingFace tokenizer instance.
|
| 325 |
-
"""
|
| 326 |
-
|
| 327 |
-
return AutoTokenizer.from_pretrained(data_config.tokenizer.name)
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
def initialize_dataloader(
|
| 331 |
-
data_config: DataConfig,
|
| 332 |
-
training_config: TrainingConfig,
|
| 333 |
-
fabric: L.Fabric,
|
| 334 |
-
dataset: Dataset,
|
| 335 |
-
):
|
| 336 |
-
"""Initialize the DataLoader for efficient batch processing.
|
| 337 |
-
|
| 338 |
-
Creates a PyTorch DataLoader that handles batching and data loading for training.
|
| 339 |
-
Configured specifically for streaming tokenized text datasets.
|
| 340 |
-
|
| 341 |
-
You might also want to extend this function to add a sampler, or some sort of custom
|
| 342 |
-
collate function. For the default dataset, we don't need any of this, because the data are
|
| 343 |
-
pre-shuffled, and pre-tokenized.
|
| 344 |
-
|
| 345 |
-
Args:
|
| 346 |
-
data_config: Configuration object containing dataloader settings.
|
| 347 |
-
training_config: Configuration object containing training settings.
|
| 348 |
-
fabric: A Lightning Fabric instance.
|
| 349 |
-
dataset: A HuggingFace Dataset object containing tokenized text data.
|
| 350 |
-
Expected to have 'input_ids' field in its items.
|
| 351 |
-
|
| 352 |
-
Returns:
|
| 353 |
-
DataLoader: PyTorch DataLoader instance configured for the dataset.
|
| 354 |
-
"""
|
| 355 |
-
|
| 356 |
-
def _collate_fn(batch):
|
| 357 |
-
return {"input_ids": [entry["input_ids"] for entry in batch]}
|
| 358 |
-
|
| 359 |
-
sub_batch_size = data_config.dataloader.batch_size // (
|
| 360 |
-
fabric.world_size * training_config.optimization.gradient_accumulation_steps
|
| 361 |
-
)
|
| 362 |
-
|
| 363 |
-
# NOTE: We use the sub-batch size for the dataloader, which is the full batch size
|
| 364 |
-
# divided by the gradient accumulation steps. This ensures that the effective batch size
|
| 365 |
-
# is correct.
|
| 366 |
-
|
| 367 |
-
return DataLoader(
|
| 368 |
-
dataset,
|
| 369 |
-
batch_size=sub_batch_size,
|
| 370 |
-
shuffle=False, # Keep sequential for streaming datasets
|
| 371 |
-
pin_memory=True, # Speeds up transfer to GPU
|
| 372 |
-
collate_fn=_collate_fn,
|
| 373 |
-
)
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
########################################################
|
| 377 |
-
#
|
| 378 |
-
# Model Initialization
|
| 379 |
-
#
|
| 380 |
-
########################################################
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
def initialize_model(model_config: ModelConfig):
|
| 384 |
-
"""Initialize the model for training.
|
| 385 |
-
|
| 386 |
-
Loads in a given model implemented in the `src.model` package and returns it.
|
| 387 |
-
|
| 388 |
-
NOTE: out of the box we currently only support the PicoDecoder model (a causal transformer
|
| 389 |
-
language model). If you'd like to implement your own model, you can do so by adding a new
|
| 390 |
-
model class in the `src.model` package, and then adding a new entry here.
|
| 391 |
-
|
| 392 |
-
Args:
|
| 393 |
-
model_config: Configuration object containing model settings.
|
| 394 |
-
|
| 395 |
-
Returns:
|
| 396 |
-
PyTorch model instance.
|
| 397 |
-
|
| 398 |
-
"""
|
| 399 |
-
if model_config.model_type == "pico_decoder":
|
| 400 |
-
return PicoDecoder(model_config)
|
| 401 |
-
else:
|
| 402 |
-
raise ValueError(f"Invalid model type: {model_config.model_type}")
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
########################################################
|
| 406 |
-
#
|
| 407 |
-
# Optimizer and Scheduler
|
| 408 |
-
#
|
| 409 |
-
########################################################
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
def initialize_optimizer(training_config: TrainingConfig, model: torch.nn.Module):
|
| 413 |
-
"""Initialize the optimizer for model training.
|
| 414 |
-
|
| 415 |
-
Creates an optimizer instance based on the configuration settings.
|
| 416 |
-
|
| 417 |
-
Add whatever other optimizers you want here.
|
| 418 |
-
|
| 419 |
-
Args:
|
| 420 |
-
training_config: Configuration object containing optimizer settings.
|
| 421 |
-
Must have:
|
| 422 |
-
- optimization.optimizer (str): Name of the optimizer ("adamw")
|
| 423 |
-
- optimization.lr (float): Learning rate for the optimizer
|
| 424 |
-
model: PyTorch model whose parameters will be optimized.
|
| 425 |
-
|
| 426 |
-
Returns:
|
| 427 |
-
torch.optim.Optimizer: Configured optimizer instance.
|
| 428 |
-
|
| 429 |
-
"""
|
| 430 |
-
|
| 431 |
-
if training_config.optimization.optimizer == "adamw":
|
| 432 |
-
optimizer = torch.optim.AdamW(
|
| 433 |
-
model.parameters(), lr=training_config.optimization.lr
|
| 434 |
-
)
|
| 435 |
-
else:
|
| 436 |
-
raise ValueError(f"Invalid optimizer: {training_config.optimization.optimizer}")
|
| 437 |
-
|
| 438 |
-
return optimizer
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
def initialize_lr_scheduler(
|
| 442 |
-
training_config: TrainingConfig, optimizer: torch.optim.Optimizer
|
| 443 |
-
):
|
| 444 |
-
"""Initialize a learning rate scheduler with warmup and decay.
|
| 445 |
-
|
| 446 |
-
The default is a learning rate scheduler that implements a linear warmup followed by
|
| 447 |
-
linear decay. The learning rate increases linearly from 0 to the initial lr
|
| 448 |
-
during warmup, then decreases linearly to 0 during the remaining steps.
|
| 449 |
-
|
| 450 |
-
Add other types of learning rate schedulers here.
|
| 451 |
-
|
| 452 |
-
Args:
|
| 453 |
-
training_config: Configuration object containing optimizer and scheduler settings.
|
| 454 |
-
optimizer: PyTorch optimizer whose learning rate will be scheduled.
|
| 455 |
-
|
| 456 |
-
Returns:
|
| 457 |
-
torch.optim.lr_scheduler.LambdaLR: Learning rate scheduler instance.
|
| 458 |
-
"""
|
| 459 |
-
|
| 460 |
-
if training_config.optimization.lr_scheduler == "linear_with_warmup":
|
| 461 |
-
# Credit where credit is due:
|
| 462 |
-
# https://github.com/huggingface/transformers/blob/e71a01a104dd663c730e494eb0b6467bb51df357/src/transformers/optimization.py#L102
|
| 463 |
-
def _lr_lambda(curr_step, num_warmup_steps, max_steps):
|
| 464 |
-
if curr_step < num_warmup_steps:
|
| 465 |
-
return float(curr_step) / float(max(1, num_warmup_steps))
|
| 466 |
-
else:
|
| 467 |
-
return max(
|
| 468 |
-
0.0,
|
| 469 |
-
float(max_steps - curr_step)
|
| 470 |
-
/ float(max(1, max_steps - num_warmup_steps)),
|
| 471 |
-
)
|
| 472 |
-
|
| 473 |
-
lr_lambda = lambda step: _lr_lambda( # noqa: E731
|
| 474 |
-
step,
|
| 475 |
-
training_config.optimization.lr_warmup_steps,
|
| 476 |
-
training_config.max_steps,
|
| 477 |
-
)
|
| 478 |
-
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
| 479 |
-
optimizer,
|
| 480 |
-
lr_lambda,
|
| 481 |
-
)
|
| 482 |
-
elif training_config.optimization.lr_scheduler == "cosine":
|
| 483 |
-
# Cosine decay with warmup: linear warmup followed by cosine decay
|
| 484 |
-
# This provides sustained learning over long training runs
|
| 485 |
-
def _cosine_lr_lambda(curr_step, num_warmup_steps, max_steps):
|
| 486 |
-
if curr_step < num_warmup_steps:
|
| 487 |
-
# Linear warmup
|
| 488 |
-
return float(curr_step) / float(max(1, num_warmup_steps))
|
| 489 |
-
else:
|
| 490 |
-
# Cosine decay to 0.1 * initial_lr (not to 0)
|
| 491 |
-
progress = float(curr_step - num_warmup_steps) / float(
|
| 492 |
-
max(1, max_steps - num_warmup_steps)
|
| 493 |
-
)
|
| 494 |
-
return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 495 |
-
|
| 496 |
-
lr_lambda = lambda step: _cosine_lr_lambda( # noqa: E731
|
| 497 |
-
step,
|
| 498 |
-
training_config.optimization.lr_warmup_steps,
|
| 499 |
-
training_config.max_steps,
|
| 500 |
-
)
|
| 501 |
-
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
|
| 502 |
-
optimizer,
|
| 503 |
-
lr_lambda,
|
| 504 |
-
)
|
| 505 |
-
else:
|
| 506 |
-
raise ValueError(
|
| 507 |
-
f"Invalid learning rate scheduler: {training_config.optimization.lr_scheduler}"
|
| 508 |
-
)
|
| 509 |
-
|
| 510 |
-
return lr_scheduler
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
########################################################
|
| 514 |
-
#
|
| 515 |
-
# Experiment Monitoring (Logging, Experiment Tracking, etc.)
|
| 516 |
-
#
|
| 517 |
-
########################################################
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
def _initialize_log_file(checkpointing_config: CheckpointingConfig) -> str:
|
| 521 |
-
"""Create and initialize a timestamped log file in the run's log directory.
|
| 522 |
-
|
| 523 |
-
Sets up a log file with a unique timestamp in the run's logging directory.
|
| 524 |
-
Creates the necessary directory structure if it doesn't exist.
|
| 525 |
-
|
| 526 |
-
Directory Structure:
|
| 527 |
-
{checkpointing_config.runs_dir}/
|
| 528 |
-
└── {checkpointing_config.run_name}/
|
| 529 |
-
└── {checkpointing_config.logs_dir}/
|
| 530 |
-
└── log_YYYYMMDD_HHMMSS.txt
|
| 531 |
-
|
| 532 |
-
Args:
|
| 533 |
-
checkpointing_config: Configuration object containing checkpointing settings.
|
| 534 |
-
|
| 535 |
-
Returns:
|
| 536 |
-
str: Absolute path to the created log file.
|
| 537 |
-
|
| 538 |
-
"""
|
| 539 |
-
|
| 540 |
-
run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
|
| 541 |
-
logs_dir = os.path.join(run_dir, checkpointing_config.logs_dir)
|
| 542 |
-
os.makedirs(logs_dir, exist_ok=True)
|
| 543 |
-
|
| 544 |
-
# datetime stamp
|
| 545 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 546 |
-
log_file_name = f"log_{timestamp}.log"
|
| 547 |
-
log_file_path = os.path.join(logs_dir, log_file_name)
|
| 548 |
-
|
| 549 |
-
open(log_file_path, "w").close() # Create an empty log file
|
| 550 |
-
|
| 551 |
-
return log_file_path
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
@use_backoff()
|
| 555 |
-
def initialize_wandb(
|
| 556 |
-
monitoring_config: MonitoringConfig, checkpointing_config: CheckpointingConfig
|
| 557 |
-
):
|
| 558 |
-
"""Initialize Weights and Biases.
|
| 559 |
-
|
| 560 |
-
This function initializes Weights and Biases based on the configuration settings.
|
| 561 |
-
|
| 562 |
-
Args:
|
| 563 |
-
monitoring_config: Configuration object containing monitoring settings.
|
| 564 |
-
checkpointing_config: Configuration object containing checkpointing settings.
|
| 565 |
-
|
| 566 |
-
Returns:
|
| 567 |
-
Optional[WandbLogger]: An experiment tracker instance.
|
| 568 |
-
"""
|
| 569 |
-
|
| 570 |
-
assert (
|
| 571 |
-
monitoring_config.wandb.project is not None
|
| 572 |
-
and monitoring_config.wandb.project != ""
|
| 573 |
-
), "Wandb project must be provided if wandb is to be used."
|
| 574 |
-
assert (
|
| 575 |
-
monitoring_config.wandb.entity is not None
|
| 576 |
-
and monitoring_config.wandb.entity != ""
|
| 577 |
-
), "Wandb entity must be provided if wandb is to be used."
|
| 578 |
-
|
| 579 |
-
_run_id = None
|
| 580 |
-
if checkpointing_config.training.auto_resume:
|
| 581 |
-
# If we are loading a checkpoint, we can try to find the run id of the previous run
|
| 582 |
-
previous_runs = wandb.Api().runs(
|
| 583 |
-
path=f"{monitoring_config.wandb.entity}/{monitoring_config.wandb.project}",
|
| 584 |
-
filters={"display_name": checkpointing_config.run_name},
|
| 585 |
-
)
|
| 586 |
-
try:
|
| 587 |
-
if len(previous_runs) == 1:
|
| 588 |
-
_run_id = previous_runs[0].id
|
| 589 |
-
except ValueError:
|
| 590 |
-
pass
|
| 591 |
-
|
| 592 |
-
wandb_logger = WandbLogger(
|
| 593 |
-
project=monitoring_config.wandb.project,
|
| 594 |
-
entity=monitoring_config.wandb.entity,
|
| 595 |
-
id=_run_id,
|
| 596 |
-
name=checkpointing_config.run_name,
|
| 597 |
-
)
|
| 598 |
-
|
| 599 |
-
return wandb_logger
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
@rank_zero_only
|
| 603 |
-
def initialize_logging(
|
| 604 |
-
monitoring_config: MonitoringConfig,
|
| 605 |
-
checkpointing_config: CheckpointingConfig,
|
| 606 |
-
fabric: L.Fabric,
|
| 607 |
-
):
|
| 608 |
-
"""Initialize logging system with default logging, to file and console.
|
| 609 |
-
|
| 610 |
-
The default logging system uses a file handler and a stream handler.
|
| 611 |
-
|
| 612 |
-
NOTE: this function is only called on rank 0.
|
| 613 |
-
|
| 614 |
-
Args:
|
| 615 |
-
monitoring_config: Configuration object containing monitoring settings.
|
| 616 |
-
checkpointing_config: Configuration object containing checkpointing settings.
|
| 617 |
-
|
| 618 |
-
Returns:
|
| 619 |
-
logger: Standard Python logger configured for file and console output
|
| 620 |
-
"""
|
| 621 |
-
|
| 622 |
-
# ---- Standard Local Logger ---- #
|
| 623 |
-
logger = logging.getLogger("pico-train")
|
| 624 |
-
logger.setLevel(logging.INFO)
|
| 625 |
-
|
| 626 |
-
# Create file handler
|
| 627 |
-
log_file_path = _initialize_log_file(checkpointing_config)
|
| 628 |
-
file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
|
| 629 |
-
file_handler.setLevel(monitoring_config.logging.log_level)
|
| 630 |
-
|
| 631 |
-
# Create formatter and add it to the handler
|
| 632 |
-
formatter = logging.Formatter(
|
| 633 |
-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 634 |
-
datefmt="%Y-%m-%d %H:%M:%S",
|
| 635 |
-
)
|
| 636 |
-
file_handler.setFormatter(formatter)
|
| 637 |
-
|
| 638 |
-
# Add the handler to the logger
|
| 639 |
-
logger.addHandler(file_handler)
|
| 640 |
-
|
| 641 |
-
# Add a stream handler for console output
|
| 642 |
-
stream_handler = logging.StreamHandler()
|
| 643 |
-
stream_handler.setLevel(monitoring_config.logging.log_level)
|
| 644 |
-
stream_handler.setFormatter(formatter)
|
| 645 |
-
logger.addHandler(stream_handler)
|
| 646 |
-
|
| 647 |
-
return logger
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
########################################################
|
| 651 |
-
#
|
| 652 |
-
# HuggingFace/Remote Checkpointing
|
| 653 |
-
#
|
| 654 |
-
########################################################
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
@rank_zero_only
|
| 658 |
-
@use_backoff()
|
| 659 |
-
def initialize_hf_checkpointing(
|
| 660 |
-
checkpointing_config: CheckpointingConfig, fabric: L.Fabric
|
| 661 |
-
):
|
| 662 |
-
"""Initialize HuggingFace Checkpointing.
|
| 663 |
-
|
| 664 |
-
Creates a HuggingFace repository if it doesn't exist, and creates a branch named after the run.
|
| 665 |
-
|
| 666 |
-
NOTE: this function is only called on rank 0.
|
| 667 |
-
|
| 668 |
-
Args:
|
| 669 |
-
checkpointing_config: Configuration object containing checkpointing settings; must have
|
| 670 |
-
a 'hf_checkpoint' attribute that specifies the HuggingFace repository id and
|
| 671 |
-
collection slug (if applicable) to save the checkpoint to.
|
| 672 |
-
|
| 673 |
-
Raises:
|
| 674 |
-
RuntimeError: If unable to create HuggingFace repository after multiple attempts.
|
| 675 |
-
"""
|
| 676 |
-
|
| 677 |
-
huggingface_repo_id = checkpointing_config.hf_checkpoint.repo_id
|
| 678 |
-
assert (
|
| 679 |
-
huggingface_repo_id is not None and huggingface_repo_id != ""
|
| 680 |
-
), "hf_checkpoint.repo_id must be provided."
|
| 681 |
-
|
| 682 |
-
repo = create_repo(huggingface_repo_id, exist_ok=True)
|
| 683 |
-
|
| 684 |
-
# can create a repo without a specified namespace (will default to username)
|
| 685 |
-
# however the rest of the HF calls need the fully qualified name
|
| 686 |
-
# this is returned by create repo, so we update the config for later calls
|
| 687 |
-
checkpointing_config.hf_checkpoint.repo_id = repo.repo_id
|
| 688 |
-
huggingface_repo_id = repo.repo_id
|
| 689 |
-
|
| 690 |
-
if checkpointing_config.hf_checkpoint.collection_slug:
|
| 691 |
-
add_collection_item(
|
| 692 |
-
checkpointing_config.hf_checkpoint.collection_slug,
|
| 693 |
-
huggingface_repo_id,
|
| 694 |
-
repo.repo_type,
|
| 695 |
-
exists_ok=True,
|
| 696 |
-
)
|
| 697 |
-
|
| 698 |
-
create_branch(
|
| 699 |
-
repo_id=huggingface_repo_id,
|
| 700 |
-
branch=checkpointing_config.run_name,
|
| 701 |
-
exist_ok=True,
|
| 702 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/utils/io.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
"""Defines a retry wrapper for io operations."""
|
| 2 |
-
|
| 3 |
-
import time
|
| 4 |
-
from functools import wraps
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def use_backoff(max_retries=2, initial_delay=1, backoff_factor=2):
|
| 8 |
-
"""
|
| 9 |
-
Universal retry wrapper with exponential backoff for any function, but primarily for loading
|
| 10 |
-
and storing HuggingFace datasets and objects.
|
| 11 |
-
|
| 12 |
-
Example usage:
|
| 13 |
-
|
| 14 |
-
>>> @use_backoff(max_retries=10, delay=1, backoff_factor=2)
|
| 15 |
-
>>> def important_io_operation(x):
|
| 16 |
-
>>> return x + 1
|
| 17 |
-
|
| 18 |
-
Args:
|
| 19 |
-
fn: Function to execute
|
| 20 |
-
max_retries: Maximum number of retry attempts (default: 3)
|
| 21 |
-
delay: Initial delay between retries in seconds (default: 1)
|
| 22 |
-
backoff_factor: Multiplier for delay between retries (default: 2)
|
| 23 |
-
|
| 24 |
-
Returns:
|
| 25 |
-
A wrapper function that will retry the function fn up to max_retries times with exponential backoff
|
| 26 |
-
|
| 27 |
-
Raises:
|
| 28 |
-
Exception: If all retries fail
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
def _decorator(fn):
|
| 32 |
-
@wraps(fn)
|
| 33 |
-
def wrapper(*args, **kwargs):
|
| 34 |
-
current_delay = initial_delay
|
| 35 |
-
last_exception = None
|
| 36 |
-
|
| 37 |
-
for attempt in range(max_retries):
|
| 38 |
-
try:
|
| 39 |
-
return fn(*args, **kwargs)
|
| 40 |
-
except Exception as e:
|
| 41 |
-
last_exception = e
|
| 42 |
-
if attempt < max_retries - 1: # Don't sleep on the last attempt
|
| 43 |
-
time.sleep(current_delay)
|
| 44 |
-
current_delay *= backoff_factor
|
| 45 |
-
|
| 46 |
-
raise Exception(
|
| 47 |
-
f"IO Operation failed after {max_retries} attempts: {str(last_exception)}"
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
return wrapper
|
| 51 |
-
|
| 52 |
-
return _decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/utils/logging.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Miscellaneous logging utilities.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from io import StringIO
|
| 6 |
-
|
| 7 |
-
import yaml
|
| 8 |
-
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
| 9 |
-
from rich.console import Console
|
| 10 |
-
from rich.panel import Panel
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@rank_zero_only
|
| 14 |
-
def pretty_print_yaml_config(logger, config: dict) -> None:
|
| 15 |
-
"""
|
| 16 |
-
Pretty print config with rich formatting. Assumes that the config is already saved as a
|
| 17 |
-
dictionary - this can be done by calling `asdict` on the dataclass or loading in the config
|
| 18 |
-
from a yaml file.
|
| 19 |
-
|
| 20 |
-
NOTE: this function is only called on rank 0.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
logger: Logger object to log the formatted output to.
|
| 24 |
-
config: Dictionary containing the config to pretty print.
|
| 25 |
-
"""
|
| 26 |
-
# Create string buffer
|
| 27 |
-
output = StringIO()
|
| 28 |
-
console = Console(file=output, force_terminal=False)
|
| 29 |
-
|
| 30 |
-
# Convert to YAML string first
|
| 31 |
-
yaml_str = yaml.dump(
|
| 32 |
-
config, default_flow_style=False, sort_keys=False, Dumper=yaml.SafeDumper
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
# Create formatted panel
|
| 36 |
-
panel = Panel(
|
| 37 |
-
yaml_str,
|
| 38 |
-
border_style="blue",
|
| 39 |
-
padding=(0, 1), # Reduced padding
|
| 40 |
-
expand=False, # Don't expand to terminal width
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
# Print to buffer
|
| 44 |
-
console.print(panel)
|
| 45 |
-
|
| 46 |
-
# Log the formatted output
|
| 47 |
-
for line in output.getvalue().splitlines():
|
| 48 |
-
logger.info(line)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|