from datetime import datetime
from enum import Enum
from typing import Callable, Optional
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from modalities.batch import DatasetBatch, EvaluationResultBatch, ResultItem
from modalities.checkpointing.stateful.app_state import AppState
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.loss_functions import Loss
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
from modalities.training.training_progress import TrainingProgress
from modalities.util import Aggregator, TimeRecorder, print_rank_0
from modalities.utils.mfu import MFUCalculatorABC
[docs]
class ThroughputAggregationKeys(Enum):
    NUM_SAMPLES = "NUM_SAMPLES"
    FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME" 
[docs]
class Trainer:
    def __init__(
        self,
        global_rank: int,
        progress_publisher: MessagePublisher[ProgressUpdate],
        evaluation_result_publisher: MessagePublisher[EvaluationResultBatch],
        gradient_acc_steps: int,
        global_num_tokens_per_train_step: int,
        dp_degree: int,
        num_seen_train_steps: int,
        global_num_seen_tokens: int,
        num_target_steps: int,
        num_target_tokens: int,
        gradient_clipper: GradientClipperIF,
        mfu_calculator: Optional[MFUCalculatorABC] = None,
    ) -> None:
        """
        Initializes the Trainer object.
        Args:
            global_rank (int): The global rank to which operates the trainer object.
            progress_publisher (MessagePublisher[ProgressUpdate]): The publisher for progress updates.
            evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]):
                The publisher for evaluation result batches.
            gradient_acc_steps (int): The number of gradient accumulation steps.
            global_num_tokens_per_train_step (int): The number of global tokens per training step.
            num_seen_train_steps (int): The number of training steps already seen.
            global_num_seen_tokens (int): The number of tokens already seen.
            num_target_steps (int): The target number of training steps.
            num_target_tokens (int): The target number of tokens.
            gradient_clipper (GradientClipperIF): The gradient clipper.
            mfu_calculator (Optional[MFUCalculatorABC]): The MFU calculator.
        Returns:
            None
        """
        self.global_rank = global_rank
        self.progress_publisher = progress_publisher
        self.evaluation_result_publisher = evaluation_result_publisher
        self.gradient_acc_steps = gradient_acc_steps
        self.global_num_tokens_per_train_step = global_num_tokens_per_train_step
        self.dp_degree = dp_degree
        self.num_seen_train_steps = num_seen_train_steps
        self.num_target_steps = num_target_steps
        self.num_target_tokens = num_target_tokens
        self.global_num_seen_tokens = global_num_seen_tokens
        self.gradient_clipper = gradient_clipper
        self.mfu_calculator = mfu_calculator
    @staticmethod
    def _get_num_train_steps_done(micro_batch_id: int, gradient_acc_steps: int) -> int:
        """
        Calculates the number of training steps done based on the micro batch ID and gradient accumulation steps.
        Args:
            micro_batch_id (int): The ID of the current micro batch.
            gradient_acc_steps (int): The number of gradient accumulation steps.
        Returns:
            int: The number of training steps done.
        """
        return (micro_batch_id + 1) // gradient_acc_steps
    def _train_batch(
        self,
        batch: DatasetBatch,
        model: FSDP,
        optimizer: Optimizer,
        scheduler: LRScheduler,
        loss_fun: Loss,
        micro_batch_id: int,
    ) -> tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]:
        """
        Conducts a training step on batch of data.
        Args:
            batch (DatasetBatch): The input batch of data.
            model (FSDP): The model to train.
            optimizer (Optimizer): The optimizer used for training.
            scheduler (LRScheduler): The learning rate scheduler.
            loss_fun (Loss): The loss function used for training.
            micro_batch_id (int): The ID of the micro batch.
        Returns:
            tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
                A tuple containing the following:
                    - step_performed (bool): Indicates whether a training step was performed.
                    - num_train_steps_done (int): The number of training steps done.
                    - loss (torch.Tensor): The computed loss.
                    - gradient_norm_score (Optional[torch.Tensor]): The gradient norm score,
                        if a training step was performed otherwise return None.
        """
        result_batch = model_predict_batch(model=model, batch=batch)
        loss = loss_fun(result_batch)
        (loss / self.gradient_acc_steps).backward()
        if (micro_batch_id + 1) % self.gradient_acc_steps == 0:
            gradient_norm_score = self.gradient_clipper.clip_gradients()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            step_performed = True
        else:
            step_performed = False
            gradient_norm_score = None
        num_train_steps_done = Trainer._get_num_train_steps_done(
            micro_batch_id=micro_batch_id, gradient_acc_steps=self.gradient_acc_steps
        )
        return step_performed, num_train_steps_done, loss, gradient_norm_score
