modalities.checkpointing.torch package
Submodules
modalities.checkpointing.torch.torch_checkpoint_loading module
- class modalities.checkpointing.torch.torch_checkpoint_loading.TorchCheckpointLoading(device, precision=None)[source]
Bases:
FSDP1CheckpointLoadingIF
Class to load PyTorch model and optimizer checkpoints.
Initializes the TorchCheckpointLoading object.
- Args:
device (torch.device): The device to load the model on. precision (Optional[PrecisionEnum], optional): If specified, the model checkpoint will
loaded with the given precision. Otherwise, the precision as specified in the state_dict will be used. Defaults to None.
- Returns:
None
- Parameters:
device (device)
precision (PrecisionEnum | None)
- load_model_checkpoint(model, file_path)[source]
Loads a model checkpoint from the specified file path.
- Args:
model (nn.Module): The model to load the checkpoint into. file_path (Path): The path to the checkpoint file.
- Returns:
nn.Module: The model with the loaded checkpoint.
- load_optimizer_checkpoint_(optimizer, model, file_path)[source]
Load the optimizer checkpoint from the specified file path (in-place).
- Args:
optimizer (Optimizer): The optimizer to load the checkpoint into (in-place). model (nn.Module): The model associated with the optimizer. file_path (Path): The path to the checkpoint file.
- Raises:
NotImplementedError: This method is not implemented yet. It is reserved for future work.