modalities.checkpointing.stateful package
Submodules
modalities.checkpointing.stateful.app_state module
- class modalities.checkpointing.stateful.app_state.AppState(model, optimizer, lr_scheduler=None)[source]
Bases:
Stateful
This is a useful wrapper for checkpointing the application state (i.e., model, optimizer, lr scheduler). Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs.
Note: We take advantage of this wrapper to call distributed state dict methods on the model and optimizer. Note: this class has been copied and adapted from https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
Initializes the AppState object.
- Args:
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
- Parameters:
model (Module)
optimizer (Optimizer)
lr_scheduler (LRScheduler | None)
- property is_loaded: bool
Returns whether the state dict has been loaded. Returns:
bool: Flag indicating whether the state dict has been loaded.
- load_state_dict(state_dict)[source]
Loads the state dict into the AppState object.
- Args:
state_dict (dict[str, Any]): The state dict to load into the AppState object.
- Raises:
RuntimeError: If the state dict has already been loaded.
- property lr_scheduler: LRScheduler
- class modalities.checkpointing.stateful.app_state.LRSchedulerStateRetriever[source]
Bases:
StateRetrieverIF
- static get_state_dict(app_state)[source]
Returns the state dict of the lr scheduler in the AppState object.
- Args:
app_state (AppState): The app_state object containing the lr scheduler.
- Returns:
dict[str, Any]: The state dict of the lr scheduler in the AppState object.
- class modalities.checkpointing.stateful.app_state.ModelStateRetriever[source]
Bases:
StateRetrieverIF
- static get_state_dict(app_state)[source]
Returns the state dict of the model in the AppState object.
- Args:
app_state (AppState): The app_state object containing the model.
- Returns:
dict[str, Any]: The state dict of the model in the AppState object.
- class modalities.checkpointing.stateful.app_state.OptimizerStateRetriever[source]
Bases:
StateRetrieverIF
- static get_state_dict(app_state)[source]
Returns the state dict of the optimizer in the AppState object.
- Args:
app_state (AppState): The app_state object containing the optimizer.
- Returns:
dict[str, Any]: The state dict of the optimizer in the AppState object.
- class modalities.checkpointing.stateful.app_state.StateRetrieverIF[source]
Bases:
ABC
State retriever interface for loading and getting state dicts of models, optimizers and lr schedulers. Other stateful components can be added as needed by having the retriever implement this interface.
- abstractmethod static get_state_dict(app_state)[source]
Returns the state dict of the AppState object.
- Args:
app_state (AppState): The application state object.
- Raises:
NotImplementedError: This abstract method is not implemented and should be overridden in a subclass.
- Returns:
dict[str, Any]: The state dict of the AppState object.
- abstractmethod static load_state_dict_(app_state, state_dict)[source]
Loads the state dict into the AppState object.
- Args:
app_state (AppState): The application state object. state_dict (dict[str, Any]): The state dict to load into the AppState object.
- Raises:
NotImplementedError: This abstract method is not implemented and should be overridden in a subclass.
modalities.checkpointing.stateful.app_state_factory module
- class modalities.checkpointing.stateful.app_state_factory.AppStateFactory[source]
Bases:
object
Factory class to create AppState objects.
- static get_dcp_checkpointed_app_state_(raw_app_state, checkpoint_dir_path)[source]
Loads the checkpointed state dict into the raw AppState object (i.e., non-checkpoint loaded AppState) in-place.
- Args:
raw_app_state (AppState): The raw AppState object. checkpoint_dir_path (Path): The path to the checkpoint directory.
- Raises:
RuntimeError: Raises an error if the state dict has already been loaded.
- Returns:
AppState: The AppState object with the loaded state dict.
- static get_raw_app_state(model, optimizer, lr_scheduler=None)[source]
Creates a new (non-checkpoint loaded) AppState object from an instantiated model, optimizer, and optional learning rate scheduler.
- Return type:
- Parameters:
model (Module)
optimizer (Optimizer)
lr_scheduler (LRScheduler | None)
- Args:
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None.
- Returns:
AppState: The AppState object.