[docs]
    def train(
        self,
        app_state: AppState,
        train_loader: LLMDataLoader,
        loss_fun: Loss,
        training_log_interval_in_steps: int,
        evaluation_callback: Callable[[TrainingProgress], None],
        checkpointing_callback: Callable[[TrainingProgress], None],
    ):
        """
        Trains the model.
        Args:
            app_state (AppState): The application state containing the model, optimizer and lr scheduler.
            train_loader (LLMDataLoader): The data loader containing the training data.
            loss_fun (Loss): The loss function used for training.
            training_log_interval_in_steps (int): The interval at which training progress is logged.
            evaluation_callback (Callable[[TrainingProgress], None]): A callback function for evaluation.
            checkpointing_callback (Callable[[TrainingProgress], None]): A callback function for checkpointing.
        Returns:
            None
        """
        model = app_state.model
        optimizer = app_state.optimizer
        lr_scheduler = app_state.lr_scheduler
        model.train()
        cumulated_losses = self._reset_tracked_losses()
        # throughput
        thoughput_aggregator = Aggregator[ThroughputAggregationKeys]()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # batch loop
        batch: DatasetBatch
        # TODO: why do we need a barrier here?
        # dist.barrier()
        forward_backward_time_recorder = TimeRecorder()
        forward_backward_time_recorder.start()
        gradient_norm_scores = []
        # run evaluation callback and checkpointing callback before the first optimizer step
        evaluation_callback(num_train_steps_done=self.num_seen_train_steps)
        training_progress = TrainingProgress(
            num_seen_steps_previous_run=self.num_seen_train_steps,
            num_seen_tokens_previous_run=self.global_num_seen_tokens,
            num_seen_steps_current_run=0,
            num_seen_tokens_current_run=0,
            num_target_steps=self.num_target_steps,
            num_target_tokens=self.num_target_tokens,
        )
        checkpointing_callback(training_progress=training_progress)
        num_steps_todo = self.num_target_steps - self.num_seen_train_steps
        num_batches_todo = num_steps_todo * self.gradient_acc_steps
        # Because we might resume training, we add the starting batch id of the data loader
        for _, (micro_batch_id, batch) in zip(range(num_batches_todo), enumerate(train_loader)):
            # Train single batch
            (
                step_performed,
                num_train_steps_done,
                batch_loss,
                gradient_norm_score,
            ) = self._train_batch(
                batch=batch,
                model=model,
                optimizer=optimizer,
                scheduler=lr_scheduler,
                loss_fun=loss_fun,
                micro_batch_id=micro_batch_id,
            )
            forward_backward_time_recorder.stop()
            training_progress.num_seen_steps_current_run = num_train_steps_done
            training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done
            # Save the batch loss
            cumulated_losses[0] += batch_loss.detach().item()
            # This works, because we always drop the last batch in case it has less samples than the batch size
            cumulated_losses[-1] += 1  # number of local batches
            # gradient norm is already synced across all ranks
            if gradient_norm_score is not None:
                gradient_norm_scores.append(gradient_norm_score.item())
            batch_length_tensor = torch.tensor(len(batch)).to(device)
            thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor)
            self._publish_progress(
                progress_publisher=self.progress_publisher,
                num_train_steps_done=training_progress.num_seen_steps_total,
                dataloader_tag=train_loader.dataloader_tag,
            )
            # Check if model performance should be logged
            if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed:
                forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device)
                forward_backward_time_recorder.reset()
                thoughput_aggregator.add_value(
                    key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time
                )
                # we only want to sync the num samples across data parallel ranks
                # so we divide the world size by the dp degree
                synced_num_samples = thoughput_aggregator.get_all_reduced_value(
                    ThroughputAggregationKeys.NUM_SAMPLES
                ) / (dist.get_world_size() / self.dp_degree)
                synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value(
                    ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX
                )
                synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time
                # TODO: insert reducer from outside so Trainer is independent of FSDP
                # add the loss and gradient norm for the LAST batch
                cumulated_losses[1] = batch_loss.item()
                reduced_losses = Reducer.reduce(
                    tensor=cumulated_losses,
                    operation=dist.ReduceOp.SUM,
                    # 1.) summed batch loss / (num batches * world size)
                    # 2.) last batch loss / world size
                    post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]),
                )
                train_loss_avg, train_loss_last_batch = (
                    reduced_losses[0],
                    reduced_losses[1],
                )
                losses = {
                    "train loss avg": ResultItem(train_loss_avg, decimal_places=2),
                    "train loss last": ResultItem(train_loss_last_batch, decimal_places=2),
                }
                consumed_tokens = torch.tensor(training_progress.num_seen_tokens_total)
                metrics = {
                    "consumed tokens": ResultItem(consumed_tokens, 0),
                    "grad norm avg": ResultItem(torch.mean(torch.Tensor(gradient_norm_scores)), 2),
                    "grad norm last": ResultItem(torch.tensor(gradient_norm_scores[-1]), 2),
                }
                gradient_norm_scores = []
                mfu_score = torch.tensor(-1.0)
                if self.mfu_calculator is not None:
                    mfu_score = self.mfu_calculator.compute(num_samples_per_second=synced_num_samples_per_second)
                peak_memory_MB = torch.cuda.max_memory_allocated(device) / 1024**2  # in MB
                torch.cuda.reset_peak_memory_stats(device)
                training_metrics = EvaluationResultBatch(
                    losses=losses,
                    metrics=metrics,
                    # TODO: hardcoded metric key
                    throughput_metrics={
                        "train samples/s": ResultItem(synced_num_samples_per_second, 1),
                        "train mfu (16-bit)": ResultItem(mfu_score, 2),
                        "lr mean": ResultItem(torch.tensor(lr_scheduler.get_last_lr()).mean()),
                        "peak memory rank 0 (MB)": ResultItem(torch.tensor(peak_memory_MB), 2),
                    },
                    dataloader_tag=train_loader.dataloader_tag,
                    num_train_steps_done=training_progress.num_seen_steps_total,
                )
                print_rank_0(f"{datetime.now().isoformat(timespec='seconds')} | {training_metrics}")
                self._publish_evaluation_result(
                    evaluation_result_publisher=self.evaluation_result_publisher,
                    evaluation_result=training_metrics,
                )
                thoughput_aggregator.remove_keys()
                cumulated_losses = self._reset_tracked_losses()
            if step_performed:
                evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total)
                checkpointing_callback(training_progress=training_progress)
            # we start the time recoder here again to also capture the time spend loading
            # via the dataloader.
            forward_backward_time_recorder.start() 
    def _reset_tracked_losses(self):
        # Initializes and returns a tensor representing the cumulated loss and gradient norm.
        # The tensor is initialized with zeros and its device is set based on the availability of CUDA.
        cumulated_loss_and_gradient_norm = torch.zeros(3)
        if torch.cuda.is_available():
            cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to(torch.device("cuda"))
        else:
            cumulated_loss_and_gradient_norm = cumulated_loss_and_gradient_norm.to("cpu")
        return cumulated_loss_and_gradient_norm
    @staticmethod
    def _publish_progress(
        progress_publisher: MessagePublisher[ProgressUpdate],
        num_train_steps_done: int,
        dataloader_tag: str,
    ):
        # Publishes the progress of the training, i.e., number of training steps done.
        payload = ProgressUpdate(
            num_steps_done=num_train_steps_done,
            experiment_status=ExperimentStatus.TRAIN,
            dataloader_tag=dataloader_tag,
        )
        progress_publisher.publish_message(payload=payload, message_type=MessageTypes.BATCH_PROGRESS_UPDATE)
    @staticmethod
    def _publish_evaluation_result(
        evaluation_result_publisher: MessagePublisher[EvaluationResultBatch],
        evaluation_result: EvaluationResultBatch,
    ):
        # Publishes the evaluation result.
        evaluation_result_publisher.publish_message(
            payload=evaluation_result, message_type=MessageTypes.EVALUATION_RESULT
        )