Source code for modalities.trainer

from enum import Enum
from typing import Callable, Optional

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from modalities.batch import DatasetBatch, EvaluationResultBatch, ResultItem
from modalities.checkpointing.stateful.app_state import AppState
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.loss_functions import Loss
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
from modalities.training.training_progress import TrainingProgress
from modalities.util import Aggregator, TimeRecorder, print_rank_0
from modalities.utils.mfu import MFUCalculatorABC


[docs] class ThroughputAggregationKeys(Enum): NUM_SAMPLES = "NUM_SAMPLES" FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME"
[docs] class Trainer: def __init__( self, global_rank: int, progress_publisher: MessagePublisher[ProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], gradient_acc_steps: int, global_num_tokens_per_train_step: int, num_seen_train_steps: int, global_num_seen_tokens: int, num_target_steps: int, num_target_tokens: int, gradient_clipper: GradientClipperIF, mfu_calculator: Optional[MFUCalculatorABC] = None, ) -> None: """ Initializes the Trainer object. Args: global_rank (int): The global rank to which operates the trainer object. progress_publisher (MessagePublisher[ProgressUpdate]): The publisher for progress updates. evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): The publisher for evaluation result batches. gradient_acc_steps (int): The number of gradient accumulation steps. global_num_tokens_per_train_step (int): The number of global tokens per training step. num_seen_train_steps (int): The number of training steps already seen. global_num_seen_tokens (int): The number of tokens already seen. num_target_steps (int): The target number of training steps. num_target_tokens (int): The target number of tokens. gradient_clipper (GradientClipperIF): The gradient clipper. mfu_calculator (Optional[MFUCalculatorABC]): The MFU calculator. Returns: None """ self.global_rank = global_rank self.progress_publisher = progress_publisher self.evaluation_result_publisher = evaluation_result_publisher self.gradient_acc_steps = gradient_acc_steps self.global_num_tokens_per_train_step = global_num_tokens_per_train_step self.num_seen_train_steps = num_seen_train_steps self.num_target_steps = num_target_steps self.num_target_tokens = num_target_tokens self.global_num_seen_tokens = global_num_seen_tokens self.gradient_clipper = gradient_clipper self.mfu_calculator = mfu_calculator @staticmethod def _get_num_train_steps_done(micro_batch_id: int, gradient_acc_steps: int) -> int: """ Calculates the number of training steps done based on the micro batch ID and gradient accumulation steps. Args: micro_batch_id (int): The ID of the current micro batch. gradient_acc_steps (int): The number of gradient accumulation steps. Returns: int: The number of training steps done. """ return (micro_batch_id + 1) // gradient_acc_steps def _train_batch( self, batch: DatasetBatch, model: FSDP, optimizer: Optimizer, scheduler: LRScheduler, loss_fun: Loss, micro_batch_id: int, ) -> tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: """ Conducts a training step on batch of data. Args: batch (DatasetBatch): The input batch of data. model (FSDP): The model to train. optimizer (Optimizer): The optimizer used for training. scheduler (LRScheduler): The learning rate scheduler. loss_fun (Loss): The loss function used for training. micro_batch_id (int): The ID of the micro batch. Returns: tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple containing the following: - step_performed (bool): Indicates whether a training step was performed. - num_train_steps_done (int): The number of training steps done. - loss (torch.Tensor): The computed loss. - gradient_norm_score (Optional[torch.Tensor]): The gradient norm score, if a training step was performed otherwise return None. """ result_batch = model_predict_batch(model=model, batch=batch) loss = loss_fun(result_batch) (loss / self.gradient_acc_steps).backward() if (micro_batch_id + 1) % self.gradient_acc_steps == 0: gradient_norm_score = self.gradient_clipper.clip_gradients() optimizer.step() scheduler.step() optimizer.zero_grad() step_performed = True else: step_performed = False gradient_norm_score = None num_train_steps_done = Trainer._get_num_train_steps_done( micro_batch_id=micro_batch_id, gradient_acc_steps=self.gradient_acc_steps ) return step_performed, num_train_steps_done, loss, gradient_norm_score
[docs] def train( self, app_state: AppState, train_loader: LLMDataLoader, loss_fun: Loss, training_log_interval_in_steps: int, evaluation_callback: Callable[[TrainingProgress], None], checkpointing_callback: Callable[[TrainingProgress], None], ): """ Trains the model. Args: app_state (AppState): The application state containing the model, optimizer and lr scheduler. train_loader (LLMDataLoader): The data loader containing the training data. loss_fun (Loss): The loss function used for training. training_log_interval_in_steps (int): The interval at which training progress is logged. evaluation_callback (Callable[[TrainingProgress], None]): A callback function for evaluation. checkpointing_callback (Callable[[TrainingProgress], None]): A callback function for checkpointing. Returns: None """ model = app_state.model optimizer = app_state.optimizer lr_scheduler = app_state.lr_scheduler model.train() cumulated_losses = self._reset_tracked_losses() # throughput thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # batch loop batch: DatasetBatch # TODO: why do we need a barrier here? dist.barrier() forward_backward_time_recorder = TimeRecorder() forward_backward_time_recorder.start() gradient_norm_scores = [] # run evaluation callback and checkpointing callback before the first optimizer step evaluation_callback(num_train_steps_done=self.num_seen_train_steps) training_progress = TrainingProgress( num_seen_steps_previous_run=self.num_seen_train_steps, num_seen_tokens_previous_run=self.global_num_seen_tokens, num_seen_steps_current_run=0, num_seen_tokens_current_run=0, num_target_steps=self.num_target_steps, num_target_tokens=self.num_target_tokens, ) checkpointing_callback(training_progress=training_progress) num_steps_todo = self.num_target_steps - self.num_seen_train_steps num_batches_todo = num_steps_todo * self.gradient_acc_steps # Because we might resume training, we add the starting batch id of the data loader for _, (micro_batch_id, batch) in zip(range(num_batches_todo), enumerate(train_loader)): # Train single batch ( step_performed, num_train_steps_done, batch_loss, gradient_norm_score, ) = self._train_batch( batch=batch, model=model, optimizer=optimizer, scheduler=lr_scheduler, loss_fun=loss_fun, micro_batch_id=micro_batch_id, ) forward_backward_time_recorder.stop() training_progress.num_seen_steps_current_run = num_train_steps_done training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done # Save the batch loss cumulated_losses[0] += batch_loss.item() # This works, because we always drop the last batch in case it has less samples than the batch size cumulated_losses[-1] += 1 # number of local batches # gradient norm is already synced across all ranks if gradient_norm_score is not None: gradient_norm_scores.append(gradient_norm_score.item()) batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) self._publish_progress( progress_publisher=self.progress_publisher, num_train_steps_done=training_progress.num_seen_steps_total, dataloader_tag=train_loader.dataloader_tag, ) # Check if model performance should be logged if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed: forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device) forward_backward_time_recorder.reset() thoughput_aggregator.add_value( key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time ) synced_num_samples = thoughput_aggregator.get_all_reduced_value(ThroughputAggregationKeys.NUM_SAMPLES) synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value( ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX ) synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP # add the loss and gradient norm for the LAST batch cumulated_losses[1] = batch_loss.item() reduced_losses = Reducer.reduce( tensor=cumulated_losses, operation=dist.ReduceOp.SUM, # 1.) summed batch loss / (num batches * world size) # 2.) last batch loss / world size post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]), ) train_loss_avg, train_loss_last_batch = ( reduced_losses[0], reduced_losses[1], ) losses = { "train loss avg": ResultItem(train_loss_avg, decimal_places=2), "train loss last": ResultItem(train_loss_last_batch, decimal_places=2), } consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total]) metrics = { "consumed tokens": ResultItem(consumed_tokens, 0), "grad norm avg": ResultItem(torch.mean(torch.Tensor(gradient_norm_scores)), 2), "grad norm last": ResultItem(torch.tensor(gradient_norm_scores[-1]), 2), } gradient_norm_scores = [] mfu_score = torch.tensor(-1.0) if self.mfu_calculator is not None: mfu_score = self.mfu_calculator.compute(num_samples_per_second=synced_num_samples_per_second) training_metrics = EvaluationResultBatch( losses=losses, metrics=metrics, # TODO: hardcoded metric key throughput_metrics={ "train samples/s": ResultItem(synced_num_samples_per_second, 1), "train mfu (16-bit)": ResultItem(mfu_score, 2), "lr mean": ResultItem(torch.tensor(lr_scheduler.get_last_lr()).mean()), }, dataloader_tag=train_loader.dataloader_tag, num_train_steps_done=training_progress.num_seen_steps_total, ) print_rank_0(training_metrics) self._publish_evaluation_result( evaluation_result_publisher=self.evaluation_result_publisher, evaluation_result=training_metrics, ) thoughput_aggregator.remove_keys() cumulated_losses = self._reset_tracked_losses() if step_performed: evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) # we start the time recoder here again to also capture the time spend loading # via the dataloader. forward_backward_time_recorder.start()
def _reset_tracked_losses(self): # Initializes and returns a tensor representing the cumulated loss and gradient norm. # The tensor is initialized with zeros and its device is set based on the availability of CUDA. cumulated_loss_and_gradient_norm = torch.zeros(3) if torch.cuda.is_available(): cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to(torch.device("cuda")) else: cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to("cpu") return cumulated_loss_and_gradient_norm @staticmethod def _publish_progress( progress_publisher: MessagePublisher[ProgressUpdate], num_train_steps_done: int, dataloader_tag: str, ): # Publishes the progress of the training, i.e., number of training steps done. payload = ProgressUpdate( num_steps_done=num_train_steps_done, experiment_status=ExperimentStatus.TRAIN, dataloader_tag=dataloader_tag, ) progress_publisher.publish_message(payload=payload, message_type=MessageTypes.BATCH_PROGRESS_UPDATE) @staticmethod def _publish_evaluation_result( evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], evaluation_result: EvaluationResultBatch, ): # Publishes the evaluation result. evaluation_result_publisher.publish_message( payload=evaluation_result, message_type=MessageTypes.EVALUATION_RESULT )