modalities.models package

Subpackages

Submodules

modalities.models.model module

class modalities.models.model.ActivationType(value)[source]

Bases: str, Enum

Enum class representing different activation types.

Attributes:

GELU (str): GELU activation type. SWIGLU (str): SWIGLU activation type.

GELU = 'gelu'
SWIGLU = 'swiglu'
class modalities.models.model.NNModel(seed=None, weight_decay_groups=None)[source]

Bases: Module

NNModel class to define a base model.

Initializes an NNModel object.

Args:

seed (int, optional): The seed value for random number generation. Defaults to None. weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None.

Parameters:
abstractmethod forward(inputs)[source]

Forward pass of the model.

Return type:

dict[str, Tensor]

Parameters:

inputs (dict[str, Tensor])

Args:

inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.

Returns:

dict[str, torch.Tensor]: A dictionary containing output tensors.

get_parameters()[source]

Returns a dictionary of the model’s parameters.

Return type:

dict[str, Tensor]

Returns:

A dictionary where the keys are the parameter names and the values are the corresponding parameter tensors.

property weight_decay_groups: dict[str, list[str]]

Returns the weight decay groups.

Returns:

WeightDecayGroups: The weight decay groups.

class modalities.models.model.SwiGLU(n_embd, ffn_hidden, bias)[source]

Bases: Module

SwiGLU class to define the SwiGLU activation function.

Initializes the SwiGLU object.

Args:

n_embd (int): The number of embedding dimensions. ffn_hidden (int): The number of hidden dimensions in the feed-forward network. Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) bias (bool): Whether to include bias terms in the linear layers.

Parameters:
forward(x)[source]

Forward pass of the SwiGLU module.

Return type:

Tensor

Parameters:

x (Tensor)

Args:

x (torch.Tensor): Input tensor.

Returns:

torch.Tensor: Output tensor.

modalities.models.model.model_predict_batch(model, batch)[source]

Predicts the output for a batch of samples using the given model.

Return type:

InferenceResultBatch

Parameters:
Args:

model (nn.Module): The model used for prediction. batch (DatasetBatch): The batch of samples to be predicted.

Returns:

InferenceResultBatch: The batch of inference results containing the predicted targets and predictions.

modalities.models.model_factory module

class modalities.models.model_factory.GPT2ModelFactory[source]

Bases: object

static get_gpt2_model(sample_key, prediction_key, poe_type, sequence_length, vocab_size, n_layer, n_head_q, n_head_kv, n_embd, ffn_hidden, dropout, bias, activation_type, attention_implementation, attention_config, attention_norm_config, ffn_norm_config, lm_head_norm_config, use_weight_tying, use_meta_device=False, seed=None)[source]
Return type:

GPT2LLM

Parameters:
static get_gpt2_tensor_parallelized_model(model, device_mesh)[source]
Return type:

Module

Parameters:
class modalities.models.model_factory.ModelFactory[source]

Bases: object

Model factory class to create models.

static get_activation_checkpointed_fsdp1_model_(model, activation_checkpointing_modules)[source]

Apply activation checkpointing to the given FSDP1-wrapped model (in-place operation).

Return type:

FullyShardedDataParallel

Parameters:
Args:

model (FSDP1): The FSDP1-wrapped model to apply activation checkpointing to. activation_checkpointing_modules (list[str]): List of module names to apply activation checkpointing to.

Raises:

ValueError: Activation checkpointing can only be applied to FSDP1-wrapped models!

Returns:

FSDP1: The model with activation checkpointing applied.

static get_activation_checkpointed_fsdp2_model_(ac_variant, layers_fqn, model, ac_fun_params)[source]

FSDP2 variant for applying activation checkpointing to the given model (in-place operation).

Return type:

Module

Parameters:
Important: When using FSDP2, we always first apply activation checkpointing to the model

and then wrap it with FSDP2.

Args:

ac_variant (ActivationCheckpointingVariants): The activation checkpointing variant to use. layers_fqn (str): Fully qualified name (FQN) of the layers to apply activation checkpointing to. model (nn.Module): The (unwrapped) model to apply activation checkpointing to. ac_fun_params (ACM.FullACParams | ACM.SelectiveLayerACParams | ACM.SelectiveOpACParams):

The parameters for the activation checkpointing function, depending on the variant.

Raises:

ValueError: Activation checkpointing can only be applied to unwrapped nn.Module models

Returns:

nn.Module: The model with activation checkpointing applied.

static get_compiled_model(model, block_names, fullgraph, debug=False)[source]

Apply torch.compile to each transformer block, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). Inspired by: https://github.com/pytorch/torchtitan/blob/6b2912a9b53464bfef744e62100716271b2b248f/torchtitan/parallelisms/parallelize_llama.py#L275

Return type:

Module

