modalities.checkpointing.stateful package

Submodules

modalities.checkpointing.stateful.app_state module

class modalities.checkpointing.stateful.app_state.AppState(model, optimizer, lr_scheduler=None)[source]

Bases: Stateful

This is a useful wrapper for checkpointing the application state (i.e., model, optimizer, lr scheduler). Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs.

Note: We take advantage of this wrapper to call distributed state dict methods on the model and optimizer. Note: this class has been copied and adapted from https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html

Initializes the AppState object.

Args:

model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.

Parameters:
property is_loaded: bool

Returns whether the state dict has been loaded. Returns:

bool: Flag indicating whether the state dict has been loaded.

load_state_dict(state_dict)[source]

Loads the state dict into the AppState object.

Return type:

None

Parameters:

state_dict (dict[str, Any])

Args:

state_dict (dict[str, Any]): The state dict to load into the AppState object.

Raises:

RuntimeError: If the state dict has already been loaded.

property lr_scheduler: LRScheduler
property model: Module
property optimizer: Optimizer
state_dict()[source]

Returns the state dict of the AppState object.

Return type:

dict[str, Any]

Returns:

dict[str, Any]: The state dict of the AppState object.

class modalities.checkpointing.stateful.app_state.LRSchedulerStateRetriever[source]

Bases: StateRetrieverIF

static get_state_dict(app_state)[source]

Returns the state dict of the lr scheduler in the AppState object.

Return type:

dict[str, Any]

Parameters:

app_state (AppState)

Args:

app_state (AppState): The app_state object containing the lr scheduler.

Returns:

dict[str, Any]: The state dict of the lr scheduler in the AppState object.

static load_state_dict_(app_state, state_dict)[source]

Loads the state dict into the lr scheduler in the AppState object.

Return type:

None

Parameters:
Args:

app_state (AppState): The app_state object containing the lr scheduler. state_dict (dict[str, Any]): The state dict to load into the lr scheduler.

class modalities.checkpointing.stateful.app_state.ModelStateRetriever[source]

Bases: StateRetrieverIF

static get_state_dict(app_state)[source]

Returns the state dict of the model in the AppState object.

Return type:

dict[str, Any]

Parameters:

app_state (AppState)

Args:

app_state (AppState): The app_state object containing the model.

Returns:

dict[str, Any]: The state dict of the model in the AppState object.

static load_state_dict_(app_state, state_dict)[source]

Loads the state dict into the model in the AppState object.

Return type:

None

Parameters:
Args:

app_state (AppState): The app_state object containing the model. state_dict (dict[str, Any]): The state dict to load into the model.

class modalities.checkpointing.stateful.app_state.OptimizerStateRetriever[source]

Bases: StateRetrieverIF

static get_state_dict(app_state)[source]

Returns the state dict of the optimizer in the AppState object.

Return type:

dict[str, Any]

Parameters:

app_state (AppState)

Args:

app_state (AppState): The app_state object containing the optimizer.

Returns:

dict[str, Any]: The state dict of the optimizer in the AppState object.

static load_state_dict_(app_state, state_dict)[source]

Loads the state dict into the optimizer in the AppState object.

Return type:

None

Parameters:
Args:

app_state (AppState): The app_state object containing the optimizer. state_dict (dict[str, Any]): The state dict to load into the optimizer.

class modalities.checkpointing.stateful.app_state.StateRetrieverIF[source]

Bases: ABC

State retriever interface for loading and getting state dicts of models, optimizers and lr schedulers. Other stateful components can be added as needed by having the retriever implement this interface.

abstractmethod static get_state_dict(app_state)[source]

Returns the state dict of the AppState object.

Return type:

dict[str, Any]

Parameters:

app_state (AppState)

Args:

app_state (AppState): The application state object.

Raises:

NotImplementedError: This abstract method is not implemented and should be overridden in a subclass.

Returns:

dict[str, Any]: The state dict of the AppState object.

abstractmethod static load_state_dict_(app_state, state_dict)[source]

Loads the state dict into the AppState object.

Return type:

None

Parameters:
Args:

app_state (AppState): The application state object. state_dict (dict[str, Any]): The state dict to load into the AppState object.

Raises:

NotImplementedError: This abstract method is not implemented and should be overridden in a subclass.

class modalities.checkpointing.stateful.app_state.StatefulComponents(value)[source]

Bases: Enum

LR_SCHEDULER = 'lr_scheduler'
MODEL = 'model'
OPTIMIZER = 'optimizer'

modalities.checkpointing.stateful.app_state_factory module

class modalities.checkpointing.stateful.app_state_factory.AppStateFactory[source]

Bases: object

Factory class to create AppState objects.

static get_dcp_checkpointed_app_state_(raw_app_state, checkpoint_dir_path)[source]

Loads the checkpointed state dict into the raw AppState object (i.e., non-checkpoint loaded AppState) in-place.

Return type:

AppState

Parameters:
Args:

raw_app_state (AppState): The raw AppState object. checkpoint_dir_path (Path): The path to the checkpoint directory.

Raises:

RuntimeError: Raises an error if the state dict has already been loaded.

Returns:

AppState: The AppState object with the loaded state dict.

static get_raw_app_state(model, optimizer, lr_scheduler=None)[source]

Creates a new (non-checkpoint loaded) AppState object from an instantiated model, optimizer, and optional learning rate scheduler.

Return type:

AppState

Parameters:
Args:

model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None.

Returns:

AppState: The AppState object.

Module contents