modalities.training.activation_checkpointing package

Submodules

modalities.training.activation_checkpointing.activation_checkpointing module

class modalities.training.activation_checkpointing.activation_checkpointing.ActivationCheckpointing[source]

Bases: object

In eager / normal mode, every module stores ALL activations in the forward pass and reuses the activations for gradient computation in the backward pass. Thus, the overall memory footprint of the activations accumulates during the forward pass and peaks before running the backward pass. During the backward pass, the activations are cleared once they are not needed anymore.

With activation checkpointing, the regions that are ACed do only store the input and output activations, but no intermediate activations, thus, reducing the overall memory footprint. In the backward pass, the intermediate activations are recomputed, trading off a lower memory footprint with a higher compute cost. Typically, these ACed regions are the transformer blocks in the case of a GPT model.

With selective AC, we add another AC variant, that allows for a more granular control over the AC process. This variant allows to only save the activations of certain, typically compute intensive operations, while recomputing the activations of all other operations. Thus, the overall memory footprint is reduced, while the compute cost is not increased too much.

The implemenation is heavily inspired by the torch titan implementation: https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/models/llama/parallelize_llama.py#L294

SAVE_DICT = {'ops._c10d_functional.reduce_scatter_tensor.default': <OpOverload(op='_c10d_functional.reduce_scatter_tensor', overload='default')>, 'ops.aten._scaled_dot_product_efficient_attention.default': <OpOverload(op='aten._scaled_dot_product_efficient_attention', overload='default')>, 'ops.aten._scaled_dot_product_flash_attention.default': <OpOverload(op='aten._scaled_dot_product_flash_attention', overload='default')>, 'ops.aten.max.default': <OpOverload(op='aten.max', overload='default')>, 'ops.aten.mm.default': <OpOverload(op='aten.mm', overload='default')>}
static apply_activation_checkpointing_(ac_variant, layers_fqn, model, ac_fun_params)[source]

Applies activation checkpointing to a given model. There are three variants of activation checkpointing: 1. FULL_ACTIVATION_CHECKPOINTING: applies activation checkpointing to all layers. In thise case,

only the inputs and outputs of each layer are saved, but not the intermediate activations.

  1. SELECTIVE_LAYER_ACTIVATION_CHECKPOINTING: applies activation checkpointing to every ac_freq layer. It is similar to FULL_ACTIVATION_CHECKPOINTING, but only saves the inputs and outputs of every ac_freq layer.

  2. SELECTIVE_OP_ACTIVATION_CHECKPOINTING: applies activation checkpointing to all layers, but only saves the activations of certain operations. Usually these operations are compute intensive and their activations are saved and not recomputed in the backward pass. All the remaining operations are recomputed in the backward pass.

Args:

ac_variant (ActivationCheckpointingVariants): The activation checkpointing variant to use. layers_fqn (str): The fully qualified name (FQN) of the layers to apply activation checkpointing to. model (nn.Module): The model to apply activation checkpointing to (in place). ac_fun_params (ActivationCheckpointedModelConfig.*Params): The parameters for the activation

checkpointing function.

Raises:
ValueError: If the activation checkpointing variant is not recognized or if the layers_fqn does not

reference a ModuleList.

Parameters:
modalities.training.activation_checkpointing.activation_checkpointing.apply_activation_checkpointing_fsdp1_inplace(model, activation_checkpointing_modules)[source]
Parameters:
modalities.training.activation_checkpointing.activation_checkpointing.is_module_to_apply_activation_checkpointing(submodule, activation_checkpointing_modules)[source]
Return type:

bool

Parameters:

modalities.training.activation_checkpointing.activation_checkpointing_variants module

class modalities.training.activation_checkpointing.activation_checkpointing_variants.ActivationCheckpointingVariants(value)[source]

Bases: Enum

Enum for the different activation checkpointing variants.

FULL_ACTIVATION_CHECKPOINTING = 'full_activation_checkpointing'
SELECTIVE_LAYER_ACTIVATION_CHECKPOINTING = 'selective_layer_activation_checkpointing'
SELECTIVE_OP_ACTIVATION_CHECKPOINTING = 'selective_op_activation_checkpointing'

Module contents