Parameters:
Note: With fullgraph=True, we enforce the block to be compiled as a whole, which raises an error on

graph breaks and maximizes speedup.

Args:

model (nn.Module): The model to be compiled. block_names (list[str]): List of block names to be compiled individually. fullgraph (bool): Flag enforcing the block to be compiled without graph breaks. debug (Optional[bool]): Flag to enable debug mode. Default is False.

Returns:

nn.Module: The compiled model.

static get_debugging_enriched_model(model, logging_dir_path, tracked_ranks=None, log_interval_steps=1)[source]

Enriches the model with debugging hooks to log tensor statistics during forward and backward passes. During the forward pass, it logs the input and output tensors of each module, as well as the parameters. Similarly, during the backward pass, it logs the gradients of the output tensors.

Return type:

Module

Parameters:
  • model (Module)

  • logging_dir_path (Path)

  • tracked_ranks (Set[int] | None)

  • log_interval_steps (int)

The following tensor statistics are logged:
  • global shape

  • local shape

  • is_dtensor (whether the tensor is a DTensor)

  • nan count

  • inf count

  • mean

  • std

  • min

  • max

The statistics are written to a JSONL file in the specified logging directory.

Args:

model (nn.Module): The model to be enriched with debugging hooks. logging_dir_path (Path): The directory path where the tensor statistics will be logged. tracked_ranks (Optional[Set[int]]): A set of ranks to track. If provided, only these ranks

will log the statistics. If None, all ranks will log the statistics.

log_interval_steps (int): The interval in steps at which to log the tensor statistics. Default is 1.

static get_fsdp1_checkpointed_model(checkpoint_loading, checkpoint_path, model)[source]

Loads a FSDP1 checkpointed model from the given checkpoint path.

Return type:

FullyShardedDataParallel

Parameters:
Args:
checkpoint_loading (FSDP1CheckpointLoadingIF): The checkpoint loading

approach used to load the model checkpoint.

checkpoint_path (Path): The path to the checkpoint file. model (nn.Module): The model to be loaded with the checkpoint.

Returns:

nn.Module: The loaded wrapped model.

get_fsdp1_wrapped_model(sync_module_states, block_names, mixed_precision_settings, sharding_strategy)[source]

Get the FSDP1-wrapped model.

Return type:

FullyShardedDataParallel

Parameters:
Args:

model (nn.Module): The original model to be wrapped. sync_module_states (bool): Whether to synchronize module states across ranks. block_names (list[str]): List of block names. mixed_precision_settings (MixedPrecisionSettings): Mixed precision settings. sharding_strategy (ShardingStrategy): Sharding strategy.

Returns:

FSDP1: The FSDP1-wrapped model.

Note:

‘FSDPTransformerAutoWrapPolicyFactory` is hardcoded and should be passed in instead. Different auto wrap policies may be supported in the future.

static get_fsdp2_wrapped_model(model, block_names, device_mesh, mixed_precision_settings, reshard_after_forward)[source]

Get the FSDP2-wrapped model.

Based on https://github.com/pytorch/torchtitan/blob/de9fd2b9ea7e763c9182e0df81fc32c2618cc0b6/torchtitan/parallelisms/parallelize_llama.py#L459 and https://github.com/pytorch/torchtitan/blob/43584e0a4e72645e25cccd05d86f9632587a8beb/docs/fsdp.md NOTE: Torch Titan already implement pipeline parallelism. We skip that here for now.

Return type:

FSDPModule

Parameters:
Args:

model (nn.Module): The original model to be wrapped. block_names (list[str]): List of block names. device_mesh (DeviceMesh): The device mesh. mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings. reshard_after_forward (bool): Whether to reshard after forward.

Returns:

FSDP2: The FSDP2-wrapped model.

static get_weight_initialized_model(model, model_initializer)[source]

Initializes the given model with weights using the provided model initializer. The model can be on a meta device.

Return type:

Module

Parameters:
Args:

model (nn.Module): The model to be initialized with weights. model_initializer (ModelInitializationIF): The model initializer object.

Returns:

nn.Module: The initialized model.

modalities.models.utils module

class modalities.models.utils.ModelTypeEnum(value)[source]

Bases: Enum

Enumeration class representing different types of models.

Attributes:

MODEL (str): Represents a regular model. CHECKPOINTED_MODEL (str): Represents a checkpointed model.

CHECKPOINTED_MODEL = 'checkpointed_model'
MODEL = 'model'
modalities.models.utils.get_model_from_config(config, model_type)[source]

Retrieves a model from the given configuration based on the specified model type.

Args:

config (dict): The configuration dictionary. model_type (ModelTypeEnum): The type of the model to retrieve.

Returns:

Any: The model object based on the specified model type.

Raises:

NotImplementedError: If the model type is not supported.

Parameters:

Module contents