modalities.checkpointing.torch package

Submodules

modalities.checkpointing.torch.torch_checkpoint_loading module

class modalities.checkpointing.torch.torch_checkpoint_loading.TorchCheckpointLoading(device, precision=None)[source]

Bases: FSDP1CheckpointLoadingIF

Class to load PyTorch model and optimizer checkpoints.

Initializes the TorchCheckpointLoading object.

Args:

device (torch.device): The device to load the model on. precision (Optional[PrecisionEnum], optional): If specified, the model checkpoint will

loaded with the given precision. Otherwise, the precision as specified in the state_dict will be used. Defaults to None.

Returns:

None

Parameters:
load_model_checkpoint(model, file_path)[source]

Loads a model checkpoint from the specified file path.

Return type:

Module

Parameters:
Args:

model (nn.Module): The model to load the checkpoint into. file_path (Path): The path to the checkpoint file.

Returns:

nn.Module: The model with the loaded checkpoint.

load_optimizer_checkpoint_(optimizer, model, file_path)[source]

Load the optimizer checkpoint from the specified file path (in-place).

Args:

optimizer (Optimizer): The optimizer to load the checkpoint into (in-place). model (nn.Module): The model associated with the optimizer. file_path (Path): The path to the checkpoint file.

Raises:

NotImplementedError: This method is not implemented yet. It is reserved for future work.

Parameters:

Module contents