modalities.checkpointing.stateful package
Submodules
modalities.checkpointing.stateful.app_state module
- class modalities.checkpointing.stateful.app_state.AppState(model, optimizer, lr_scheduler=None)[source]
Bases:
StatefulThis is a useful wrapper for checkpointing the application state (i.e., model, optimizer, lr scheduler). Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs.
Note: We take advantage of this wrapper to call distributed state dict methods on the model and optimizer. Note: this class has been copied and adapted from https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
Initializes the AppState object.
- Args:
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
- Parameters:
model (Module)
optimizer (Optimizer)
lr_scheduler (LRScheduler | None)
- property is_loaded: bool
Returns whether the state dict has been loaded. Returns:
bool: Flag indicating whether the state dict has been loaded.
- load_state_dict(state_dict)[source]
Loads the state dict into the AppState object.
- Args:
state_dict (dict[str, Any]): The state dict to load into the AppState object.
- Raises:
RuntimeError: If the state dict has already been loaded.
- property lr_scheduler: LRScheduler
- class modalities.checkpointing.stateful.app_state.LRSchedulerStateRetriever[source]
Bases:
StateRetrieverIF- static get_state_dict(app_state)[source]
Returns the state dict of the lr scheduler in the AppState object.
- Args:
app_state (AppState): The app_state object containing the lr scheduler.
- Returns:
dict[str, Any]: The state dict of the lr scheduler in the AppState object.
- class modalities.checkpointing.stateful.app_state.ModelStateRetriever[source]
Bases:
StateRetrieverIF- static get_state_dict(app_state)[source]
Returns the state dict of the model in the AppState object.
- Args:
app_state (AppState): The app_state object containing the model.
- Returns:
dict[str, Any]: The state dict of the model in the AppState object.
- class modalities.checkpointing.stateful.app_state.OptimizerStateRetriever[source]
Bases:
StateRetrieverIF- static get_state_dict(app_state)[source]
Returns the state dict of the optimizer in the AppState object.
- Args:
app_state (AppState): The app_state object containing the optimizer.
- Returns:
dict[str, Any]: The state dict of the optimizer in the AppState object.
- class modalities.checkpointing.stateful.app_state.StateRetrieverIF[source]
Bases:
ABCState retriever interface for loading and getting state dicts of models, optimizers and lr schedulers. Other stateful components can be added as needed by having the retriever implement this interface.
- abstractmethod static get_state_dict(app_state)[source]
Returns the state dict of the AppState object.
- Args:
app_state (AppState): The application state object.
- Raises:
NotImplementedError: This abstract method is not implemented and should be overridden in a subclass.
- Returns:
dict[str, Any]: The state dict of the AppState object.
- abstractmethod static load_state_dict_(app_state, state_dict)[source]
Loads the state dict into the AppState object.
- Args:
app_state (AppState): The application state object. state_dict (dict[str, Any]): The state dict to load into the AppState object.
- Raises:
NotImplementedError: This abstract method is not implemented and should be overridden in a subclass.
modalities.checkpointing.stateful.app_state_factory module
- class modalities.checkpointing.stateful.app_state_factory.AppStateFactory[source]
Bases:
objectFactory class to create AppState objects.
- static get_dcp_checkpointed_app_state_(raw_app_state, checkpoint_dir_path)[source]
Loads the checkpointed state dict into the raw AppState object (i.e., non-checkpoint loaded AppState) in-place.
- Args:
raw_app_state (AppState): The raw AppState object. checkpoint_dir_path (Path): The path to the checkpoint directory.
- Raises:
RuntimeError: Raises an error if the state dict has already been loaded.
- Returns:
AppState: The AppState object with the loaded state dict.
- static get_raw_app_state(model, optimizer, lr_scheduler=None)[source]
Creates a new (non-checkpoint loaded) AppState object from an instantiated model, optimizer, and optional learning rate scheduler.
- Return type:
- Parameters:
model (Module)
optimizer (Optimizer)
lr_scheduler (LRScheduler | None)
- Args:
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None.
- Returns:
AppState: The AppState object.