ThomasTheMaker commited on
Commit
4d7adb3
·
verified ·
1 Parent(s): 84cb578

Delete src

Browse files
src/checkpointing/__init__.py DELETED
@@ -1,23 +0,0 @@
1
- """
2
- Pico Checkpointing Package
3
-
4
- We subdivide the checkpointing into training, evaluation, and learning_dynamics. Training
5
- checkpoints store the model, optimizer, and learning rate scheduler. Evaluation checkpoints store
6
- the evaluation results on the defined metrics. Learning dynamics checkpoints store activations and gradients used for
7
- learning dynamics analysis.
8
- """
9
-
10
- from .evaluation import save_evaluation_results
11
- from .learning_dynamics import (
12
- compute_learning_dynamics_states,
13
- save_learning_dynamics_states,
14
- )
15
- from .training import load_checkpoint, save_checkpoint
16
-
17
- __all__ = [
18
- "compute_learning_dynamics_states",
19
- "load_checkpoint",
20
- "save_checkpoint",
21
- "save_evaluation_results",
22
- "save_learning_dynamics_states",
23
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/checkpointing/evaluation.py DELETED
@@ -1,68 +0,0 @@
1
- """
2
- Utilities for checkpointing evaluation-related states (i.e. evaluation results, etc.)
3
-
4
- We save the evaluation results in a JSON file at the step-specific evaluation results directory.
5
- """
6
-
7
- import json
8
- import os
9
- from typing import Any, Dict
10
-
11
- from huggingface_hub import upload_folder
12
- from lightning.fabric import Fabric
13
- from lightning.fabric.utilities.rank_zero import rank_zero_only
14
-
15
- from src.config import CheckpointingConfig
16
- from src.training.utils.io import use_backoff
17
-
18
-
19
- @rank_zero_only
20
- @use_backoff()
21
- def save_evaluation_results(
22
- checkpointing_config: CheckpointingConfig,
23
- checkpoint_step: int,
24
- fabric: Fabric,
25
- evaluation_results: Dict[str, Any],
26
- ) -> None:
27
- """Save evaluation results to disk and optionally to HuggingFace Hub.
28
-
29
- The evaluation results are saved in the following directory structure:
30
- {checkpointing_config.runs_dir}/
31
- └── {checkpointing_config.run_name}/
32
- └── {checkpointing_config.eval_results_dir}/
33
- └── step_{checkpoint_step}.json
34
-
35
- NOTE: this function is only called on rank 0 to avoid conflicts; assumes that the evaluation
36
- results are gathered on rank 0.
37
-
38
- Args:
39
- checkpointing_config: Configuration object containing checkpoint settings
40
- checkpoint_step: Current training checkpoint step (i.e. number of learning steps taken)
41
- fabric: Lightning Fabric instance
42
- evaluation_results: Dictionary containing evaluation metrics
43
- """
44
-
45
- run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
46
- eval_results_dir = os.path.join(
47
- run_dir, checkpointing_config.evaluation.eval_results_dir
48
- )
49
-
50
- os.makedirs(eval_results_dir, exist_ok=True)
51
-
52
- curr_eval_results_path = os.path.join(
53
- eval_results_dir, f"step_{checkpoint_step}.json"
54
- )
55
-
56
- # save out as json
57
- with open(curr_eval_results_path, "w") as f:
58
- json.dump(evaluation_results, f)
59
-
60
- if checkpointing_config.save_to_hf:
61
- upload_folder(
62
- folder_path=eval_results_dir,
63
- path_in_repo=checkpointing_config.evaluation.eval_results_dir,
64
- repo_id=checkpointing_config.hf_checkpoint.repo_id,
65
- commit_message=f"Saving Evaluation Results -- Step {checkpoint_step}",
66
- revision=checkpointing_config.run_name,
67
- token=os.getenv("HF_TOKEN"),
68
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/checkpointing/learning_dynamics.py DELETED
@@ -1,424 +0,0 @@
1
- """
2
- Utilities for checkpointing learning dynamics-related states (i.e. activations, weights, grads, etc.)
3
-
4
- We save the learning dynamics states in a subdirectory of the checkpointing directory.
5
- """
6
-
7
- import os
8
- import re
9
- from typing import Dict, Optional
10
-
11
- import deepspeed
12
- import torch
13
- import torch.nn as nn
14
- import torch.optim as optim
15
- from datasets import Dataset
16
- from huggingface_hub import upload_folder
17
- from lightning.fabric import Fabric
18
- from lightning.fabric.strategies import DeepSpeedStrategy
19
- from lightning.fabric.utilities.rank_zero import rank_zero_only
20
- from torch.nn import functional as F
21
- from torch.utils.data import DataLoader
22
- from transformers import PreTrainedTokenizerBase
23
-
24
- from src.config import CheckpointingConfig
25
- from src.config.checkpointing_config import LearningDynamicsCheckpointingConfig
26
- from src.training.utils.initialization import initialize_model
27
- from src.training.utils.io import use_backoff
28
-
29
-
30
- # NOTE: DeepSpeed requires a dummy optimizer to be passed in to the setup function
31
- class DummyOptimizer(optim.Optimizer):
32
- def __init__(self, params):
33
- super().__init__(params, defaults={})
34
-
35
-
36
- class CheckpointStateExtractor:
37
- """
38
- Class to extract and save the states of a model at a given checkpoint step for learning
39
- dynamics research.
40
- """
41
-
42
- def __init__(
43
- self,
44
- learning_dynamics_config: LearningDynamicsCheckpointingConfig,
45
- fabric: Fabric,
46
- model: nn.Module,
47
- ):
48
- self.learning_dynamics_config = learning_dynamics_config
49
- self.fabric = fabric
50
- self.model = model
51
-
52
- def extract_states(self, dataloader, compute_gradients: bool = False):
53
- """Extracts model states (activations, weights, and optionally gradients).
54
-
55
- Given a dataloader, this function will perform a forward pass of the model on each batch,
56
- and save the activations and weights at each layer. If compute_gradients is True, it will
57
- also compute the gradients of the model parameters.
58
-
59
- Args:
60
- dataloader: The dataloader containing the dataset to extract states from.
61
- compute_gradients: Whether to compute the gradients of the model parameters.
62
-
63
- Returns:
64
- A dictionary containing the activations, weights, and optionally gradients of the model.
65
- """
66
- checkpoint_activations = {}
67
- checkpoint_weights = {}
68
-
69
- # NOTE: to extract activations and weights, we need to setup forward hooks on the layers
70
- # of the model that we are interested in. This is a good intro to forward hooks if you
71
- # are not familiar: https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/
72
- forward_hooks = self._setup_forward_hooks(
73
- checkpoint_activations,
74
- checkpoint_weights,
75
- )
76
-
77
- ########################################################
78
- #
79
- # Forward Pass: Extract activations and weights; and compute gradients
80
- #
81
- ########################################################
82
-
83
- for sub_batch in dataloader:
84
- _input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device)
85
-
86
- if compute_gradients:
87
- if "labels" in sub_batch:
88
- input_ids = _input_ids
89
- labels = torch.tensor(
90
- sub_batch["labels"], device=self.fabric.device
91
- )
92
- else:
93
- input_ids = _input_ids[:, :-1]
94
- labels = _input_ids[:, 1:]
95
- else:
96
- input_ids = _input_ids
97
- labels = None
98
-
99
- if labels is None:
100
- # we can throw away the outputs, we are only interested in the hidden states
101
- with torch.no_grad():
102
- _ = self.model(input_ids)
103
- else:
104
- # NOTE: if we are computing gradients, calling backwards will compute the gradients
105
- # of the model parameters.
106
- outputs, _ = self.model(input_ids)
107
- outputs = outputs.transpose(1, 2)
108
- loss = F.cross_entropy(outputs, labels)
109
- self.fabric.backward(loss, model=self.model)
110
-
111
- # cleanup forward hooks
112
- # NOTE this is not strictly necessary, since self.model is a deepcopy of the original model
113
- # but it is good practice to remove the hooks after the forward pass is complete.
114
- for hook in forward_hooks:
115
- hook.remove()
116
-
117
- ########################################################
118
- #
119
- # Extract gradients from the target tensors of the model
120
- #
121
- ########################################################
122
-
123
- layer_suffixes = self.learning_dynamics_config.layer_suffixes
124
- checkpoint_gradients = {}
125
- if compute_gradients:
126
- for name, param in self.model.named_parameters():
127
- # only do this for the weight matrix of the layer_suffixes
128
- if (
129
- any(layer_suffix in name for layer_suffix in layer_suffixes)
130
- and "weight" in name
131
- ):
132
- if isinstance(self.fabric.strategy, DeepSpeedStrategy):
133
- _grad = deepspeed.utils.safe_get_full_grad(param)
134
- else:
135
- _grad = param.grad
136
-
137
- assert _grad is not None, f"Gradient is None for layer: {name}"
138
- name = re.sub(r"\.weight", "", name)
139
- checkpoint_gradients[name] = _grad.detach().cpu()
140
-
141
- # zero out the gradients
142
- self.model.zero_grad()
143
-
144
- return checkpoint_activations, checkpoint_weights, checkpoint_gradients
145
-
146
- ########################################################
147
- #
148
- # Setup forward hooks to save activations and weights at each layer
149
- #
150
- ########################################################
151
-
152
- def _setup_forward_hooks(self, checkpoint_activations, checkpoint_weights):
153
- """Setup forward hooks for the model to save activations and weights at each layer.
154
-
155
- This function will setup forward hooks on the layers of the model that we are interested in.
156
- The forward hooks will save the activations and weights at each layer whenever the forward pass
157
- is performed.
158
-
159
- Args:
160
- checkpoint_activations: A dictionary to store the activations at each layer.
161
- checkpoint_weights: A dictionary to store the weights at each layer.
162
-
163
- Returns:
164
- A list of forward hooks. We do this so that we can remove the hooks after the forward pass
165
- is complete.
166
- """
167
-
168
- forward_hooks = []
169
- layer_suffixes = self.learning_dynamics_config.layer_suffixes
170
-
171
- for name, module in self.model.named_modules():
172
- if any(layer_suffix in name for layer_suffix in layer_suffixes):
173
- _forward_hook = module.register_forward_hook(
174
- self._get_forward_hook(
175
- name, checkpoint_activations, checkpoint_weights
176
- )
177
- )
178
- forward_hooks.append(_forward_hook)
179
- return forward_hooks
180
-
181
- def _get_forward_hook(
182
- self, module_name, checkpoint_activations, checkpoint_weights
183
- ):
184
- """Get a forward hook for a given module.
185
-
186
- This function is called by the _setup_forward_hooks function to setup a forward hook for a given
187
- module. This functions is a closure that captures the module_name, checkpoint_activations, and
188
- checkpoint_weights.
189
-
190
- Args:
191
- module_name: The name of the module to setup a forward hook for.
192
- checkpoint_activations: A dictionary to store the activations at each layer.
193
- checkpoint_weights: A dictionary to store the weights at each layer.
194
-
195
- Returns:
196
- A forward hook for the given module.
197
- """
198
-
199
- def _forward_hook(module, _, module_out):
200
- sequence_idx = self.learning_dynamics_config.sequence_idx
201
-
202
- local_activations = module_out[:, sequence_idx, :].detach()
203
-
204
- # Gather activations from all processes using fabric
205
- gathered_activations = self.fabric.all_gather(local_activations)
206
-
207
- # Reshape from [num_processes, batch_size, hidden_dim] to [total_batch_size, hidden_dim]
208
- # NOTE: transposing allows us to interleave the activations from each process so that
209
- # they are in the correct order. (i.e. activation N is from data sample N)
210
- gathered_activations = gathered_activations.transpose(0, 1).reshape(
211
- -1, gathered_activations.shape[-1]
212
- )
213
-
214
- # check if there is already a key for the module name
215
- if module_name not in checkpoint_activations:
216
- # if there is no key, then we create a new key and store the hidden states
217
- checkpoint_activations[module_name] = (
218
- gathered_activations.detach().cpu()
219
- )
220
-
221
- # extract the weight matrix just once
222
- weight_matrix = module.weight.detach().cpu()
223
- checkpoint_weights[module_name] = weight_matrix
224
- else:
225
- # if there is already a key, then we concatenate the new hidden states to the existing ones
226
- checkpoint_activations[module_name] = torch.cat(
227
- (
228
- checkpoint_activations[module_name],
229
- gathered_activations.detach().cpu(),
230
- )
231
- )
232
-
233
- return _forward_hook
234
-
235
-
236
- def compute_learning_dynamics_states(
237
- checkpointing_config: CheckpointingConfig,
238
- fabric: Fabric,
239
- model: nn.Module,
240
- dataset: Dataset,
241
- compute_gradients: bool = False,
242
- ) -> Dict[str, torch.Tensor]:
243
- """Computes the learning dynamics metrics for a given checkpoint step.
244
-
245
- Uses the CheckpointStateExtractor to extract the activations, weights, and optionally gradients
246
- of the model at a given checkpoint step.
247
-
248
- Args:
249
- checkpointing_config: The configuration object for checkpointing.
250
- fabric: The Fabric instance for distributed training.
251
- model: The model to extract states from.
252
- dataset: The dataset to extract states from.
253
- compute_gradients: Whether to compute the gradients of the model parameters.
254
-
255
- Returns:
256
- A dictionary containing the activations, weights, and optionally gradients of the model.
257
- """
258
-
259
- # NOTE: Synchronizing processes for fabric dataloader setup
260
- fabric.barrier()
261
- model.to("cpu") # Offloading model to CPU
262
-
263
- # Setting up Dataloader for learning dynamics
264
- def _collate_fn(batch):
265
- return {"input_ids": [entry["input_ids"] for entry in batch]}
266
-
267
- batch_size = checkpointing_config.learning_dynamics.batch_size
268
- sub_batch_size = batch_size // fabric.world_size
269
-
270
- # NOTE: Make sure to set drop_last to False, otherwise the last batch will be dropped
271
- # and we will not have a complete set of activations for the last sample. Also,
272
- # we need to set shuffle to False, otherwise the activations will be shuffled across
273
- # processes and we will not be able to interleave them correctly.
274
- extractor_dataloader = DataLoader(
275
- dataset,
276
- batch_size=sub_batch_size,
277
- shuffle=False,
278
- collate_fn=_collate_fn,
279
- drop_last=False,
280
- )
281
- extractor_dataloader = fabric.setup_dataloaders(
282
- extractor_dataloader, use_distributed_sampler=True
283
- )
284
-
285
- # Create a new model instance with same parameters but zero gradients
286
- _model = initialize_model(model.config)
287
- _model.load_state_dict(model.state_dict())
288
-
289
- if isinstance(fabric.strategy, DeepSpeedStrategy):
290
- _model, _ = fabric.setup(_model, DummyOptimizer(_model.parameters()))
291
- else:
292
- _model = fabric.setup(_model)
293
-
294
- _model.zero_grad()
295
-
296
- # setup forward hooks for the model to save activations and weights at each layer
297
- state_extractor = CheckpointStateExtractor(
298
- checkpointing_config.learning_dynamics, fabric, _model
299
- )
300
-
301
- checkpoint_activations, checkpoint_weights, checkpoint_gradients = (
302
- state_extractor.extract_states(
303
- extractor_dataloader, compute_gradients=compute_gradients
304
- )
305
- )
306
-
307
- del _model
308
- torch.cuda.empty_cache()
309
-
310
- # NOTE: Synchronizing processes for model setup
311
- fabric.barrier()
312
-
313
- model.to(fabric.device)
314
-
315
- # NOTE: Trimming down the activations to match the dataset size;
316
- # This is because the DataSampler might add extra samples to the dataset to make it evenly divisible
317
- # by the number of processes. We need to remove these extra samples.
318
- for layer_name, layer_activations in checkpoint_activations.items():
319
- if len(layer_activations) > len(dataset):
320
- checkpoint_activations[layer_name] = layer_activations[: len(dataset)]
321
- elif len(layer_activations) < len(dataset):
322
- raise ValueError(
323
- f"Number of activations ({len(layer_activations)}) in layer {layer_name} does not match number of samples in dataset ({len(dataset)})"
324
- )
325
-
326
- return {
327
- "activations": checkpoint_activations,
328
- "weights": checkpoint_weights,
329
- "gradients": checkpoint_gradients,
330
- }
331
-
332
-
333
- @rank_zero_only
334
- @use_backoff()
335
- def save_learning_dynamics_states(
336
- checkpointing_config: CheckpointingConfig,
337
- checkpoint_step: int,
338
- prefix: str,
339
- fabric: Fabric,
340
- learning_dynamics_states: Dict[str, torch.Tensor],
341
- learning_dynamics_dataset: Optional[Dataset] = None,
342
- tokenizer: Optional[PreTrainedTokenizerBase] = None,
343
- ) -> None:
344
- """Save the learning dynamics metrics to the checkpointing directory.
345
-
346
- By default only the learning dynamics states are saved. If the learning dynamics dataset
347
- is provided, it is also saved; if a tokenizer is provided, the dataset is also detokenized
348
- (i.e. a new column with the text is added to the dataset).
349
-
350
- The learning dynamics dataset is saved in the checkpointing directory as a HuggingFace
351
- dataset.
352
-
353
- Creates a versioned checkpoint directory with the following structure:
354
-
355
- {checkpointing_config.runs_dir}/
356
- └── {checkpointing_config.run_name}/
357
- └── {checkpointing_config.checkpoints_dir}/
358
- ├── step_{checkpoint_step}/
359
- │ └── {checkpointing_config.learning_dynamics_dir}/ # Learning Dynamics files
360
- │ ├── {prefix}_activations.pt
361
- │ ├── {prefix}_weights.pt
362
- │ └── {prefix}_gradients.pt
363
- │ └── {prefix}_data/ # if learning_dynamics_dataset is provided
364
- └── latest -> step_{checkpoint_step}/
365
-
366
- NOTE: this function is only called on rank 0
367
-
368
- Args:
369
- checkpointing_config: The configuration object for checkpointing.
370
- checkpoint_step: The checkpoint step at which the learning dynamics states were computed.
371
- prefix: The prefix for the learning dynamics states.
372
- fabric: The Fabric instance for distributed training.
373
- learning_dynamics_states: The learning dynamics states to save.
374
- learning_dynamics_dataset: The dataset containing learning dynamics data,
375
- including input IDs that need to be decoded. (optional)
376
- tokenizer: The tokenizer used to decode input IDs into text. (optional)
377
- """
378
-
379
- runs_dir = checkpointing_config.runs_dir
380
- run_name = checkpointing_config.run_name
381
- checkpoints_dir = checkpointing_config.checkpoints_dir
382
- learning_dynamics_dir = checkpointing_config.learning_dynamics_dir
383
-
384
- run_path = os.path.join(runs_dir, run_name)
385
- root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
386
- checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
387
- learning_dynamics_path = os.path.join(checkpoint_path, learning_dynamics_dir)
388
- os.makedirs(learning_dynamics_path, exist_ok=True)
389
-
390
- # save the learning dynamics states
391
- for key, value in learning_dynamics_states.items():
392
- if value is not None and len(value) > 0:
393
- torch.save(
394
- value, os.path.join(learning_dynamics_path, f"{prefix}_{key}.pt")
395
- )
396
-
397
- if learning_dynamics_dataset is not None:
398
- if tokenizer is not None:
399
- # go through dataset and decode the input ids; and add back into dataset
400
- detokenized_dataset = {"input_ids": [], "text": []}
401
-
402
- for entry in learning_dynamics_dataset:
403
- input_ids = entry["input_ids"]
404
- decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True)
405
- detokenized_dataset["input_ids"].append(input_ids)
406
- detokenized_dataset["text"].append(decoded_text)
407
-
408
- learning_dynamics_dataset = Dataset.from_dict(detokenized_dataset)
409
-
410
- learning_dynamics_dataset_path = os.path.join(
411
- learning_dynamics_path, f"{prefix}_data"
412
- )
413
- learning_dynamics_dataset.save_to_disk(learning_dynamics_dataset_path)
414
-
415
- if checkpointing_config.save_to_hf:
416
- # Upload the HF model
417
- upload_folder(
418
- folder_path=learning_dynamics_path,
419
- path_in_repo=learning_dynamics_dir,
420
- repo_id=checkpointing_config.hf_checkpoint.repo_id,
421
- commit_message=f"Saving Learning Dynamics Data ({prefix}) -- Step {checkpoint_step}",
422
- revision=checkpointing_config.run_name,
423
- token=os.getenv("HF_TOKEN"),
424
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/checkpointing/training.py DELETED
@@ -1,287 +0,0 @@
1
- """
2
- Utilities for checkpointing training-related states (i.e. model, optimizer, lr_scheduler, etc.)
3
-
4
- We save both a HuggingFace model and a Fabric-specific checkpoint. The HuggingFace model is
5
- saved at the step-specific checkpoint directory, while the Fabric-specific checkpoint is saved
6
- in a subdirectory. This is done to facilitate easier versioning of the HuggingFace model files
7
- (which are what gets uploaded to the Hub).
8
- """
9
-
10
- import os
11
- from dataclasses import asdict
12
- from typing import Any, Dict, Tuple, Union
13
-
14
- import yaml
15
- from huggingface_hub import upload_file, upload_folder
16
- from lightning.fabric import Fabric
17
- from lightning.fabric.strategies import DeepSpeedStrategy
18
- from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
19
- from torch import nn
20
- from torch.optim import Optimizer
21
- from torch.optim.lr_scheduler import LRScheduler
22
- from transformers import PreTrainedTokenizerBase
23
-
24
- from src.config import CheckpointingConfig
25
- from src.training.utils.io import use_backoff
26
-
27
-
28
- @use_backoff()
29
- def load_checkpoint(
30
- checkpointing_config: CheckpointingConfig,
31
- checkpoint_step: Union[str, int],
32
- fabric: Fabric,
33
- model: nn.Module,
34
- optimizer: Optimizer,
35
- lr_scheduler: LRScheduler,
36
- ) -> Tuple[nn.Module, Optimizer, LRScheduler, int]:
37
- """Load model checkpoint and associated states from a given step.
38
-
39
- Args:
40
- checkpointing_config: Configuration object containing checkpoint settings
41
- checkpoint_step: The step at which to load the checkpoint
42
- fabric: Lightning Fabric instance for distributed training support
43
- model: The model instance to load weights into
44
- optimizer: The optimizer instance to load states into
45
- lr_scheduler: The learning rate scheduler to load states into
46
-
47
- Returns:
48
- Tuple containing the model, optimizer, lr_scheduler, and checkpoint step.
49
- Returns None if no checkpoint is found.
50
- """
51
-
52
- if isinstance(checkpoint_step, int):
53
- checkpoint_step = f"step_{checkpoint_step}"
54
-
55
- checkpoint_path = os.path.join(
56
- checkpointing_config.runs_dir,
57
- checkpointing_config.run_name,
58
- checkpointing_config.checkpoints_dir,
59
- checkpoint_step,
60
- )
61
-
62
- if not os.path.exists(checkpoint_path):
63
- return None
64
-
65
- # Load from specified fabric checkpoint subdirectory
66
- fabric_checkpoint_path = os.path.join(
67
- checkpoint_path, checkpointing_config.fabric_checkpoint_dir
68
- )
69
-
70
- checkpoint_state = {
71
- "_model": model,
72
- "_optimizer": optimizer,
73
- "_lr_scheduler": lr_scheduler,
74
- }
75
-
76
- if not isinstance(fabric.strategy, DeepSpeedStrategy):
77
- fabric_load_file = os.path.join(
78
- fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
79
- )
80
- else:
81
- # Deepspeed checkpoints create sub-directory with distributed checkpoint file
82
- fabric_load_file = fabric_checkpoint_path
83
-
84
- extra_state = fabric.load(os.path.join(fabric_load_file), state=checkpoint_state)
85
-
86
- # NOTE: extra_state will contain any additional states that were saved in the checkpoint
87
- checkpoint_step = extra_state["_checkpoint_step"]
88
-
89
- if "_rng_states" in extra_state:
90
- _rng_states = extra_state["_rng_states"]
91
- _set_rng_states(_rng_states)
92
-
93
- return model, optimizer, lr_scheduler, checkpoint_step
94
-
95
-
96
- @use_backoff()
97
- def save_checkpoint(
98
- configs: Dict[str, Any],
99
- checkpoint_step: int,
100
- fabric: Fabric,
101
- model: nn.Module,
102
- optimizer: Optimizer,
103
- lr_scheduler: LRScheduler,
104
- tokenizer: PreTrainedTokenizerBase,
105
- upload_logs: bool = False,
106
- ) -> None:
107
- """Save training checkpoint and associated states to disk and optionally to HuggingFace Hub.
108
-
109
- We save the following files:
110
- - HuggingFace model files (config.json, pytorch_model.bin)
111
- - Tokenizer files (vocab.json, merges.txt)
112
- - Fabric-specific files - fabric state of the model, optimizer, and lr_scheduler. If using
113
- DeepSpeed, the checkpoint is saved in a subdirectory, otherwise it is saved in a single file.
114
-
115
- Note that the HuggingFace model files are saved at the step-specific checkpoint directory, while the
116
- Fabric-specific files are saved in a subdirectory. This is done to facilitate easier
117
- versioning of the HuggingFace model files (which are what gets uploaded to the Hub).
118
-
119
- NOTE: Why do we save a HF model at all? We do this because it makes it easier to load the model
120
- in a separate script for evaluation and to play nicely with the HuggingFace Hub.
121
-
122
- Creates a versioned checkpoint directory with the following structure:
123
-
124
- {checkpointing_config.runs_dir}/
125
- └── {checkpointing_config.run_name}/
126
- └── training_config.yaml # Training config
127
- └── {checkpointing_config.checkpoints_dir}/
128
- ├── step_{checkpoint_step}/
129
- │ ├── config.json # HuggingFace model config
130
- │ ├── model.safetensors # HuggingFace model weights
131
- │ ├── pico_{model_type}.py # HuggingFace custom model class
132
- │ ├── tokenizer.json # Tokenizer vocab
133
- │ ├── tokenizer_config.json # Tokenizer config
134
- │ └── {checkpointing_config.fabric_checkpoint_dir}/ # Fabric-specific files
135
- │ └── checkpoint/ # Distributed model checkpoint files (if using DeepSpeed)
136
- │ OR
137
- │ └── checkpoint.pt # Single checkpoint file (if using other strategies)
138
- └── latest -> step_{checkpoint_step}/
139
-
140
- Args:
141
- configs: A dictionary containing the initialized configuration objects.
142
- checkpoint_step: The current training checkpoint step (i.e. number of learning steps taken)
143
- fabric: Lightning Fabric instance for distributed training support
144
- model: The model instance to save
145
- optimizer: The optimizer instance to save
146
- lr_scheduler: The learning rate scheduler to save
147
- tokenizer: The tokenizer to save
148
- upload_logs: Whether to upload training logs to HF Hub (default: False)
149
-
150
- """
151
-
152
- checkpointing_config = configs["checkpointing"]
153
-
154
- # Get the directories from the training config
155
- runs_dir = checkpointing_config.runs_dir
156
- checkpoints_dir = checkpointing_config.checkpoints_dir
157
- fabric_checkpoint_dir = checkpointing_config.fabric_checkpoint_dir
158
- logs_dir = checkpointing_config.logs_dir
159
-
160
- run_path = os.path.join(runs_dir, checkpointing_config.run_name)
161
- root_checkpoint_path = os.path.join(run_path, checkpoints_dir)
162
- checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}")
163
-
164
- # Create directories
165
- os.makedirs(checkpoint_path, exist_ok=True)
166
-
167
- ########################################################
168
- #
169
- # Save HuggingFace files
170
- #
171
- ########################################################
172
-
173
- # NOTE: we convert the Pico model to a HuggingFace model before saving it. See `model.py`
174
- # for more details.
175
- if fabric.global_rank == 0:
176
- hf_model = model.convert_to_hf_model()
177
- hf_model.save_pretrained(checkpoint_path)
178
- tokenizer.save_pretrained(checkpoint_path)
179
-
180
- ########################################################
181
- #
182
- # Save Fabric-specific files
183
- #
184
- ########################################################
185
-
186
- # Create fabric-specific subdirectory
187
- fabric_checkpoint_path = os.path.join(checkpoint_path, fabric_checkpoint_dir)
188
- os.makedirs(fabric_checkpoint_path, exist_ok=True)
189
-
190
- # Save model states (use underscore to avoid conflicts with third-party libraries)
191
- checkpoint_state = {
192
- "_model": model,
193
- "_optimizer": optimizer,
194
- "_lr_scheduler": lr_scheduler,
195
- "_checkpoint_step": checkpoint_step,
196
- }
197
-
198
- if not isinstance(fabric.strategy, DeepSpeedStrategy):
199
- checkpoint_state["_rng_states"] = _collect_rng_states()
200
- fabric_save_file = os.path.join(
201
- fabric_checkpoint_path, checkpointing_config.fabric_checkpoint_filename
202
- )
203
- else:
204
- # Deepspeed checkpoints create sub-directory with distributed checkpoint file
205
- fabric_save_file = fabric_checkpoint_path
206
-
207
- fabric.save(fabric_save_file, checkpoint_state)
208
-
209
- if fabric.global_rank == 0:
210
- # Save config in fabric directory
211
- config_path = os.path.join(run_path, "training_config.yaml")
212
- if not os.path.exists(config_path):
213
- # Converting dataclasses to joined dicts and saving to file
214
- _training_config = {}
215
- for config_name, config in configs.items():
216
- _training_config[config_name] = asdict(config)
217
- with open(config_path, "w") as f:
218
- yaml.dump(_training_config, f)
219
-
220
- # Update latest symlink
221
- latest_symlink_path = os.path.join(root_checkpoint_path, "latest")
222
- if os.path.lexists(latest_symlink_path):
223
- os.remove(latest_symlink_path)
224
- os.symlink(
225
- f"step_{checkpoint_step}", latest_symlink_path, target_is_directory=True
226
- )
227
-
228
- ########################################################
229
- #
230
- # Push to HuggingFace Hub (if configured)
231
- #
232
- ########################################################
233
-
234
- if fabric.global_rank == 0:
235
- # Push only on rank zero thread
236
-
237
- if checkpointing_config.save_to_hf:
238
- repo_id = checkpointing_config.hf_checkpoint.repo_id
239
-
240
- # Upload the HF model
241
- hf_model.push_to_hub(
242
- repo_id=repo_id,
243
- commit_message=f"Saving HF Model -- Step {checkpoint_step}",
244
- revision=checkpointing_config.run_name,
245
- token=os.getenv("HF_TOKEN"),
246
- )
247
-
248
- if checkpoint_step == 0:
249
- # Uploading Tokenizer during first step since it never changes
250
- tokenizer.push_to_hub(
251
- repo_id=repo_id,
252
- commit_message=f"Saving Tokenizer -- Step {checkpoint_step}",
253
- revision=checkpointing_config.run_name,
254
- token=os.getenv("HF_TOKEN"),
255
- )
256
-
257
- # Upload training config, also only in first step
258
- upload_file(
259
- path_or_fileobj=config_path,
260
- path_in_repo="training_config.yaml",
261
- repo_id=repo_id,
262
- commit_message=f"Saving Training Config -- Step {checkpoint_step}",
263
- revision=checkpointing_config.run_name,
264
- token=os.getenv("HF_TOKEN"),
265
- )
266
-
267
- # Upload the fabric checkpoint directory
268
- upload_folder(
269
- folder_path=fabric_checkpoint_path,
270
- path_in_repo=fabric_checkpoint_dir,
271
- repo_id=repo_id,
272
- commit_message=f"Saving Fabric Checkpoint -- Step {checkpoint_step}",
273
- revision=checkpointing_config.run_name,
274
- token=os.getenv("HF_TOKEN"),
275
- )
276
-
277
- # Upload logs if requested
278
- if upload_logs:
279
- logs_path = os.path.join(run_path, logs_dir)
280
- upload_folder(
281
- folder_path=logs_path,
282
- path_in_repo=logs_dir,
283
- repo_id=repo_id,
284
- commit_message=f"Saving Logs -- Step {checkpoint_step}",
285
- revision=checkpointing_config.run_name,
286
- token=os.getenv("HF_TOKEN"),
287
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/__init__.py DELETED
@@ -1,31 +0,0 @@
1
- """
2
- Pico Config Package
3
-
4
- The modules of this package are where you can specify the hyperparameters for the Pico model,
5
- the dataset, the training process, evaluation, etc.
6
-
7
- As with anything else in Pico, we've designed for the configuration setup to be as flexible
8
- as possible. By default the configs are implemented as vanilla dataclasses -- this makes it easy to
9
- switch to different config management systems if you want, like hydra.
10
-
11
- Some things to NOTE:
12
- - All hyperparameters are initialized with default values, which can be overridden.
13
- - The default vocab size is set to the size of the OLMo tokenizer.
14
- """
15
-
16
- # For convenience, we export the config classes here
17
- from .checkpointing_config import CheckpointingConfig
18
- from .data_config import DataConfig
19
- from .evaluation_config import EvaluationConfig
20
- from .model_config import ModelConfig
21
- from .monitoring_config import MonitoringConfig
22
- from .training_config import TrainingConfig
23
-
24
- __all__ = [
25
- "CheckpointingConfig",
26
- "DataConfig",
27
- "EvaluationConfig",
28
- "ModelConfig",
29
- "MonitoringConfig",
30
- "TrainingConfig",
31
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/_constants.py DELETED
@@ -1,18 +0,0 @@
1
- """
2
- Constants used throughout the codebase
3
- """
4
-
5
- # Basic Training Constants used throughout the codebase
6
- VOCAB_SIZE = 50304
7
- MAX_SEQ_LEN = 2048
8
- BATCH_SIZE = 1024
9
- GRADIENT_ACCUMULATION_STEPS = 128
10
-
11
- # Directories used to store training runs, checkpoints, logs, and evaluation results
12
- RUNS_DIR = "runs"
13
- CHECKPOINTS_DIR = "checkpoints"
14
- LOGS_DIR = "logs"
15
- FABRIC_CHECKPOINT_DIR = "fabric_state"
16
- FABRIC_CHECKPOINT_FILENAME = "checkpoint.pt"
17
- LEARNING_DYNAMICS_DIR = "learning_dynamics"
18
- EVAL_RESULTS_DIR = "eval_results"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/checkpointing_config.py DELETED
@@ -1,97 +0,0 @@
1
- """
2
- Checkpointing Config
3
-
4
- Specifies the hyperparameters for the checkpointing process; checkpointing is used to save
5
- the model and optimizer states, as well as the learning dynamics metrics.
6
- """
7
-
8
- from dataclasses import dataclass, field
9
- from typing import List, Optional
10
-
11
- from ._constants import (
12
- CHECKPOINTS_DIR,
13
- EVAL_RESULTS_DIR,
14
- FABRIC_CHECKPOINT_DIR,
15
- FABRIC_CHECKPOINT_FILENAME,
16
- LEARNING_DYNAMICS_DIR,
17
- LOGS_DIR,
18
- RUNS_DIR,
19
- )
20
-
21
-
22
- @dataclass
23
- class TrainingCheckpointingConfig:
24
- # Automatically resume training from the most recent checkpoint
25
- auto_resume: bool = True
26
-
27
-
28
- @dataclass
29
- class EvaluationCheckpointingConfig:
30
- # Directory in which evaluation results are saved
31
- eval_results_dir: str = EVAL_RESULTS_DIR
32
-
33
-
34
- @dataclass
35
- class LearningDynamicsCheckpointingConfig:
36
- # Suffixes of the layers to compute learning dynamics for
37
- layer_suffixes: List[str] = field(
38
- default_factory=lambda: [
39
- "attention.v_proj",
40
- "attention.o_proj",
41
- "swiglu.w_2",
42
- ]
43
- )
44
-
45
- # Sequence index at which to extract hidden states; by default, we extract the hidden states
46
- # at the last token of the sequence (-1)
47
- sequence_idx: int = -1
48
-
49
- # size of the sub-batch used for extracting learning dynamics states
50
- batch_size: int = 8
51
-
52
- # Path to evaluation dataset - used across learning dynamics checkpointing for consistency
53
- # NOTE: set to None to disable extracting learning dynamics states for an eval_batch
54
- # NOTE: this dataset should be small, ideally just a batch of additional data
55
- eval_data: Optional[str] = "pico-lm/pretokenized-paloma-tinsy"
56
-
57
-
58
- @dataclass
59
- class HuggingFaceCheckpointingConfig:
60
- # Should be in the format of <(username or organization name)>/<repo_name>, e.g. pico-lm/demo
61
- repo_id: str = ""
62
-
63
- # HuggingFace Collection Slug (specifies a tag for the run)
64
- collection_slug: Optional[str] = None
65
-
66
-
67
- @dataclass
68
- class CheckpointingConfig:
69
- # Assign a name to the run
70
- run_name: Optional[str] = None
71
-
72
- # Defining checkpointing directories
73
- runs_dir: str = RUNS_DIR
74
- checkpoints_dir: str = CHECKPOINTS_DIR
75
- logs_dir: str = LOGS_DIR
76
- fabric_checkpoint_dir: str = FABRIC_CHECKPOINT_DIR
77
- fabric_checkpoint_filename: str = FABRIC_CHECKPOINT_FILENAME
78
- learning_dynamics_dir: str = LEARNING_DYNAMICS_DIR
79
-
80
- # How often to save checkpoints
81
- save_every_n_steps: int = 1000
82
-
83
- # Whether to save checkpoints to HuggingFace
84
- save_to_hf: Optional[bool] = False
85
- hf_checkpoint: HuggingFaceCheckpointingConfig = field(
86
- default_factory=HuggingFaceCheckpointingConfig
87
- )
88
-
89
- training: TrainingCheckpointingConfig = field(
90
- default_factory=TrainingCheckpointingConfig
91
- )
92
- evaluation: EvaluationCheckpointingConfig = field(
93
- default_factory=EvaluationCheckpointingConfig
94
- )
95
- learning_dynamics: LearningDynamicsCheckpointingConfig = field(
96
- default_factory=LearningDynamicsCheckpointingConfig
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/data_config.py DELETED
@@ -1,36 +0,0 @@
1
- """
2
- Data Config
3
-
4
- Specifies the hyperparameters for the dataset, dataloader, and tokenizer.
5
- """
6
-
7
- from dataclasses import dataclass, field
8
-
9
- from ._constants import BATCH_SIZE, VOCAB_SIZE
10
-
11
-
12
- @dataclass
13
- class DatasetConfig:
14
- # Defines the HuggingFace name of a dataset
15
- name: str = "pico-lm/pretokenized-dolma"
16
-
17
-
18
- @dataclass
19
- class DataLoaderConfig:
20
- # NOTE: You should only change these values jointly with the training config; so that the
21
- # sub-batch size is consistent with the gradient accumulation steps
22
- batch_size: int = BATCH_SIZE
23
-
24
-
25
- @dataclass
26
- class TokenizerConfig:
27
- # Specify a tokenizer to use
28
- name: str = "allenai/OLMo-7B-0724-hf"
29
- vocab_size: int = VOCAB_SIZE
30
-
31
-
32
- @dataclass
33
- class DataConfig:
34
- dataset: DatasetConfig = field(default_factory=DatasetConfig)
35
- dataloader: DataLoaderConfig = field(default_factory=DataLoaderConfig)
36
- tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/evaluation_config.py DELETED
@@ -1,28 +0,0 @@
1
- """
2
- Evaluation Config
3
-
4
- Specifies the hyperparameters for the evaluation process, i.e. what metrics to compute, etc.
5
- """
6
-
7
- from dataclasses import dataclass, field
8
- from typing import List, Optional
9
-
10
- from src.config._constants import MAX_SEQ_LEN
11
-
12
-
13
- @dataclass
14
- class PalomaEvaluationConfig:
15
- dataset_name: str = "pico-lm/pretokenized-paloma-tinsy"
16
- dataset_split: str = "val"
17
- max_length: int = MAX_SEQ_LEN
18
- batch_size: int = 16
19
-
20
-
21
- @dataclass
22
- class EvaluationConfig:
23
- # Evaluation metrics to compute: by default, we compute the perplexity of the model on the paloma dataset
24
- metrics: Optional[List[str]] = field(default_factory=lambda: ["paloma"])
25
-
26
- # NOTE: Add other evaluation configs here
27
- # Each evaluation metric should have its own config
28
- paloma: PalomaEvaluationConfig = field(default_factory=PalomaEvaluationConfig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/model_config.py DELETED
@@ -1,33 +0,0 @@
1
- """
2
- Model Config
3
-
4
- Specifies the hyperparameters for the Pico model/model architecture.
5
- """
6
-
7
- from dataclasses import dataclass
8
- from typing import Optional
9
-
10
- from ._constants import BATCH_SIZE, MAX_SEQ_LEN, VOCAB_SIZE
11
-
12
-
13
- @dataclass
14
- class ModelConfig:
15
- model_type: str = "pico_decoder"
16
-
17
- # Pico Decoder default hyperparameters
18
-
19
- d_model: int = 768
20
- n_layers: int = 12
21
-
22
- vocab_size: int = VOCAB_SIZE
23
- batch_size: int = BATCH_SIZE
24
- max_seq_len: int = MAX_SEQ_LEN
25
-
26
- attention_n_heads: int = 12
27
- attention_n_kv_heads: Optional[int] = 4
28
-
29
- activation_hidden_dim: int = 3072
30
-
31
- norm_eps: float = 1e-6
32
-
33
- position_emb_theta: float = 10000.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/monitoring_config.py DELETED
@@ -1,29 +0,0 @@
1
- """
2
- Monitoring Config
3
-
4
- Specifies the monitoring process, e.g. how to log metrics and keep track of training progress.
5
- """
6
-
7
- from dataclasses import dataclass, field
8
-
9
-
10
- @dataclass
11
- class LoggingConfig:
12
- log_level: str = "INFO"
13
- log_every_n_steps: int = 100
14
-
15
-
16
- @dataclass
17
- class WandbConfig:
18
- # configure logging to Weights and Biases
19
- project: str = ""
20
- entity: str = ""
21
-
22
-
23
- @dataclass
24
- class MonitoringConfig:
25
- logging: LoggingConfig = field(default_factory=LoggingConfig)
26
-
27
- # Weights and Biases
28
- save_to_wandb: bool = False
29
- wandb: WandbConfig = field(default_factory=WandbConfig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/training_config.py DELETED
@@ -1,40 +0,0 @@
1
- """
2
- Training Config
3
-
4
- Specifies the hyperparameters for the training process, i.e. the optimizer, learning rate, etc.
5
- """
6
-
7
- from dataclasses import dataclass, field
8
-
9
- from ._constants import GRADIENT_ACCUMULATION_STEPS
10
-
11
-
12
- @dataclass
13
- class FabricConfig:
14
- # Configure nodes/devices for parallelised training
15
- num_nodes: int = 1
16
- num_devices: int = 1
17
- precision: str = "bf16-mixed"
18
- # Hardware accelerator to use, can be cpu/cuda/mps etc.
19
- accelerator: str = "cuda"
20
-
21
-
22
- @dataclass
23
- class OptimizationConfig:
24
- # Optimizer
25
- optimizer: str = "adamw"
26
- lr: float = 3e-4
27
-
28
- # Learning Rate Scheduler
29
- lr_scheduler: str = "linear_with_warmup"
30
- lr_warmup_steps: int = 2500
31
-
32
- # Define number of gradient accumulation steps
33
- gradient_accumulation_steps: int = GRADIENT_ACCUMULATION_STEPS
34
-
35
-
36
- @dataclass
37
- class TrainingConfig:
38
- fabric: FabricConfig = field(default_factory=FabricConfig)
39
- optimization: OptimizationConfig = field(default_factory=OptimizationConfig)
40
- max_steps: int = 200_000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/evaluation/__init__.py DELETED
@@ -1,103 +0,0 @@
1
- """
2
- Pico Evaluation Package
3
-
4
- This package implements the evaluation pipeline for the Pico language model. It provides
5
- functionality to evaluate model performance using various metrics and handles the complete
6
- evaluation workflow.
7
-
8
- We recommend that each evaluation metric should have its own config, and should be
9
- implemented as a module in the `evaluation/tasks` directory that exposes a `run_<metric_name>` function.
10
-
11
- NOTE: Out of the box we only support Paloma, but the structure is designed to be flexible and
12
- you are meant to add whatever metrics you want. One of the main reasons we store out
13
- the model in the HuggingFace format is so that its easy to use third-party evaluation
14
- libraries/frameworks.
15
- """
16
-
17
- import os
18
-
19
- import torch
20
- from lightning.fabric import Fabric
21
- from torch import nn
22
-
23
- from src.config import CheckpointingConfig, EvaluationConfig
24
-
25
- from .tasks.paloma import run_paloma_evaluation
26
-
27
-
28
- def run_evaluation(
29
- evaluation_config: EvaluationConfig,
30
- checkpointing_config: CheckpointingConfig,
31
- fabric: Fabric,
32
- model: nn.Module,
33
- ) -> None:
34
- """Run model evaluation using specified metrics in `evaluation_config`.
35
-
36
- This function orchestrates the complete evaluation pipeline by:
37
- 1. Resolving the model checkpoint path (either specified or latest) to load the model from;
38
- during training, this is the path to the latest checkpoint in the run directory.
39
- 2. Iterating over each evaluation metric, and running the corresponding evaluation function.
40
- NOTE: we suggest you follow the pattern of the Paloma evaluation function, and implement
41
- your own evaluation function for each metric in the `evaluation/tasks` directory.
42
- 3. Aggregating results across all metrics in a dictionary, and returning it.
43
-
44
- Args:
45
- evaluation_config (EvaluationConfig): Configuration object containing:
46
- - metrics (List[str]): Metrics to evaluate; each metric should have its
47
- own config. Currently supported: ["paloma"];
48
- - paloma (PalomaConfig): Configuration for Paloma evaluation
49
- - max_length (int): Maximum sequence length
50
- - limit_eval_examples (Optional[int]): Number of examples to evaluate
51
- checkpointing_config (CheckpointingConfig): Configuration object containing:
52
- fabric (Fabric): Lightning Fabric instance
53
- model (nn.Module): Original model instance
54
-
55
- Returns:
56
- Dict[str, float]: Dictionary mapping metric names to their values
57
- Example: {"paloma": 3.45}
58
-
59
- Raises:
60
- ValueError: If an unsupported evaluation metric is requested
61
-
62
- Example:
63
- results = run_evaluation(
64
- EvaluationConfig(
65
- run_name="experiment_1",
66
- metrics=["paloma"],
67
- paloma=PalomaConfig(max_length=2048, batch_size=16)
68
- )
69
- )
70
-
71
- """
72
-
73
- fabric.barrier()
74
-
75
- model.to("cpu") # Offloading model to CPU
76
-
77
- evaluation_results = {}
78
-
79
- # NOTE: Evaluation is only run on first processes to enable third-party evaluation libraries
80
- # to determine how to handle distributed evaluation.
81
- if fabric.global_rank == 0:
82
- run_name = checkpointing_config.run_name
83
- model_path = f"{os.getcwd()}/{checkpointing_config.runs_dir}/{run_name}/{checkpointing_config.checkpoints_dir}/latest"
84
- os.makedirs(model_path, exist_ok=True)
85
-
86
- for metric in evaluation_config.metrics:
87
- # NOTE: add your own metrics here
88
- if metric == "paloma":
89
- evaluation_result = run_paloma_evaluation(
90
- model_path, evaluation_config.paloma
91
- )
92
- else:
93
- raise ValueError(f"Metric {metric} not supported")
94
-
95
- evaluation_results[metric] = evaluation_result
96
-
97
- torch.cuda.empty_cache()
98
-
99
- fabric.barrier()
100
-
101
- model.to(fabric.device)
102
-
103
- return evaluation_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/evaluation/tasks/paloma.py DELETED
@@ -1,52 +0,0 @@
1
- """
2
- Paloma is a comprehensive evaluation benchmark for large language models (LLMs) that focuses
3
- on measuring perplexity across diverse text domains.
4
-
5
- To evaluate on Paloma, we use the huggingface evaluation framework.
6
-
7
- For more details, see: https://huggingface.co/datasets/allenai/paloma
8
- """
9
-
10
- import evaluate
11
- from datasets import load_dataset
12
- from datasets.utils.logging import disable_progress_bar, enable_progress_bar
13
-
14
- from src.config.evaluation_config import PalomaEvaluationConfig
15
-
16
-
17
- def run_paloma_evaluation(
18
- model_path: str,
19
- paloma_config: PalomaEvaluationConfig,
20
- ) -> None:
21
- """Run Perplexity evaluation on the Paloma evaluation dataset.
22
-
23
- We use the HuggingFace evaluate library to load in and compute the perplexity metric.
24
-
25
- Args:
26
- model_path (str): Path to the model checkpoint to be evaluated
27
- paloma_config (PalomaEvaluationConfig): Configuration for Paloma evaluation
28
- """
29
-
30
- disable_progress_bar()
31
-
32
- # load custom evaluation space, see https://huggingface.co/spaces/pico-lm/perplexity
33
- perplexity = evaluate.load("pico-lm/perplexity")
34
-
35
- dataset = load_dataset(
36
- paloma_config.dataset_name, split=paloma_config.dataset_split
37
- )["text"]
38
-
39
- # compute perplexity score on Paloma dataset
40
- perplexity_result = perplexity.compute(
41
- model_id=model_path,
42
- predictions=dataset,
43
- add_start_token=False,
44
- max_length=paloma_config.max_length,
45
- batch_size=paloma_config.batch_size,
46
- trust_remote_code=True,
47
- )
48
-
49
- mean_perplexity = perplexity_result["mean_perplexity"]
50
-
51
- enable_progress_bar()
52
- return mean_perplexity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- """
2
- Model Package
3
-
4
- This Package contains Pico models (currently only the Pico Decoder). We plan to implement other
5
- architectures in the future.
6
-
7
- If you have other models you'd like to implement, we recommend you add modules to this package.
8
- """
9
-
10
- from .pico_decoder import PicoDecoder
11
-
12
- __all__ = ["PicoDecoder"]
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model/pico_decoder.py DELETED
@@ -1,911 +0,0 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
-
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
-
15
- - KV-cache for faster autoregressive generation
16
-
17
- References:
18
- - RoPE: https://arxiv.org/abs/2104.09864
19
- - SwiGLU: https://arxiv.org/abs/2002.05202
20
- - LLAMA: https://arxiv.org/abs/2302.13971
21
-
22
- Adapted from:
23
- - OLMO: https://github.com/allenai/OLMo
24
- - LLAMA: https://github.com/meta/llama
25
- """
26
-
27
- from dataclasses import asdict
28
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
-
34
- # Handle PyTorch version compatibility for attention backend
35
- try:
36
- from torch.nn.attention import SDPBackend, sdpa_kernel
37
-
38
- HAS_TORCH_ATTENTION = True
39
- except ImportError:
40
- # Fallback for older PyTorch versions
41
- HAS_TORCH_ATTENTION = False
42
- SDPBackend = None
43
- sdpa_kernel = None
44
-
45
- from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
- from transformers.generation import GenerationConfig
47
- from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
-
49
- try:
50
- if TYPE_CHECKING:
51
- # We need to do this to avoid importing these when creating the HF-compatible models
52
- from src.config import ModelConfig
53
- except ImportError:
54
- pass
55
-
56
- ########################################################
57
- #
58
- # Layer Normalization
59
- #
60
- ########################################################
61
-
62
-
63
- class RMSNorm(torch.nn.Module):
64
- """Root Mean Square Layer Normalization.
65
-
66
- A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
- resulting in improved stability and performance.
68
-
69
- Args:
70
- config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
- - config.norm_eps: Small constant for numerical stability
72
- - config.d_model: Model dimension for the weight parameter
73
-
74
- References:
75
- https://arxiv.org/abs/1910.07467
76
- """
77
-
78
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
- super().__init__()
80
- self.eps = config.norm_eps
81
- self.weight = nn.Parameter(torch.ones(config.d_model))
82
-
83
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
- """
85
- Normalizes the input tensor by its RMS value.
86
- """
87
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
-
89
- def forward(self, x: torch.Tensor) -> torch.Tensor:
90
- """
91
- Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
- """
93
- output = self._norm(x.float()).type_as(x)
94
- return output * self.weight
95
-
96
-
97
- ########################################################
98
- #
99
- # Positional Embedding
100
- #
101
- ########################################################
102
-
103
-
104
- class RoPE(nn.Module):
105
- """Rotary Positional Embeddings (RoPE).
106
-
107
- Implements position-dependent rotation of keys and queries in attention mechanism,
108
- allowing better modeling of relative positions in sequences. Uses complex number
109
- operations for efficient rotation.
110
-
111
- Args:
112
- config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
- - config.position_emb_theta: Base for frequency computation
114
- - config.d_model: Model dimension
115
- - config.attention_n_heads: Number of attention heads
116
- - config.max_seq_len: Maximum sequence length
117
-
118
- References:
119
- https://arxiv.org/abs/2104.09864
120
- """
121
-
122
- _freqs_cis_tensor: torch.Tensor | None = None
123
-
124
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
- super().__init__()
126
-
127
- self.theta = config.position_emb_theta
128
- self.dim = config.d_model // config.attention_n_heads
129
-
130
- max_seq_len = config.max_seq_len
131
-
132
- # only gets set once, and then reused for all RoPE instances
133
- if RoPE._freqs_cis_tensor is None:
134
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
- max_seq_len, self.theta, self.dim
136
- )
137
-
138
- # register _freqs_cis buffer
139
- # can be easily recomputed so persistent=False
140
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
-
142
- @classmethod
143
- def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
- """Setup Frequency Tensor for RoPE Embeddings
145
-
146
- Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
-
148
- Note other implementations will use cos and sin directly, but using the complex
149
- number representation is (probably) more efficient:
150
-
151
- e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
- """
153
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
- positions = torch.arange(seq_len)
155
- freqs = torch.outer(positions, _freqs)
156
- return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
-
158
- def get_freqs_cis(
159
- self, input_shape: torch.Size, start_pos: int, end_pos: int
160
- ) -> torch.Tensor:
161
- """Reshape Frequency Tensor for RoPE Embeddings
162
-
163
- Makes the frequency tensor broadcastable with the input tensor.
164
- """
165
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
- ndim = len(input_shape)
167
- assert 0 <= 1 < ndim
168
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
-
170
- # TODO: Check whether this is correct (might be able to remove this)
171
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
- return _freqs_cis.view(*shape)
173
-
174
- def forward(
175
- self,
176
- queries: torch.Tensor,
177
- keys: torch.Tensor,
178
- start_pos: int = 0,
179
- ) -> Tuple[torch.Tensor, torch.Tensor]:
180
- """Apply RoPE Embeddings to Queries and Keys
181
-
182
- Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
-
184
- NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
- """
186
- queries_ = torch.view_as_complex(
187
- queries.float().reshape(*queries.shape[:-1], -1, 2)
188
- )
189
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
-
191
- input_shape = (
192
- queries_.shape
193
- ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
- freqs_start_pos = start_pos
195
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
-
197
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
-
199
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
-
203
-
204
- ########################################################
205
- #
206
- # Attention
207
- #
208
- ########################################################
209
-
210
-
211
- class Attention(nn.Module):
212
- """Multi-head Attention with Group Query Attention support.
213
-
214
- Implements scaled dot-product attention and supports:
215
- - Grouped Query Attention (GQA)
216
- - Key-Value caching for efficient inference
217
- - RoPE integration
218
-
219
- Args:
220
- config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
- - config.attention_n_heads: Number of attention heads
222
- - config.attention_n_kv_heads: Number of key/value heads
223
- - config.d_model: Model dimension
224
- - config.batch_size: Maximum batch size
225
- - config.max_seq_len: Maximum sequence length
226
-
227
- Shape:
228
- - Input: (batch_size, seq_len, d_model)
229
- - Output: (batch_size, seq_len, d_model)
230
- """
231
-
232
- def __init__(
233
- self,
234
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
- ):
236
- super().__init__()
237
-
238
- self.n_heads = config.attention_n_heads
239
- self.n_kv_heads = config.attention_n_kv_heads
240
-
241
- self.batch_size = config.batch_size
242
- self.max_seq_len = config.max_seq_len
243
-
244
- d_model = config.d_model
245
- self.head_dim = d_model // self.n_heads
246
-
247
- self.n_rep = self.n_heads // self.n_kv_heads
248
-
249
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
-
254
- self.rope = RoPE(config)
255
-
256
- def forward(
257
- self,
258
- input: torch.Tensor,
259
- mask: Optional[torch.Tensor] = None,
260
- past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
- use_cache: bool = False,
262
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
- """Forward pass for the attention mechanism.
264
-
265
- Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
- embeddings to the queries and keys, and then computes attention scores and outputs.
267
-
268
- For an introduction to the attention mechanism, see:
269
- https://arxiv.org/abs/1706.03762
270
-
271
- A few things to note:
272
- - The past_key_values is used to implement the KV cache, which is used to speed up
273
- generation by caching the KV pairs from previous forward passes. This is useful when doing
274
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
- its own KV cache - this KV cache is implemented as a tuple.
277
- """
278
- bsz, seq_len, _ = input.shape
279
- _queries, _keys, _values = (
280
- self.q_proj(input),
281
- self.k_proj(input),
282
- self.v_proj(input),
283
- )
284
-
285
- # Reshaping for multi-head attention
286
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
- keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
- values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
-
290
- # The start position is used to apply the RoPE embeddings to only the new tokens
291
- # when using the kv_cache in the attention mechanism.
292
- # We want to start from the last position in the cache.
293
- start_pos = 0
294
- if past_key_values is not None and past_key_values[0] is not None:
295
- start_pos = past_key_values[0].shape[1]
296
-
297
- # apply rotary positional embeddings
298
- queries, keys = self.rope(queries, keys, start_pos)
299
-
300
- if (
301
- past_key_values is not None
302
- and past_key_values[0] is not None
303
- and past_key_values[1] is not None
304
- ):
305
- keys = torch.cat([past_key_values[0], keys], dim=1)
306
- values = torch.cat([past_key_values[1], values], dim=1)
307
-
308
- if use_cache:
309
- cached_keys = keys
310
- cached_values = values
311
- else:
312
- cached_keys = None
313
- cached_values = None
314
-
315
- queries = queries.transpose(1, 2)
316
- keys = keys.transpose(1, 2)
317
- values = values.transpose(1, 2)
318
-
319
- apply_gqa = self.n_rep > 1
320
- if apply_gqa and queries.device.type == "mps":
321
- # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
- # outside of the kernel to get the same effect.
323
- # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
- values = values.repeat_interleave(self.n_rep, dim=-3)
326
- apply_gqa = False
327
-
328
- if HAS_TORCH_ATTENTION:
329
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
- with sdpa_kernel(backends=backends):
331
- attn_output = F.scaled_dot_product_attention(
332
- queries.contiguous(),
333
- keys.contiguous(),
334
- values.contiguous(),
335
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
- enable_gqa=apply_gqa,
337
- )
338
- else:
339
- # Fallback for older PyTorch versions - use default backend
340
- attn_output = F.scaled_dot_product_attention(
341
- queries.contiguous(),
342
- keys.contiguous(),
343
- values.contiguous(),
344
- attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
- enable_gqa=apply_gqa,
346
- )
347
-
348
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
- output = self.o_proj(attn_output)
350
-
351
- return output, (cached_keys, cached_values)
352
-
353
-
354
- ########################################################
355
- #
356
- # SwiGLU (Combines MLP and Activation)
357
- #
358
- ########################################################
359
-
360
-
361
- class SwiGLU(nn.Module):
362
- """SwiGLU Activation Function with Linear Projections.
363
-
364
- Implements the SwiGLU activation function combined with linear transformations,
365
- serving as the feed-forward network in transformer blocks.
366
-
367
- Args:
368
- config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
- - config.d_model: Model dimension
370
- - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
-
372
- References:
373
- https://arxiv.org/abs/2002.05202
374
- """
375
-
376
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
- super().__init__()
378
-
379
- model_dim = config.d_model
380
- act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
-
382
- self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
- self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
- self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
-
386
- def forward(self, x: torch.Tensor) -> torch.Tensor:
387
- return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
-
389
-
390
- ########################################################
391
- #
392
- # PicoDecoderBlock
393
- #
394
- ########################################################
395
-
396
-
397
- class PicoDecoderBlock(nn.Module):
398
- """Single Transformer Block with Attention and Feed-forward layers.
399
-
400
- Implements a standard transformer block with:
401
- - Multi-head attention with normalization and residual connection
402
- - SwiGLU feed-forward network with normalization and residual connection
403
-
404
- Args:
405
- config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
- a HuggingFace PicoDecoderHFConfig
407
- """
408
-
409
- def __init__(
410
- self,
411
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
- ):
413
- super().__init__()
414
-
415
- self.attention = Attention(config)
416
- self.swiglu = SwiGLU(config)
417
- self.attention_norm = RMSNorm(config)
418
- self.swiglu_norm = RMSNorm(config)
419
-
420
- def forward(
421
- self,
422
- input: torch.Tensor,
423
- mask: Optional[torch.Tensor] = None,
424
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
- use_cache: bool = False,
426
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
- attention_output, cached_key_values = self.attention(
428
- self.attention_norm(input),
429
- mask=mask,
430
- past_key_values=past_key_values,
431
- use_cache=use_cache,
432
- )
433
- # NOTE: cached_key_values is None if use_cache is False
434
-
435
- h = input + attention_output
436
- out = h + self.swiglu(self.swiglu_norm(h))
437
- return out, cached_key_values
438
-
439
-
440
- ########################################################
441
- #
442
- # Pico Decoder (Causal Transformer Model)
443
- #
444
- ########################################################
445
-
446
-
447
- class PicoDecoder(nn.Module):
448
- """
449
- Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
- single autoregressive model.
451
-
452
- For more information on the model, see the classes for the modules that make up the model.
453
- """
454
-
455
- def __init__(
456
- self,
457
- model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
- ):
459
- super().__init__()
460
- self.config = model_config
461
-
462
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
- self.layers = nn.ModuleList(
464
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
- )
466
- self.output_norm = RMSNorm(self.config)
467
- self.de_embedding_proj = nn.Linear(
468
- self.config.d_model, self.config.vocab_size, bias=False
469
- )
470
-
471
- def convert_to_hf_model(self) -> "PicoDecoderHF":
472
- """Convert the Lightning model to a HuggingFace model."""
473
- # Create HF config without fabric-specific settings
474
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
-
476
- # Create new HF model
477
- hf_model = PicoDecoderHF(hf_config)
478
-
479
- # Copy state dict, excluding fabric-specific keys
480
- hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
-
482
- return hf_model
483
-
484
- def forward(
485
- self,
486
- input_ids: torch.Tensor,
487
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
- use_cache: bool = False,
489
- ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
- """
491
- This is the forward pass for the entire Pico model. It boils down to:
492
- - Embedding the input ids
493
- - Creating a causal mask
494
- - Processing through the pico layers
495
- - Projecting the output to logits
496
-
497
- NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
- generation by caching the KV pairs from previous forward passes. This is useful when doing
499
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
- its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
- KV caches (so a tuple of tuples).
503
- """
504
-
505
- seq_len = input_ids.shape[-1]
506
- h = self.embedding_proj(input_ids)
507
-
508
- # Calculate start position from past cached KV pairs. Remember that each layer has its
509
- # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
- # correct layer and then for either the keys or values.
511
- start_pos = 0
512
- if (
513
- past_key_values is not None
514
- and past_key_values[0] is not None
515
- and past_key_values[0][0] is not None
516
- ):
517
- start_pos = past_key_values[0][0].shape[1]
518
-
519
- # Create causal mask for current sequence
520
- mask = None
521
- if seq_len > 1:
522
- mask = torch.full((seq_len, seq_len), float("-inf"))
523
- mask = torch.triu(mask, diagonal=1)
524
-
525
- # If using KV cache, extend mask to cover cached sequence length
526
- if past_key_values is not None:
527
- # Add zeros for cached tokens (we can attend to all of them)
528
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
-
530
- mask = mask.to(h.device)
531
-
532
- # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
- # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
- cached_key_values = () if use_cache else None
535
-
536
- # Process through transformer blocks
537
- for idx, layer in enumerate(self.layers):
538
- layer_past_key_values = None
539
- if past_key_values is not None:
540
- try:
541
- # Handle both tuple-based cache and HuggingFace cache objects
542
- if hasattr(past_key_values, "__getitem__") and idx < len(
543
- past_key_values
544
- ):
545
- layer_past_key_values = past_key_values[idx]
546
- except (KeyError, IndexError, TypeError):
547
- # If we can't access the cache properly, just skip it
548
- layer_past_key_values = None
549
-
550
- h, layer_cached_key_values = layer(
551
- h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
- )
553
-
554
- if use_cache:
555
- cached_key_values += (layer_cached_key_values,)
556
-
557
- # Final norm and projection
558
- h = self.output_norm(h)
559
- logits = self.de_embedding_proj(h).float()
560
-
561
- return logits, cached_key_values
562
-
563
-
564
- ########################################################
565
- #
566
- # HuggingFace Wrapper for the Pico Decoder model.
567
- #
568
- ########################################################
569
-
570
-
571
- class PicoDecoderHFConfig(PretrainedConfig):
572
- """Config class for the Pico Decoder HuggingFace wrapper."""
573
-
574
- model_type = "pico_decoder"
575
-
576
- @classmethod
577
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
- """
579
- Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
- this is because with some kwargs special handling is required and can make this class
581
- brittle.
582
- """
583
- pico_config = cls(**config_dict)
584
-
585
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
- unused_kwargs = {
587
- key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
- }
589
-
590
- if return_unused_kwargs:
591
- return pico_config, unused_kwargs
592
- return pico_config
593
-
594
- @classmethod
595
- def from_dataclass(cls, model_config: "ModelConfig"):
596
- """Initialise from our custom config dataclass."""
597
- return cls.from_dict(asdict(model_config))
598
-
599
-
600
- class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
- """
602
- HuggingFace wrapper for the Pico model with generation support.
603
-
604
- Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
- wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
- Pico model as well as the model wrapped in this HuggingFace class.
607
-
608
- This also lets you do cool things like:
609
-
610
- `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
- """
612
-
613
- config_class = PicoDecoderHFConfig
614
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
- main_input_name = "input_ids"
616
-
617
- def __init__(self, config: PicoDecoderHFConfig):
618
- super().__init__(config)
619
- self.pico_decoder = PicoDecoder(config)
620
- # Initialize generation config with defaults
621
- self.generation_config = GenerationConfig()
622
- # Set some reasonable defaults for the model
623
- if hasattr(config, "max_position_embeddings"):
624
- self.generation_config.max_length = config.max_position_embeddings
625
- if hasattr(config, "vocab_size"):
626
- self.generation_config.vocab_size = config.vocab_size
627
-
628
- def forward(
629
- self,
630
- input_ids: torch.Tensor,
631
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
- use_cache: bool = False,
633
- **kwargs,
634
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
- """HuggingFace forward pass wrapper.
636
-
637
- Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
- Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
- """
640
- logits, past_key_values = self.pico_decoder(
641
- input_ids, past_key_values, use_cache
642
- )
643
- if use_cache:
644
- return CausalLMOutputWithPast(
645
- logits=logits,
646
- past_key_values=past_key_values,
647
- )
648
- else:
649
- return CausalLMOutput(
650
- logits=logits,
651
- )
652
-
653
- def prepare_inputs_for_generation(
654
- self,
655
- input_ids: torch.LongTensor,
656
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
- attention_mask: Optional[torch.LongTensor] = None,
658
- **kwargs,
659
- ) -> Dict[str, Any]:
660
- """
661
- Prepare inputs for generation.
662
-
663
- Args:
664
- input_ids: Input token IDs
665
- past_key_values: Cached key-value pairs from previous forward passes
666
- attention_mask: Attention mask for the input
667
- **kwargs: Additional arguments
668
-
669
- Returns:
670
- Dictionary containing prepared inputs
671
- """
672
- # If we have past_key_values, we only need the last token
673
- if past_key_values is not None:
674
- input_ids = input_ids[:, -1:]
675
-
676
- return {
677
- "input_ids": input_ids,
678
- "past_key_values": past_key_values,
679
- "use_cache": True,
680
- }
681
-
682
- def get_input_embeddings(self):
683
- """Get the input embeddings layer."""
684
- return self.pico_decoder.embedding_proj
685
-
686
- def set_input_embeddings(self, value):
687
- """Set the input embeddings layer."""
688
- self.pico_decoder.embedding_proj = value
689
-
690
- def get_output_embeddings(self):
691
- """Get the output embeddings layer."""
692
- return self.pico_decoder.de_embedding_proj
693
-
694
- def set_output_embeddings(self, value):
695
- """Set the output embeddings layer."""
696
- self.pico_decoder.de_embedding_proj = value
697
-
698
- def get_lm_head(self):
699
- """Get the language model head."""
700
- return self.pico_decoder.de_embedding_proj
701
-
702
- def can_generate(self) -> bool:
703
- """Check if the model can generate text."""
704
- return True
705
-
706
- @property
707
- def is_encoder_decoder(self) -> bool:
708
- """Check if the model is an encoder-decoder model."""
709
- return False
710
-
711
- @property
712
- def can_use_cache(self) -> bool:
713
- """Check if the model can use KV cache."""
714
- return True
715
-
716
- def resize_token_embeddings(
717
- self, new_num_tokens: Optional[int] = None
718
- ) -> torch.nn.Embedding:
719
- """Resize token embeddings."""
720
- old_embeddings = self.get_input_embeddings()
721
- if new_num_tokens is None:
722
- new_num_tokens = old_embeddings.num_embeddings
723
-
724
- new_embeddings = torch.nn.Embedding(
725
- new_num_tokens, old_embeddings.embedding_dim
726
- )
727
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
- old_embeddings.weight.data
729
- )
730
-
731
- self.pico_decoder.embedding_proj = new_embeddings
732
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
- old_embeddings.embedding_dim, new_num_tokens, bias=False
734
- )
735
-
736
- return new_embeddings
737
-
738
-
739
- # Register for auto classes
740
- PicoDecoderHFConfig.register_for_auto_class()
741
- PicoDecoderHF.register_for_auto_class("AutoModel")
742
- PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
-
744
-
745
- ########################################################
746
- #
747
- # New PicoDecoderForCausalLM class for generation support
748
- #
749
- ########################################################
750
-
751
-
752
- class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
- """
754
- PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
-
756
- This class is designed to work with existing checkpoints and provides full generation support.
757
- It inherits from the right base classes that HuggingFace expects for text generation.
758
- """
759
-
760
- config_class = PicoDecoderHFConfig
761
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
- main_input_name = "input_ids"
763
-
764
- def __init__(self, config: PicoDecoderHFConfig):
765
- super().__init__(config)
766
- self.pico_decoder = PicoDecoder(config)
767
- # Initialize generation config with defaults
768
- self.generation_config = GenerationConfig()
769
- # Set some reasonable defaults for the model
770
- if hasattr(config, "max_position_embeddings"):
771
- self.generation_config.max_length = config.max_position_embeddings
772
- if hasattr(config, "vocab_size"):
773
- self.generation_config.vocab_size = config.vocab_size
774
-
775
- def forward(
776
- self,
777
- input_ids: torch.Tensor,
778
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
- use_cache: bool = False,
780
- **kwargs,
781
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
- """Forward pass for text generation."""
783
- logits, past_key_values = self.pico_decoder(
784
- input_ids, past_key_values, use_cache
785
- )
786
- if use_cache:
787
- return CausalLMOutputWithPast(
788
- logits=logits,
789
- past_key_values=past_key_values,
790
- )
791
- else:
792
- return CausalLMOutput(
793
- logits=logits,
794
- )
795
-
796
- def prepare_inputs_for_generation(
797
- self,
798
- input_ids: torch.LongTensor,
799
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
- attention_mask: Optional[torch.LongTensor] = None,
801
- **kwargs,
802
- ) -> Dict[str, Any]:
803
- """Prepare inputs for generation."""
804
- # If we have past_key_values, we only need the last token
805
- if past_key_values is not None:
806
- input_ids = input_ids[:, -1:]
807
-
808
- return {
809
- "input_ids": input_ids,
810
- "past_key_values": past_key_values,
811
- "use_cache": True,
812
- }
813
-
814
- def get_input_embeddings(self):
815
- """Get the input embeddings layer."""
816
- return self.pico_decoder.embedding_proj
817
-
818
- def set_input_embeddings(self, value):
819
- """Set the input embeddings layer."""
820
- self.pico_decoder.embedding_proj = value
821
-
822
- def get_output_embeddings(self):
823
- """Get the output embeddings layer."""
824
- return self.pico_decoder.de_embedding_proj
825
-
826
- def set_output_embeddings(self, value):
827
- """Set the output embeddings layer."""
828
- self.pico_decoder.de_embedding_proj = value
829
-
830
- def get_lm_head(self):
831
- """Get the language model head."""
832
- return self.pico_decoder.de_embedding_proj
833
-
834
- def can_generate(self) -> bool:
835
- """Check if the model can generate text."""
836
- return True
837
-
838
- @property
839
- def is_encoder_decoder(self) -> bool:
840
- """Check if the model is an encoder-decoder model."""
841
- return False
842
-
843
- @property
844
- def can_use_cache(self) -> bool:
845
- """Check if the model can use KV cache."""
846
- return True
847
-
848
- def resize_token_embeddings(
849
- self, new_num_tokens: Optional[int] = None
850
- ) -> torch.nn.Embedding:
851
- """Resize token embeddings."""
852
- old_embeddings = self.get_input_embeddings()
853
- if new_num_tokens is None:
854
- new_num_tokens = old_embeddings.num_embeddings
855
-
856
- new_embeddings = torch.nn.Embedding(
857
- new_num_tokens, old_embeddings.embedding_dim
858
- )
859
- new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
- old_embeddings.weight.data
861
- )
862
-
863
- self.pico_decoder.embedding_proj = new_embeddings
864
- self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
- old_embeddings.embedding_dim, new_num_tokens, bias=False
866
- )
867
-
868
- return new_embeddings
869
-
870
- @classmethod
871
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
- """
873
- Load a pretrained model from a checkpoint.
874
-
875
- This method handles loading from both the old PicoDecoderHF format and the new format.
876
- """
877
- # First try to load with the new class
878
- try:
879
- return super().from_pretrained(
880
- pretrained_model_name_or_path, *model_args, **kwargs
881
- )
882
- except Exception as e:
883
- print(f"Failed to load with new class: {e}")
884
- print("Attempting to load with legacy class and convert...")
885
-
886
- # Try to load with the old class and convert
887
- try:
888
- from transformers import AutoModel
889
-
890
- old_model = AutoModel.from_pretrained(
891
- pretrained_model_name_or_path,
892
- trust_remote_code=True,
893
- *model_args,
894
- **kwargs,
895
- )
896
-
897
- # Create new model instance
898
- new_model = cls(old_model.config)
899
-
900
- # Copy state dict
901
- new_model.load_state_dict(old_model.state_dict(), strict=False)
902
-
903
- return new_model
904
-
905
- except Exception as e2:
906
- print(f"Failed to convert from legacy format: {e2}")
907
- raise e
908
-
909
-
910
- # Register the new class
911
- PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/trainer.py DELETED
@@ -1,753 +0,0 @@
1
- """
2
- Pico Language Model Trainer
3
-
4
- This Trainer implements a minimalistic end-to-end training pipeline of the Pico language model with
5
- distributed training support via Lightning Fabric. It provides a modular and configurable training
6
- pipeline with the features:
7
-
8
- - Configuration Management: YAML-based configuration for all aspects of training
9
- - Distributed Training: Multi-GPU support via Lightning Fabric
10
- - Checkpointing: Regular model saving and training state recovery
11
- - Evaluation: Periodic model evaluation on validation datasets
12
- - Logging: Comprehensive metric tracking and experiment monitoring
13
- - Optimization: Support for gradient accumulation, clipping, and LR scheduling
14
- """
15
-
16
- import logging
17
- import os
18
- import platform
19
- from typing import Any, Dict
20
-
21
- import lightning as L
22
- import psutil
23
- import torch
24
- import torch.nn.functional as F
25
- import yaml
26
- from datasets import Dataset, load_dataset
27
- from lightning.fabric.utilities.rank_zero import rank_zero_only
28
-
29
- from src.checkpointing import (
30
- compute_learning_dynamics_states,
31
- load_checkpoint,
32
- save_checkpoint,
33
- save_evaluation_results,
34
- save_learning_dynamics_states,
35
- )
36
- from src.evaluation import run_evaluation
37
- from src.training.utils import (
38
- initialize_configuration,
39
- initialize_dataloader,
40
- initialize_dataset,
41
- initialize_fabric,
42
- initialize_hf_checkpointing,
43
- initialize_logging,
44
- initialize_lr_scheduler,
45
- initialize_model,
46
- initialize_optimizer,
47
- initialize_run_dir,
48
- initialize_tokenizer,
49
- initialize_wandb,
50
- )
51
- from src.training.utils.logging import pretty_print_yaml_config
52
-
53
-
54
- class Trainer:
55
- def __init__(self, config_path: str):
56
- """
57
- Initializes the Trainer class. This Trainer class implements a `train` method, which is the
58
- main entry point for training the Pico model. Before calling `train`, the Trainer class
59
- initializes the following:
60
-
61
- - Configuration loading and validation
62
- - Model, optimizer, and dataset setup
63
- - Logging and experiment tracking setup
64
- - Checkpoint management
65
-
66
- Args:
67
- config_path (str): Path to the YAML configuration file containing any overrides.
68
- """
69
-
70
- ########################################################
71
- #
72
- # Basic Initialization of Configs, Fabric, Model, Optimizer, etc.
73
- #
74
- ########################################################
75
-
76
- # Setup Config
77
- self.configs = initialize_configuration(config_path)
78
-
79
- # Setup Run Directory (i.e. where we store checkpoints, logs, etc.)
80
- initialize_run_dir(checkpointing_config=self.configs["checkpointing"])
81
-
82
- # Setup Logger
83
- if self.configs["monitoring"].save_to_wandb:
84
- wandb_logger = initialize_wandb(
85
- monitoring_config=self.configs["monitoring"],
86
- checkpointing_config=self.configs["checkpointing"],
87
- )
88
- else:
89
- wandb_logger = None
90
-
91
- # Setup Fabric
92
- self.fabric = initialize_fabric(
93
- training_config=self.configs["training"],
94
- wandb_logger=wandb_logger,
95
- )
96
- L.seed_everything(42, verbose=False)
97
-
98
- # Optimize for Tensor Cores on RTX 5090
99
- if self.fabric.device.type == "cuda":
100
- torch.set_float32_matmul_precision(
101
- "high"
102
- ) # Best performance for Tensor Cores
103
- print(
104
- "Enabled Tensor Core optimization: torch.set_float32_matmul_precision('high')"
105
- )
106
-
107
- # Set up logging
108
- self.logger = initialize_logging(
109
- monitoring_config=self.configs["monitoring"],
110
- checkpointing_config=self.configs["checkpointing"],
111
- fabric=self.fabric,
112
- )
113
-
114
- # Setup Model, Optimizer, and Dataloaders
115
- self.model = initialize_model(model_config=self.configs["model"])
116
- self.optimizer = initialize_optimizer(
117
- training_config=self.configs["training"], model=self.model
118
- )
119
- self.lr_scheduler = initialize_lr_scheduler(
120
- training_config=self.configs["training"], optimizer=self.optimizer
121
- )
122
-
123
- # Wrap model and optimizer with Fabric
124
- self.model, self.optimizer = self.fabric.setup(self.model, self.optimizer)
125
-
126
- # Setup HuggingFace Checkpointing
127
- if self.configs["checkpointing"].save_to_hf:
128
- initialize_hf_checkpointing(
129
- checkpointing_config=self.configs["checkpointing"], fabric=self.fabric
130
- )
131
-
132
- ########################################################
133
- #
134
- # Boilerplate to deal with loading/resuming from checkpoints
135
- #
136
- ########################################################
137
-
138
- self.should_load_checkpoint = self.configs["checkpointing"].training.auto_resume
139
-
140
- # Possibly load a checkpoint
141
- if self.should_load_checkpoint:
142
- resume_checkpoint = load_checkpoint(
143
- checkpointing_config=self.configs["checkpointing"],
144
- checkpoint_step="latest",
145
- fabric=self.fabric,
146
- model=self.model,
147
- optimizer=self.optimizer,
148
- lr_scheduler=self.lr_scheduler,
149
- )
150
-
151
- if resume_checkpoint:
152
- (
153
- self.model,
154
- self.optimizer,
155
- self.lr_scheduler,
156
- self.initial_batch_step,
157
- ) = resume_checkpoint
158
- else:
159
- self.initial_batch_step = 0
160
- else:
161
- self.initial_batch_step = 0
162
-
163
- ########################################################
164
- #
165
- # Initialization of Dataset & DataLoader (possibly fast-forwarding to correct batch)
166
- #
167
- ########################################################
168
-
169
- self.train_dataset, fast_forward_steps = initialize_dataset(
170
- data_config=self.configs["data"],
171
- fabric=self.fabric,
172
- initial_batch_step=self.initial_batch_step,
173
- return_fast_forward_steps=True,
174
- )
175
-
176
- self.train_dataloader = initialize_dataloader(
177
- data_config=self.configs["data"],
178
- training_config=self.configs["training"],
179
- fabric=self.fabric,
180
- dataset=self.train_dataset,
181
- )
182
- self.train_dataloader = self.fabric.setup_dataloaders(
183
- self.train_dataloader, use_distributed_sampler=False
184
- )
185
-
186
- self.tokenizer = initialize_tokenizer(data_config=self.configs["data"])
187
-
188
- # NOTE: We may need to fast-forward the iterator to the correct step so that we can
189
- # continue from the correct batch of data we would have seen had training not
190
- # previously stopped.
191
- train_iterator = iter(self.train_dataloader)
192
- if fast_forward_steps > 0:
193
- fast_forward_sub_steps = (
194
- fast_forward_steps
195
- * self.configs["training"].optimization.gradient_accumulation_steps
196
- )
197
- for _ in range(fast_forward_sub_steps):
198
- next(train_iterator)
199
-
200
- self.train_iterator = train_iterator
201
-
202
- # NOTE: Sychronizing processes after fast-forwarding iterator
203
- self.fabric.barrier()
204
-
205
- ########################################################
206
- #
207
- # Helper flags used during training for checkpointing and evaluation
208
- #
209
- ########################################################
210
-
211
- # Helper flag to determine if we should evaluate the model
212
- self.should_evaluate = (
213
- self.configs["evaluation"].metrics is not None
214
- and len(self.configs["evaluation"].metrics) > 0
215
- )
216
-
217
- self.should_compute_learning_dynamics = (
218
- self.configs["checkpointing"].learning_dynamics.layer_suffixes is not None
219
- and len(self.configs["checkpointing"].learning_dynamics.layer_suffixes) > 0
220
- )
221
-
222
- if self.should_compute_learning_dynamics:
223
- if self.configs["checkpointing"].learning_dynamics.eval_data is not None:
224
- self.learning_dynamics_eval_dataset = load_dataset(
225
- self.configs["checkpointing"].learning_dynamics.eval_data,
226
- split="val",
227
- )
228
- else:
229
- self.learning_dynamics_eval_dataset = None
230
-
231
- def train(self) -> None:
232
- """Execute the main training pipeline.
233
-
234
- This method orchestrates the complete training process by:
235
- 1. Creating an initial checkpoint to save the starting state and evaluate the model as a
236
- baseline
237
- 2. Running the main training loop via `_training_loop`
238
- 3. Handling final checkpointing and evaluation
239
-
240
- The training progress is tracked through checkpoints and evaluations
241
- at intervals specified in the configuration.
242
- """
243
-
244
- ########################################################
245
- #
246
- # Initial Checkpointing and Evaluation
247
- #
248
- ########################################################
249
-
250
- # Save Initial Checkpoint -- If the checkpoint already exists, this performs a no-op
251
- save_checkpoint(
252
- configs=self.configs,
253
- checkpoint_step=self.initial_batch_step,
254
- fabric=self.fabric,
255
- model=self.model,
256
- optimizer=self.optimizer,
257
- lr_scheduler=self.lr_scheduler,
258
- tokenizer=self.tokenizer,
259
- )
260
-
261
- # Save Initial Evaluation Results
262
- if self.should_evaluate:
263
- if self.initial_batch_step == 0:
264
- evaluation_results = run_evaluation(
265
- evaluation_config=self.configs["evaluation"],
266
- checkpointing_config=self.configs["checkpointing"],
267
- fabric=self.fabric,
268
- model=self.model,
269
- )
270
- self._log_evaluation_results(
271
- evaluation_results, self.initial_batch_step
272
- )
273
- save_evaluation_results(
274
- checkpointing_config=self.configs["checkpointing"],
275
- fabric=self.fabric,
276
- evaluation_results=evaluation_results,
277
- checkpoint_step=self.initial_batch_step,
278
- )
279
- else:
280
- # NOTE: If the run crashed while evaluating, we need to restart the evaluation
281
- eval_results_path = os.path.join(
282
- self.configs["checkpointing"].evaluation.eval_results_dir,
283
- f"step_{self.initial_batch_step}.json",
284
- )
285
- if not os.path.exists(eval_results_path):
286
- evaluation_results = run_evaluation(
287
- evaluation_config=self.configs["evaluation"],
288
- checkpointing_config=self.configs["checkpointing"],
289
- fabric=self.fabric,
290
- model=self.model,
291
- )
292
- self._log_evaluation_results(
293
- evaluation_results, self.initial_batch_step
294
- )
295
- save_evaluation_results(
296
- checkpointing_config=self.configs["checkpointing"],
297
- fabric=self.fabric,
298
- evaluation_results=evaluation_results,
299
- checkpoint_step=self.initial_batch_step,
300
- )
301
-
302
- ########################################################
303
- #
304
- # Main Training Loop (see `_training_loop` for details)
305
- #
306
- ########################################################
307
-
308
- if self.initial_batch_step < self.configs["training"].max_steps:
309
- self._log_training_configuration()
310
- final_step = self._training_loop()
311
- else:
312
- final_step = self.initial_batch_step
313
-
314
- ########################################################
315
- #
316
- # Final Checkpointing and Evaluation
317
- #
318
- ########################################################
319
-
320
- # Save Learning Dynamics States
321
- if self.should_compute_learning_dynamics:
322
- if self.learning_dynamics_eval_dataset is not None:
323
- self.log(f"Step {final_step} -- 📈 Saving Learning Dynamics")
324
- learning_dynamics_val_states = compute_learning_dynamics_states(
325
- checkpointing_config=self.configs["checkpointing"],
326
- fabric=self.fabric,
327
- model=self.model,
328
- dataset=self.learning_dynamics_eval_dataset,
329
- compute_gradients=True,
330
- )
331
- save_learning_dynamics_states(
332
- checkpointing_config=self.configs["checkpointing"],
333
- fabric=self.fabric,
334
- learning_dynamics_states=learning_dynamics_val_states,
335
- checkpoint_step=final_step,
336
- prefix="val",
337
- )
338
-
339
- # Handle checkpointing and final evaluation
340
- if final_step % self.configs["checkpointing"].save_every_n_steps != 0:
341
- self.log(f"Step {final_step} -- 💾 Saving Final Checkpoint")
342
- save_checkpoint(
343
- configs=self.configs,
344
- checkpoint_step=final_step,
345
- fabric=self.fabric,
346
- model=self.model,
347
- optimizer=self.optimizer,
348
- lr_scheduler=self.lr_scheduler,
349
- tokenizer=self.tokenizer,
350
- )
351
-
352
- # Final evaluation
353
- if self.should_evaluate:
354
- evaluation_results = run_evaluation(
355
- evaluation_config=self.configs["evaluation"],
356
- checkpointing_config=self.configs["checkpointing"],
357
- fabric=self.fabric,
358
- model=self.model,
359
- )
360
- self._log_evaluation_results(evaluation_results, final_step)
361
- save_evaluation_results(
362
- checkpointing_config=self.configs["checkpointing"],
363
- checkpoint_step=final_step,
364
- fabric=self.fabric,
365
- evaluation_results=evaluation_results,
366
- )
367
-
368
- self.log(f"🎉 Training complete! Final step: {final_step}")
369
-
370
- if final_step < self.configs["training"].max_steps:
371
- self.log(
372
- f"\t Note: Training stopped before max steps ({self.configs['training'].max_steps})",
373
- level=logging.WARNING,
374
- )
375
-
376
- # Cleanup distributed training
377
- self.fabric.barrier()
378
- if torch.cuda.is_available():
379
- torch.cuda.empty_cache()
380
- if torch.distributed.is_initialized():
381
- torch.distributed.destroy_process_group()
382
-
383
- del self.train_dataloader # NOTE: shutting down worker nodes
384
-
385
- self.fabric.barrier()
386
-
387
- def _training_loop(self) -> int:
388
- """Execute the main training loop.
389
-
390
- This method orchestrates the core training loop and includes the following features:
391
- - Gradient accumulation
392
- - Gradient clipping
393
- - Periodic model evaluation and checkpointing
394
- - Learning Dynamics Checkpointing
395
- - Learning rate scheduling
396
- - Logging of training metrics including loss and learning rate
397
- - Handling of infinite/NaN losses
398
-
399
- Returns:
400
- int: The final step count reached during training.
401
- NOTE: A complete training run should match the configured max_steps.
402
- """
403
- # Setup training loop variables
404
- batch_step = self.initial_batch_step
405
-
406
- # NOTE: these are used to compute the average loss over a training interval.
407
- # This is more accurate than using the loss at the end of the interval.
408
- interval_loss = torch.tensor(0.0, device=self.fabric.device)
409
- interval_steps = torch.tensor(0, device=self.fabric.device)
410
- interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device)
411
-
412
- if self.should_compute_learning_dynamics:
413
- # NOTE: we basically re-construct the full batch here so that we can compute learning dynamics
414
- training_batch = {"input_ids": []}
415
-
416
- # NOTE: determine what sub-batch we should start from
417
- initial_sub_batch_step = (
418
- batch_step
419
- * self.configs["training"].optimization.gradient_accumulation_steps
420
- )
421
-
422
- ###############################################################
423
- #
424
- # Core loop starts here
425
- # NOTE: the ratio between sub_batch_step and batch_step
426
- # is the configured number of gradient_accumulation_steps
427
- # i.e. with 32 configured gradient accumulation steps,
428
- # there are 32 sub_batch_steps for each batch_step
429
- #
430
- ###############################################################
431
-
432
- for sub_batch_step, sub_batch in enumerate(
433
- self.train_iterator, start=initial_sub_batch_step
434
- ):
435
- # NOTE: We want to store the entire training batch whenever we are computing learning dynamics
436
- # and we are at a checkpointing step.
437
- should_store_training_batch = self.should_compute_learning_dynamics and (
438
- batch_step % self.configs["checkpointing"].save_every_n_steps == 0
439
- )
440
-
441
- ########################################################
442
- #
443
- # Forward Pass
444
- #
445
- ########################################################
446
-
447
- _input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device)
448
- input_ids = _input_ids[:, :-1]
449
- labels = _input_ids[:, 1:]
450
-
451
- if should_store_training_batch:
452
- gathered_input_ids = self.fabric.all_gather(_input_ids)
453
-
454
- # NOTE: On multi-GPU, we need to reshape the input_ids to be a 2D tensor; on
455
- # a single GPU, the input_ids are already a 2D tensor.
456
- if self.fabric.world_size > 1:
457
- gathered_input_ids = gathered_input_ids.reshape(
458
- -1, *gathered_input_ids.shape[2:]
459
- )
460
-
461
- training_batch["input_ids"].extend(gathered_input_ids.tolist())
462
-
463
- # Forward pass
464
- model_output, _ = self.model(input_ids)
465
- model_output = model_output.transpose(1, 2)
466
-
467
- ########################################################
468
- #
469
- # Gradient accumulation
470
- #
471
- ########################################################
472
-
473
- should_accumulate_gradients = (sub_batch_step + 1) % self.configs[
474
- "training"
475
- ].optimization.gradient_accumulation_steps != 0
476
-
477
- with self.fabric.no_backward_sync(
478
- self.model, enabled=should_accumulate_gradients
479
- ):
480
- loss = F.cross_entropy(model_output, labels)
481
- self.fabric.backward(
482
- loss
483
- / self.configs["training"].optimization.gradient_accumulation_steps,
484
- model=self.model,
485
- )
486
-
487
- if torch.isnan(loss) or torch.isinf(loss):
488
- interval_inf_or_nan_count += 1
489
- else:
490
- interval_loss += loss.item()
491
- interval_steps += 1
492
-
493
- # NOTE: if we are not accumulating gradients, we should skip the logging and optimization steps
494
- if should_accumulate_gradients:
495
- continue
496
-
497
- ########################################################
498
- #
499
- # Logging
500
- #
501
- ########################################################
502
-
503
- if batch_step % self.configs["monitoring"].logging.log_every_n_steps == 0:
504
- self._log_training_metrics(
505
- interval_loss=interval_loss,
506
- interval_steps=interval_steps,
507
- interval_inf_or_nan_count=interval_inf_or_nan_count,
508
- batch_step=batch_step,
509
- )
510
- interval_loss = torch.tensor(0.0, device=self.fabric.device)
511
- interval_steps = torch.tensor(0, device=self.fabric.device)
512
- interval_inf_or_nan_count = torch.tensor(0, device=self.fabric.device)
513
-
514
- ########################################################
515
- #
516
- # Learning Dynamics Checkpointing
517
- #
518
- ########################################################
519
-
520
- if batch_step % self.configs["checkpointing"].save_every_n_steps == 0:
521
- if self.should_compute_learning_dynamics:
522
- self.log(f"Step {batch_step} -- 📈 Saving Learning Dynamics")
523
-
524
- # Training Batch Learning Dynamics
525
- training_batch_dataset = Dataset.from_dict(training_batch)
526
-
527
- learning_dynamics_train_states = compute_learning_dynamics_states(
528
- checkpointing_config=self.configs["checkpointing"],
529
- fabric=self.fabric,
530
- model=self.model,
531
- dataset=training_batch_dataset,
532
- compute_gradients=True,
533
- )
534
-
535
- save_learning_dynamics_states(
536
- checkpointing_config=self.configs["checkpointing"],
537
- checkpoint_step=batch_step,
538
- prefix="train",
539
- fabric=self.fabric,
540
- learning_dynamics_states=learning_dynamics_train_states,
541
- learning_dynamics_dataset=training_batch_dataset,
542
- tokenizer=self.tokenizer,
543
- )
544
- training_batch = {
545
- "input_ids": []
546
- } # Resetting training_batch for next training batch
547
-
548
- # Validation Data Learning Dynamics
549
- if self.learning_dynamics_eval_dataset is not None:
550
- learning_dynamics_val_states = compute_learning_dynamics_states(
551
- checkpointing_config=self.configs["checkpointing"],
552
- fabric=self.fabric,
553
- model=self.model,
554
- dataset=self.learning_dynamics_eval_dataset,
555
- compute_gradients=True,
556
- )
557
- save_learning_dynamics_states(
558
- checkpointing_config=self.configs["checkpointing"],
559
- checkpoint_step=batch_step,
560
- prefix="val",
561
- fabric=self.fabric,
562
- learning_dynamics_states=learning_dynamics_val_states,
563
- )
564
-
565
- ########################################################
566
- #
567
- # Optimization step
568
- #
569
- ########################################################
570
-
571
- self.optimizer.step()
572
- self.optimizer.zero_grad()
573
- self.lr_scheduler.step()
574
-
575
- batch_step += 1
576
-
577
- ########################################################
578
- #
579
- # Training Checkpointing and evaluation
580
- #
581
- ########################################################
582
-
583
- if batch_step % self.configs["checkpointing"].save_every_n_steps == 0:
584
- self.log(f"Step {batch_step} -- 💾 Saving Checkpoint")
585
- save_checkpoint(
586
- configs=self.configs,
587
- checkpoint_step=batch_step,
588
- fabric=self.fabric,
589
- model=self.model,
590
- optimizer=self.optimizer,
591
- lr_scheduler=self.lr_scheduler,
592
- tokenizer=self.tokenizer,
593
- )
594
-
595
- if self.should_evaluate:
596
- evaluation_results = run_evaluation(
597
- evaluation_config=self.configs["evaluation"],
598
- checkpointing_config=self.configs["checkpointing"],
599
- fabric=self.fabric,
600
- model=self.model,
601
- )
602
- if evaluation_results is not None:
603
- self._log_evaluation_results(evaluation_results, batch_step)
604
- save_evaluation_results(
605
- checkpointing_config=self.configs["checkpointing"],
606
- fabric=self.fabric,
607
- evaluation_results=evaluation_results,
608
- checkpoint_step=batch_step,
609
- )
610
-
611
- # Break if we've reached training steps
612
- if batch_step >= self.configs["training"].max_steps:
613
- break
614
-
615
- return batch_step
616
-
617
- ########################################################
618
- #
619
- # Trainer Logging Functinalities
620
- #
621
- ########################################################
622
-
623
- def _log_training_metrics(
624
- self,
625
- interval_loss: torch.Tensor,
626
- interval_steps: torch.Tensor,
627
- interval_inf_or_nan_count: torch.Tensor,
628
- batch_step: int,
629
- ):
630
- """
631
- Gathers together the training metrics computed across all processes in distributed training
632
- and logs them in a tree-style format.
633
- """
634
- gathered_interval_loss = self.fabric.all_reduce(
635
- interval_loss, reduce_op="sum"
636
- ).item()
637
- gathered_interval_inf_or_nan_count = self.fabric.all_reduce(
638
- interval_inf_or_nan_count, reduce_op="sum"
639
- ).item()
640
- gathered_interval_steps = self.fabric.all_reduce(
641
- interval_steps, reduce_op="sum"
642
- ).item()
643
-
644
- avg_loss = (
645
- gathered_interval_loss / gathered_interval_steps
646
- if gathered_interval_steps > 0
647
- else float("inf")
648
- )
649
-
650
- self.fabric.log("train/loss", avg_loss, step=batch_step)
651
- self.fabric.log(
652
- "trainer/inf_or_nan_count",
653
- gathered_interval_inf_or_nan_count,
654
- step=batch_step,
655
- )
656
- self.fabric.log(
657
- "trainer/learning_rate",
658
- self.lr_scheduler.get_last_lr()[0],
659
- step=batch_step,
660
- )
661
-
662
- # Log to console in tree format
663
- self.log(f"Step {batch_step} -- 🔄 Training Metrics")
664
- self.log(f"├── Loss: {avg_loss:.4f}")
665
- self.log(f"├── Learning Rate: {self.lr_scheduler.get_last_lr()[0]:.2e}")
666
- self.log(f"└── Inf/NaN count: {gathered_interval_inf_or_nan_count}")
667
-
668
- def _log_evaluation_results(
669
- self, evaluation_results: Dict[str, Any], batch_step: int
670
- ):
671
- """Log model evaluation metrics to experiment tracking system and console."""
672
- self.log(f"Step {batch_step} -- 📊 Evaluation Results")
673
- for i, (metric, result) in enumerate(evaluation_results.items()):
674
- prefix = "└──" if i == len(evaluation_results) - 1 else "├──"
675
- self.log(f"{prefix} {metric}: {result}")
676
- self.fabric.log(f"eval/{metric}", result, step=batch_step)
677
-
678
- def _log_training_configuration(self):
679
- """
680
- Log training configuration details as well as runtime information about the hardware,
681
- software, and batch settings.
682
-
683
- This function is called at the beginning of the training loop to provide a summary of the
684
- training configuration.
685
- """
686
-
687
- total_params = sum(p.numel() for p in self.model.parameters())
688
- trainable_params = sum(
689
- p.numel() for p in self.model.parameters() if p.requires_grad
690
- )
691
- global_batch_size = self.configs["data"].dataloader.batch_size
692
- per_device_batch_size = self.train_dataloader.batch_size
693
- gradient_accumulation_steps = self.configs[
694
- "training"
695
- ].optimization.gradient_accumulation_steps
696
-
697
- device_type = ""
698
- fabric_device = str(self.fabric.device)
699
- if torch.cuda.is_available() and "cuda" in fabric_device:
700
- device_type = torch.cuda.get_device_name(self.fabric.device)
701
- elif torch.backends.mps.is_available() and "mps" in fabric_device:
702
- device_type = "MPS (Apple Silicon)"
703
- else:
704
- device_type = "CPU"
705
-
706
- training_config_path = os.path.join(
707
- self.configs["checkpointing"].runs_dir,
708
- self.configs["checkpointing"].run_name,
709
- "training_config.yaml",
710
- )
711
- if os.path.exists(training_config_path):
712
- self.log("=" * 50)
713
- self.log("✨ Training Configuration")
714
- self.log("=" * 50)
715
- training_config = yaml.safe_load(open(training_config_path, "r"))
716
- pretty_print_yaml_config(self.logger, training_config)
717
-
718
- self.log("=" * 50)
719
- self.log("⛭ Runtime Summary:")
720
- self.log("=" * 50)
721
- self.log(f"Starting from step: {self.initial_batch_step}")
722
-
723
- self.log("Model Setup:")
724
- self.log(f"└─ Total Parameters: {total_params:,}")
725
- self.log(f"└─ Trainable Parameters: {trainable_params:,}")
726
-
727
- self.log("Distributed Setup:")
728
- self.log(f"└─ Number of Devices: {self.fabric.world_size}")
729
- self.log(f"└─ Device Type: {device_type}")
730
- self.log(
731
- f"└─ Available Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
732
- if torch.cuda.is_available()
733
- else f"└─ Available Memory: {psutil.virtual_memory().total / 1e9:.2f} GB"
734
- )
735
-
736
- self.log("Software Setup:")
737
- self.log(f"└─ Python Version: {platform.python_version()}")
738
- self.log(f"└─ PyTorch Version: {torch.__version__}")
739
- self.log(
740
- f"└─ CUDA Version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}"
741
- )
742
- self.log(f"└─ Operating System: {platform.system()} {platform.release()}")
743
-
744
- self.log("Batch Size Configuration:")
745
- self.log(f"└─ Global Batch Size: {global_batch_size}")
746
- self.log(f"└─ Per Device Batch Size: {per_device_batch_size}")
747
- self.log(f"└─ Gradient Accumulation Steps: {gradient_accumulation_steps}")
748
- self.log("=" * 50)
749
-
750
- @rank_zero_only
751
- def log(self, msg: str, level: int = logging.INFO) -> None:
752
- """NOTE: Log messages only from rank zero process."""
753
- self.logger.log(level, msg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/utils/__init__.py DELETED
@@ -1,34 +0,0 @@
1
- """
2
- Utility package that contains functions for the training process, e.g. initialization, logging, etc.
3
- """
4
-
5
- # For convenience, we export the initialization functions here
6
- from .initialization import (
7
- initialize_configuration,
8
- initialize_dataloader,
9
- initialize_dataset,
10
- initialize_fabric,
11
- initialize_hf_checkpointing,
12
- initialize_logging,
13
- initialize_lr_scheduler,
14
- initialize_model,
15
- initialize_optimizer,
16
- initialize_run_dir,
17
- initialize_tokenizer,
18
- initialize_wandb,
19
- )
20
-
21
- __all__ = [
22
- "initialize_configuration",
23
- "initialize_dataloader",
24
- "initialize_dataset",
25
- "initialize_fabric",
26
- "initialize_hf_checkpointing",
27
- "initialize_logging",
28
- "initialize_lr_scheduler",
29
- "initialize_model",
30
- "initialize_optimizer",
31
- "initialize_run_dir",
32
- "initialize_tokenizer",
33
- "initialize_wandb",
34
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/utils/data.py DELETED
@@ -1,35 +0,0 @@
1
- """
2
- Utilities for data loading and processing.
3
- """
4
-
5
- from torch.utils.data import IterableDataset
6
-
7
-
8
- class ShardedIterableDataset(IterableDataset):
9
- """
10
- A super simple implementation of a sharded iterable dataset that enables DataParallelism
11
- across multiple workers. Ensures that each worker gets a unique shard of the dataset.
12
-
13
- NOTE: Also works fine if there is only one worker.
14
- """
15
-
16
- def __init__(self, dataset, rank, world_size):
17
- self.dataset = dataset
18
- self.rank = rank
19
- self.world_size = world_size
20
-
21
- def __iter__(self):
22
- iterator = iter(self.dataset)
23
- # NOTE: Start by skipping to this worker's shard
24
- for _ in range(self.rank):
25
- next(iterator)
26
-
27
- # NOTE: Yield every world_size-th item
28
- while True:
29
- try:
30
- yield next(iterator)
31
- # Skip other workers' samples
32
- for _ in range(self.world_size - 1):
33
- next(iterator)
34
- except StopIteration:
35
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/utils/initialization.py DELETED
@@ -1,702 +0,0 @@
1
- """
2
- Utilities for initializing components of the training process.
3
-
4
- Here, we initialize all of the components that are part of the learning process. From logging,
5
- and checkpointing to the optimizer to the dataset and the dataloader, this file contains the
6
- logic for setting up the classes and functions that are used in the training loop.
7
-
8
- As always, this code is meant to be basic. We hard-code the obvious defaults, and leave the
9
- more experimental stuff to you.
10
- """
11
-
12
- import logging
13
- import math
14
- import os
15
- import warnings
16
- from dataclasses import fields, is_dataclass
17
- from datetime import datetime
18
- from typing import Dict, Optional, Union
19
-
20
- import lightning as L
21
- import torch
22
- import yaml
23
- from datasets import Dataset, DownloadConfig, load_dataset
24
- from datasets import config as datasets_config
25
- from huggingface_hub import add_collection_item, create_branch, create_repo
26
- from lightning.fabric.loggers import Logger as FabricLogger
27
- from lightning.fabric.utilities.rank_zero import rank_zero_only
28
- from torch.utils.data import DataLoader
29
- from transformers import AutoTokenizer
30
-
31
- import wandb
32
- from src.config import (
33
- CheckpointingConfig,
34
- DataConfig,
35
- EvaluationConfig,
36
- ModelConfig,
37
- MonitoringConfig,
38
- TrainingConfig,
39
- )
40
- from src.model import PicoDecoder
41
- from src.training.utils.io import use_backoff
42
- from wandb.integration.lightning.fabric import WandbLogger
43
-
44
- warnings.filterwarnings(
45
- "ignore",
46
- message=".*This integration is tested and supported for lightning Fabric.*",
47
- )
48
- warnings.filterwarnings(
49
- "ignore",
50
- message=".*Please report any issues to.*",
51
- )
52
-
53
- ########################################################
54
- #
55
- # Basic Initialization
56
- #
57
- ########################################################
58
-
59
-
60
- def _apply_config_overrides(config, overrides: dict):
61
- """Recursively apply configuration overrides to a dataclass config object.
62
-
63
- Args:
64
- config: Base configuration object (must be a dataclass)
65
- overrides: Dictionary of override values matching config structure
66
-
67
- Returns:
68
- Modified config object with overrides to the config.
69
- """
70
- for field in fields(config):
71
- field_value = getattr(config, field.name)
72
- if is_dataclass(field_value):
73
- _apply_config_overrides(field_value, overrides.get(field.name, {}))
74
- else:
75
- if field.name in overrides:
76
- setattr(config, field.name, overrides[field.name])
77
- return config
78
-
79
-
80
- def initialize_configuration(
81
- config_path: Optional[str] = None,
82
- ) -> Dict[
83
- str,
84
- Union[
85
- DataConfig,
86
- ModelConfig,
87
- TrainingConfig,
88
- EvaluationConfig,
89
- MonitoringConfig,
90
- CheckpointingConfig,
91
- ],
92
- ]:
93
- """Initialize configuration objects with optional overrides from a YAML file.
94
-
95
- This function initializes all of the configuration objects, and then applies
96
- any overrides from the config_path file. If no config_path is provided,
97
- the function will use the default configuration objects.
98
-
99
- Args:
100
- config_path: Path to a YAML file containing configuration overrides.
101
-
102
- Returns:
103
- A dictionary containing the initialized configuration objects.
104
- """
105
- data_config = DataConfig()
106
- model_config = ModelConfig()
107
- training_config = TrainingConfig()
108
- evaluation_config = EvaluationConfig()
109
- monitoring_config = MonitoringConfig()
110
- checkpointing_config = CheckpointingConfig()
111
-
112
- if config_path:
113
- overrides = yaml.safe_load(open(config_path, "r"))
114
- data_config = _apply_config_overrides(data_config, overrides.get("data", {}))
115
- model_config = _apply_config_overrides(model_config, overrides.get("model", {}))
116
- training_config = _apply_config_overrides(
117
- training_config, overrides.get("training", {})
118
- )
119
- evaluation_config = _apply_config_overrides(
120
- evaluation_config, overrides.get("evaluation", {})
121
- )
122
- monitoring_config = _apply_config_overrides(
123
- monitoring_config, overrides.get("monitoring", {})
124
- )
125
- checkpointing_config = _apply_config_overrides(
126
- checkpointing_config, overrides.get("checkpointing", {})
127
- )
128
-
129
- configs = {
130
- "data": data_config,
131
- "model": model_config,
132
- "training": training_config,
133
- "evaluation": evaluation_config,
134
- "monitoring": monitoring_config,
135
- "checkpointing": checkpointing_config,
136
- }
137
-
138
- return configs
139
-
140
-
141
- def initialize_run_dir(checkpointing_config: CheckpointingConfig) -> str:
142
- """Initialize a directory for the current training run.
143
-
144
- Creates a unique directory for storing training, evaluation, and logging artifacts.
145
- If no run name is specified in the config, generates a timestamp-based name.
146
-
147
- Args:
148
- checkpointing_config: Configuration object containing run settings.
149
- NOTE: Must have a 'run_name' attribute that can be None, in which case
150
- a timestamp-based name will be generated.
151
-
152
- Returns:
153
- str: The path to the run directory.
154
- """
155
- run_name = checkpointing_config.run_name
156
- if run_name is None:
157
- run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
158
- checkpointing_config.run_name = run_name
159
-
160
- run_dir = os.path.join(checkpointing_config.runs_dir, run_name)
161
-
162
- os.makedirs(run_dir, exist_ok=True)
163
- return run_dir
164
-
165
-
166
- def initialize_fabric(
167
- training_config: TrainingConfig, wandb_logger: Optional[FabricLogger] = None
168
- ):
169
- """Initialize Lightning Fabric for distributed training.
170
-
171
- Sets up a Lightning Fabric instance with the specified configuration for
172
- handling distributed training, mixed precision, and logging.
173
-
174
- Args:
175
- training_config: Configuration object containing fabric settings
176
- (accelerator, precision, devices, etc.).
177
- wandb_logger: Optional weights and biases logger instance for experiment tracking
178
-
179
- Returns:
180
- L.Fabric: Initialized Lightning Fabric instance.
181
-
182
- Example:
183
- >>> fabric = initialize_fabric(training_config, wandb_logger)
184
- """
185
-
186
- total_devices = (
187
- training_config.fabric.num_devices * training_config.fabric.num_nodes
188
- )
189
-
190
- if total_devices > 1:
191
- strategy = "deepspeed_stage_2"
192
- else:
193
- strategy = "auto" # Sets up SingleDevice Strategy by default
194
-
195
- # NOTE: The strategy is set to use either DeepSpeed (Zero Stage 2) on multi-GPU,
196
- # or SingleDevice Strategy on single-GPU set ups. If you'd like to use a different strategy,
197
- # you can change the strategy flag in the fabric initialization, but be aware that this might
198
- # cause issues with checkpointing, evaluation, etc.
199
-
200
- fabric = L.Fabric(
201
- accelerator=training_config.fabric.accelerator,
202
- precision=training_config.fabric.precision,
203
- devices=training_config.fabric.num_devices,
204
- num_nodes=training_config.fabric.num_nodes,
205
- loggers=[wandb_logger] if wandb_logger is not None else None,
206
- strategy=strategy,
207
- )
208
-
209
- fabric.launch()
210
-
211
- return fabric
212
-
213
-
214
- ########################################################
215
- #
216
- # Dataset and Tokenization Initialization
217
- #
218
- ########################################################
219
-
220
-
221
- @use_backoff(max_retries=20)
222
- def initialize_dataset(
223
- data_config: DataConfig,
224
- fabric: L.Fabric,
225
- initial_batch_step: Optional[int] = 0,
226
- return_fast_forward_steps: bool = False,
227
- ):
228
- """Initialize dataset based on the given config.
229
-
230
- This function will return a dataset object, and optionally a fast_forward_steps value.
231
-
232
- The fast_forward_steps value is the number of steps that we need to fast-forward an iterator by,
233
- so that we can continue from a ertain batch of data we would have seen had training not previously
234
- stopped. Depending on how the dataset is loaded, the amount of steps to fast-forward may be
235
- different from the initial_batch_step value.
236
-
237
- NOTE: This functionality is primarily useful for streaming datasets (which for large
238
- datasets is most of the time).
239
-
240
- Args:
241
- data_config: Configuration object containing dataset settings.
242
- fabric: A Lightning Fabric instance.
243
- initial_batch_step: The initial batch step to fast-forward to.
244
- return_fast_forward_steps: Whether to return the fast-forward steps value.
245
-
246
- Returns:
247
- Dataset: Initialized dataset object.
248
- Optional[int]: Number of steps to fast-forward the iterator by, if return_fast_forward_steps is True.
249
- """
250
-
251
- datasets_config.STREAMING_READ_MAX_RETRIES = 40 # default is 20
252
- datasets_config.STREAMING_READ_RETRY_INTERVAL = 10 # default is 5
253
- download_config = DownloadConfig(
254
- max_retries=20, # default is 1 and can lead to pre-mature HTTPS errors
255
- )
256
-
257
- fast_forward_steps = 0
258
-
259
- if data_config.dataset.name == "pico-lm/pretokenized-dolma":
260
- # NOTE: We know that the dataset is sharded into 10,000 shards, so we can easily compute
261
- # the data file that we need to load in that contains the batch of data at
262
- # initial_batch_step.
263
-
264
- if initial_batch_step is not None:
265
- examples_per_shard = 20_480
266
- total_shards = 10_000
267
- batches_per_shard = examples_per_shard // data_config.dataloader.batch_size
268
- shard_idx = initial_batch_step // batches_per_shard
269
-
270
- data_files = [
271
- f"data/train-{str(_shard_idx).zfill(5)}-of-{total_shards}.parquet"
272
- for _shard_idx in range(shard_idx, total_shards)
273
- ]
274
-
275
- fast_forward_steps = initial_batch_step % batches_per_shard
276
- else:
277
- data_files = None
278
-
279
- base_dataset = load_dataset(
280
- data_config.dataset.name,
281
- split="train",
282
- streaming=True,
283
- data_files=data_files,
284
- download_config=download_config,
285
- )
286
- else:
287
- # NOTE: For other datasets, you might want to add some custom loading logic, especially
288
- # to help with loading or fast-forwarding to the correct batch.
289
-
290
- base_dataset = load_dataset(
291
- data_config.dataset.name,
292
- split="train",
293
- streaming=True,
294
- download_config=download_config,
295
- )
296
-
297
- if data_config.dataset.name == "pico-lm/pretokenized-dolma":
298
- from .data import ShardedIterableDataset
299
-
300
- # NOTE: We wrap the dataset in a ShardedIterableDataset, which is a custom class that
301
- # allows us to shard an iterable dataset across multiple processes. This is useful for
302
- # distributed training, where we want data-parallelism.
303
- dataset = ShardedIterableDataset(
304
- base_dataset, fabric.global_rank, fabric.world_size
305
- )
306
- else:
307
- dataset = base_dataset
308
-
309
- if return_fast_forward_steps:
310
- return dataset, fast_forward_steps
311
- else:
312
- return dataset
313
-
314
-
315
- def initialize_tokenizer(data_config: DataConfig):
316
- """Initialize the tokenizer for text processing.
317
-
318
- This function can be extended to include custom tokenization logic.
319
-
320
- Args:
321
- data_config: Configuration object containing tokenizer settings.
322
-
323
- Returns:
324
- AutoTokenizer: A HuggingFace tokenizer instance.
325
- """
326
-
327
- return AutoTokenizer.from_pretrained(data_config.tokenizer.name)
328
-
329
-
330
- def initialize_dataloader(
331
- data_config: DataConfig,
332
- training_config: TrainingConfig,
333
- fabric: L.Fabric,
334
- dataset: Dataset,
335
- ):
336
- """Initialize the DataLoader for efficient batch processing.
337
-
338
- Creates a PyTorch DataLoader that handles batching and data loading for training.
339
- Configured specifically for streaming tokenized text datasets.
340
-
341
- You might also want to extend this function to add a sampler, or some sort of custom
342
- collate function. For the default dataset, we don't need any of this, because the data are
343
- pre-shuffled, and pre-tokenized.
344
-
345
- Args:
346
- data_config: Configuration object containing dataloader settings.
347
- training_config: Configuration object containing training settings.
348
- fabric: A Lightning Fabric instance.
349
- dataset: A HuggingFace Dataset object containing tokenized text data.
350
- Expected to have 'input_ids' field in its items.
351
-
352
- Returns:
353
- DataLoader: PyTorch DataLoader instance configured for the dataset.
354
- """
355
-
356
- def _collate_fn(batch):
357
- return {"input_ids": [entry["input_ids"] for entry in batch]}
358
-
359
- sub_batch_size = data_config.dataloader.batch_size // (
360
- fabric.world_size * training_config.optimization.gradient_accumulation_steps
361
- )
362
-
363
- # NOTE: We use the sub-batch size for the dataloader, which is the full batch size
364
- # divided by the gradient accumulation steps. This ensures that the effective batch size
365
- # is correct.
366
-
367
- return DataLoader(
368
- dataset,
369
- batch_size=sub_batch_size,
370
- shuffle=False, # Keep sequential for streaming datasets
371
- pin_memory=True, # Speeds up transfer to GPU
372
- collate_fn=_collate_fn,
373
- )
374
-
375
-
376
- ########################################################
377
- #
378
- # Model Initialization
379
- #
380
- ########################################################
381
-
382
-
383
- def initialize_model(model_config: ModelConfig):
384
- """Initialize the model for training.
385
-
386
- Loads in a given model implemented in the `src.model` package and returns it.
387
-
388
- NOTE: out of the box we currently only support the PicoDecoder model (a causal transformer
389
- language model). If you'd like to implement your own model, you can do so by adding a new
390
- model class in the `src.model` package, and then adding a new entry here.
391
-
392
- Args:
393
- model_config: Configuration object containing model settings.
394
-
395
- Returns:
396
- PyTorch model instance.
397
-
398
- """
399
- if model_config.model_type == "pico_decoder":
400
- return PicoDecoder(model_config)
401
- else:
402
- raise ValueError(f"Invalid model type: {model_config.model_type}")
403
-
404
-
405
- ########################################################
406
- #
407
- # Optimizer and Scheduler
408
- #
409
- ########################################################
410
-
411
-
412
- def initialize_optimizer(training_config: TrainingConfig, model: torch.nn.Module):
413
- """Initialize the optimizer for model training.
414
-
415
- Creates an optimizer instance based on the configuration settings.
416
-
417
- Add whatever other optimizers you want here.
418
-
419
- Args:
420
- training_config: Configuration object containing optimizer settings.
421
- Must have:
422
- - optimization.optimizer (str): Name of the optimizer ("adamw")
423
- - optimization.lr (float): Learning rate for the optimizer
424
- model: PyTorch model whose parameters will be optimized.
425
-
426
- Returns:
427
- torch.optim.Optimizer: Configured optimizer instance.
428
-
429
- """
430
-
431
- if training_config.optimization.optimizer == "adamw":
432
- optimizer = torch.optim.AdamW(
433
- model.parameters(), lr=training_config.optimization.lr
434
- )
435
- else:
436
- raise ValueError(f"Invalid optimizer: {training_config.optimization.optimizer}")
437
-
438
- return optimizer
439
-
440
-
441
- def initialize_lr_scheduler(
442
- training_config: TrainingConfig, optimizer: torch.optim.Optimizer
443
- ):
444
- """Initialize a learning rate scheduler with warmup and decay.
445
-
446
- The default is a learning rate scheduler that implements a linear warmup followed by
447
- linear decay. The learning rate increases linearly from 0 to the initial lr
448
- during warmup, then decreases linearly to 0 during the remaining steps.
449
-
450
- Add other types of learning rate schedulers here.
451
-
452
- Args:
453
- training_config: Configuration object containing optimizer and scheduler settings.
454
- optimizer: PyTorch optimizer whose learning rate will be scheduled.
455
-
456
- Returns:
457
- torch.optim.lr_scheduler.LambdaLR: Learning rate scheduler instance.
458
- """
459
-
460
- if training_config.optimization.lr_scheduler == "linear_with_warmup":
461
- # Credit where credit is due:
462
- # https://github.com/huggingface/transformers/blob/e71a01a104dd663c730e494eb0b6467bb51df357/src/transformers/optimization.py#L102
463
- def _lr_lambda(curr_step, num_warmup_steps, max_steps):
464
- if curr_step < num_warmup_steps:
465
- return float(curr_step) / float(max(1, num_warmup_steps))
466
- else:
467
- return max(
468
- 0.0,
469
- float(max_steps - curr_step)
470
- / float(max(1, max_steps - num_warmup_steps)),
471
- )
472
-
473
- lr_lambda = lambda step: _lr_lambda( # noqa: E731
474
- step,
475
- training_config.optimization.lr_warmup_steps,
476
- training_config.max_steps,
477
- )
478
- lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
479
- optimizer,
480
- lr_lambda,
481
- )
482
- elif training_config.optimization.lr_scheduler == "cosine":
483
- # Cosine decay with warmup: linear warmup followed by cosine decay
484
- # This provides sustained learning over long training runs
485
- def _cosine_lr_lambda(curr_step, num_warmup_steps, max_steps):
486
- if curr_step < num_warmup_steps:
487
- # Linear warmup
488
- return float(curr_step) / float(max(1, num_warmup_steps))
489
- else:
490
- # Cosine decay to 0.1 * initial_lr (not to 0)
491
- progress = float(curr_step - num_warmup_steps) / float(
492
- max(1, max_steps - num_warmup_steps)
493
- )
494
- return max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
495
-
496
- lr_lambda = lambda step: _cosine_lr_lambda( # noqa: E731
497
- step,
498
- training_config.optimization.lr_warmup_steps,
499
- training_config.max_steps,
500
- )
501
- lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
502
- optimizer,
503
- lr_lambda,
504
- )
505
- else:
506
- raise ValueError(
507
- f"Invalid learning rate scheduler: {training_config.optimization.lr_scheduler}"
508
- )
509
-
510
- return lr_scheduler
511
-
512
-
513
- ########################################################
514
- #
515
- # Experiment Monitoring (Logging, Experiment Tracking, etc.)
516
- #
517
- ########################################################
518
-
519
-
520
- def _initialize_log_file(checkpointing_config: CheckpointingConfig) -> str:
521
- """Create and initialize a timestamped log file in the run's log directory.
522
-
523
- Sets up a log file with a unique timestamp in the run's logging directory.
524
- Creates the necessary directory structure if it doesn't exist.
525
-
526
- Directory Structure:
527
- {checkpointing_config.runs_dir}/
528
- └── {checkpointing_config.run_name}/
529
- └── {checkpointing_config.logs_dir}/
530
- └── log_YYYYMMDD_HHMMSS.txt
531
-
532
- Args:
533
- checkpointing_config: Configuration object containing checkpointing settings.
534
-
535
- Returns:
536
- str: Absolute path to the created log file.
537
-
538
- """
539
-
540
- run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
541
- logs_dir = os.path.join(run_dir, checkpointing_config.logs_dir)
542
- os.makedirs(logs_dir, exist_ok=True)
543
-
544
- # datetime stamp
545
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
546
- log_file_name = f"log_{timestamp}.log"
547
- log_file_path = os.path.join(logs_dir, log_file_name)
548
-
549
- open(log_file_path, "w").close() # Create an empty log file
550
-
551
- return log_file_path
552
-
553
-
554
- @use_backoff()
555
- def initialize_wandb(
556
- monitoring_config: MonitoringConfig, checkpointing_config: CheckpointingConfig
557
- ):
558
- """Initialize Weights and Biases.
559
-
560
- This function initializes Weights and Biases based on the configuration settings.
561
-
562
- Args:
563
- monitoring_config: Configuration object containing monitoring settings.
564
- checkpointing_config: Configuration object containing checkpointing settings.
565
-
566
- Returns:
567
- Optional[WandbLogger]: An experiment tracker instance.
568
- """
569
-
570
- assert (
571
- monitoring_config.wandb.project is not None
572
- and monitoring_config.wandb.project != ""
573
- ), "Wandb project must be provided if wandb is to be used."
574
- assert (
575
- monitoring_config.wandb.entity is not None
576
- and monitoring_config.wandb.entity != ""
577
- ), "Wandb entity must be provided if wandb is to be used."
578
-
579
- _run_id = None
580
- if checkpointing_config.training.auto_resume:
581
- # If we are loading a checkpoint, we can try to find the run id of the previous run
582
- previous_runs = wandb.Api().runs(
583
- path=f"{monitoring_config.wandb.entity}/{monitoring_config.wandb.project}",
584
- filters={"display_name": checkpointing_config.run_name},
585
- )
586
- try:
587
- if len(previous_runs) == 1:
588
- _run_id = previous_runs[0].id
589
- except ValueError:
590
- pass
591
-
592
- wandb_logger = WandbLogger(
593
- project=monitoring_config.wandb.project,
594
- entity=monitoring_config.wandb.entity,
595
- id=_run_id,
596
- name=checkpointing_config.run_name,
597
- )
598
-
599
- return wandb_logger
600
-
601
-
602
- @rank_zero_only
603
- def initialize_logging(
604
- monitoring_config: MonitoringConfig,
605
- checkpointing_config: CheckpointingConfig,
606
- fabric: L.Fabric,
607
- ):
608
- """Initialize logging system with default logging, to file and console.
609
-
610
- The default logging system uses a file handler and a stream handler.
611
-
612
- NOTE: this function is only called on rank 0.
613
-
614
- Args:
615
- monitoring_config: Configuration object containing monitoring settings.
616
- checkpointing_config: Configuration object containing checkpointing settings.
617
-
618
- Returns:
619
- logger: Standard Python logger configured for file and console output
620
- """
621
-
622
- # ---- Standard Local Logger ---- #
623
- logger = logging.getLogger("pico-train")
624
- logger.setLevel(logging.INFO)
625
-
626
- # Create file handler
627
- log_file_path = _initialize_log_file(checkpointing_config)
628
- file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
629
- file_handler.setLevel(monitoring_config.logging.log_level)
630
-
631
- # Create formatter and add it to the handler
632
- formatter = logging.Formatter(
633
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
634
- datefmt="%Y-%m-%d %H:%M:%S",
635
- )
636
- file_handler.setFormatter(formatter)
637
-
638
- # Add the handler to the logger
639
- logger.addHandler(file_handler)
640
-
641
- # Add a stream handler for console output
642
- stream_handler = logging.StreamHandler()
643
- stream_handler.setLevel(monitoring_config.logging.log_level)
644
- stream_handler.setFormatter(formatter)
645
- logger.addHandler(stream_handler)
646
-
647
- return logger
648
-
649
-
650
- ########################################################
651
- #
652
- # HuggingFace/Remote Checkpointing
653
- #
654
- ########################################################
655
-
656
-
657
- @rank_zero_only
658
- @use_backoff()
659
- def initialize_hf_checkpointing(
660
- checkpointing_config: CheckpointingConfig, fabric: L.Fabric
661
- ):
662
- """Initialize HuggingFace Checkpointing.
663
-
664
- Creates a HuggingFace repository if it doesn't exist, and creates a branch named after the run.
665
-
666
- NOTE: this function is only called on rank 0.
667
-
668
- Args:
669
- checkpointing_config: Configuration object containing checkpointing settings; must have
670
- a 'hf_checkpoint' attribute that specifies the HuggingFace repository id and
671
- collection slug (if applicable) to save the checkpoint to.
672
-
673
- Raises:
674
- RuntimeError: If unable to create HuggingFace repository after multiple attempts.
675
- """
676
-
677
- huggingface_repo_id = checkpointing_config.hf_checkpoint.repo_id
678
- assert (
679
- huggingface_repo_id is not None and huggingface_repo_id != ""
680
- ), "hf_checkpoint.repo_id must be provided."
681
-
682
- repo = create_repo(huggingface_repo_id, exist_ok=True)
683
-
684
- # can create a repo without a specified namespace (will default to username)
685
- # however the rest of the HF calls need the fully qualified name
686
- # this is returned by create repo, so we update the config for later calls
687
- checkpointing_config.hf_checkpoint.repo_id = repo.repo_id
688
- huggingface_repo_id = repo.repo_id
689
-
690
- if checkpointing_config.hf_checkpoint.collection_slug:
691
- add_collection_item(
692
- checkpointing_config.hf_checkpoint.collection_slug,
693
- huggingface_repo_id,
694
- repo.repo_type,
695
- exists_ok=True,
696
- )
697
-
698
- create_branch(
699
- repo_id=huggingface_repo_id,
700
- branch=checkpointing_config.run_name,
701
- exist_ok=True,
702
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/utils/io.py DELETED
@@ -1,52 +0,0 @@
1
- """Defines a retry wrapper for io operations."""
2
-
3
- import time
4
- from functools import wraps
5
-
6
-
7
- def use_backoff(max_retries=2, initial_delay=1, backoff_factor=2):
8
- """
9
- Universal retry wrapper with exponential backoff for any function, but primarily for loading
10
- and storing HuggingFace datasets and objects.
11
-
12
- Example usage:
13
-
14
- >>> @use_backoff(max_retries=10, delay=1, backoff_factor=2)
15
- >>> def important_io_operation(x):
16
- >>> return x + 1
17
-
18
- Args:
19
- fn: Function to execute
20
- max_retries: Maximum number of retry attempts (default: 3)
21
- delay: Initial delay between retries in seconds (default: 1)
22
- backoff_factor: Multiplier for delay between retries (default: 2)
23
-
24
- Returns:
25
- A wrapper function that will retry the function fn up to max_retries times with exponential backoff
26
-
27
- Raises:
28
- Exception: If all retries fail
29
- """
30
-
31
- def _decorator(fn):
32
- @wraps(fn)
33
- def wrapper(*args, **kwargs):
34
- current_delay = initial_delay
35
- last_exception = None
36
-
37
- for attempt in range(max_retries):
38
- try:
39
- return fn(*args, **kwargs)
40
- except Exception as e:
41
- last_exception = e
42
- if attempt < max_retries - 1: # Don't sleep on the last attempt
43
- time.sleep(current_delay)
44
- current_delay *= backoff_factor
45
-
46
- raise Exception(
47
- f"IO Operation failed after {max_retries} attempts: {str(last_exception)}"
48
- )
49
-
50
- return wrapper
51
-
52
- return _decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/utils/logging.py DELETED
@@ -1,48 +0,0 @@
1
- """
2
- Miscellaneous logging utilities.
3
- """
4
-
5
- from io import StringIO
6
-
7
- import yaml
8
- from lightning.fabric.utilities.rank_zero import rank_zero_only
9
- from rich.console import Console
10
- from rich.panel import Panel
11
-
12
-
13
- @rank_zero_only
14
- def pretty_print_yaml_config(logger, config: dict) -> None:
15
- """
16
- Pretty print config with rich formatting. Assumes that the config is already saved as a
17
- dictionary - this can be done by calling `asdict` on the dataclass or loading in the config
18
- from a yaml file.
19
-
20
- NOTE: this function is only called on rank 0.
21
-
22
- Args:
23
- logger: Logger object to log the formatted output to.
24
- config: Dictionary containing the config to pretty print.
25
- """
26
- # Create string buffer
27
- output = StringIO()
28
- console = Console(file=output, force_terminal=False)
29
-
30
- # Convert to YAML string first
31
- yaml_str = yaml.dump(
32
- config, default_flow_style=False, sort_keys=False, Dumper=yaml.SafeDumper
33
- )
34
-
35
- # Create formatted panel
36
- panel = Panel(
37
- yaml_str,
38
- border_style="blue",
39
- padding=(0, 1), # Reduced padding
40
- expand=False, # Don't expand to terminal width
41
- )
42
-
43
- # Print to buffer
44
- console.print(panel)
45
-
46
- # Log the formatted output
47
- for line in output.getvalue().splitlines():
48
- logger.info(line)