modalities.checkpointing.torch package
Submodules
modalities.checkpointing.torch.torch_checkpoint_loading module
- class modalities.checkpointing.torch.torch_checkpoint_loading.TorchCheckpointLoading(device, precision=None)[source]
- Bases: - FSDP1CheckpointLoadingIF- Class to load PyTorch model and optimizer checkpoints. - Initializes the TorchCheckpointLoading object. - Args:
- device (torch.device): The device to load the model on. precision (Optional[PrecisionEnum], optional): If specified, the model checkpoint will - loaded with the given precision. Otherwise, the precision as specified in the state_dict will be used. Defaults to None. 
- Returns:
- None 
 - Parameters:
- device (device) 
- precision (PrecisionEnum | None) 
 
 - load_model_checkpoint(model, file_path)[source]
- Loads a model checkpoint from the specified file path. - Args:
- model (nn.Module): The model to load the checkpoint into. file_path (Path): The path to the checkpoint file. 
- Returns:
- nn.Module: The model with the loaded checkpoint. 
 
 - load_optimizer_checkpoint_(optimizer, model, file_path)[source]
- Load the optimizer checkpoint from the specified file path (in-place). - Args:
- optimizer (Optimizer): The optimizer to load the checkpoint into (in-place). model (nn.Module): The model associated with the optimizer. file_path (Path): The path to the checkpoint file. 
- Raises:
- NotImplementedError: This method is not implemented yet. It is reserved for future work.