modalities.models package

Subpackages

Submodules

modalities.models.model module

class modalities.models.model.ActivationType(value)[source]

Bases: str, Enum

Enum class representing different activation types.

Attributes:

GELU (str): GELU activation type. SWIGLU (str): SWIGLU activation type.

GELU = 'gelu'
SWIGLU = 'swiglu'
class modalities.models.model.NNModel(seed=None, weight_decay_groups=None)[source]

Bases: Module

NNModel class to define a base model.

Initializes an NNModel object.

Args:

seed (int, optional): The seed value for random number generation. Defaults to None. weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None.

Parameters:
abstractmethod forward(inputs)[source]

Forward pass of the model.

Return type:

dict[str, Tensor]

Parameters:

inputs (dict[str, Tensor])

Args:

inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.

Returns:

dict[str, torch.Tensor]: A dictionary containing output tensors.

get_parameters()[source]

Returns a dictionary of the model’s parameters.

Return type:

dict[str, Tensor]

Returns:

A dictionary where the keys are the parameter names and the values are the corresponding parameter tensors.

property weight_decay_groups: dict[str, list[str]]

Returns the weight decay groups.

Returns:

WeightDecayGroups: The weight decay groups.

class modalities.models.model.SwiGLU(n_embd, ffn_hidden, bias)[source]

Bases: Module

SwiGLU class to define the SwiGLU activation function.

Initializes the SwiGLU object.

Args:

n_embd (int): The number of embedding dimensions. ffn_hidden (int): The number of hidden dimensions in the feed-forward network. Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) bias (bool): Whether to include bias terms in the linear layers.

Parameters:
forward(x)[source]

Forward pass of the SwiGLU module.

Return type:

Tensor

Parameters:

x (Tensor)

Args:

x (torch.Tensor): Input tensor.

Returns:

torch.Tensor: Output tensor.

modalities.models.model.model_predict_batch(model, batch)[source]

Predicts the output for a batch of samples using the given model.

Return type:

InferenceResultBatch

Parameters:
Args:

model (nn.Module): The model used for prediction. batch (DatasetBatch): The batch of samples to be predicted.

Returns:

InferenceResultBatch: The batch of inference results containing the predicted targets and predictions.

modalities.models.model_factory module

class modalities.models.model_factory.GPT2ModelFactory[source]

Bases: object

static get_gpt2_model(sample_key, prediction_key, poe_type, sequence_length, vocab_size, n_layer, n_head_q, n_head_kv, n_embd, ffn_hidden, dropout, bias, activation_type, attention_implementation, attention_config, attention_norm_config, ffn_norm_config, lm_head_norm_config, use_weight_tying, use_meta_device=False, seed=None)[source]
Return type:

GPT2LLM

Parameters:
class modalities.models.model_factory.ModelFactory[source]

Bases: object

Model factory class to create models.

static get_activation_checkpointed_fsdp1_model(model, activation_checkpointing_modules)[source]

Apply activation checkpointing to the given model (in-place operation).

Return type:

FullyShardedDataParallel

Parameters:
Args:

model (FSDP1): The FSDP1-wrapped model to apply activation checkpointing to. activation_checkpointing_modules (list[str]): List of module names to apply activation checkpointing to.

Raises:

ValueError: Activation checkpointing can only be applied to FSDP1-wrapped models!

Returns:

FSDP1: The model with activation checkpointing applied.

static get_compiled_model(model, block_names, fullgraph, debug=False)[source]

Apply torch.compile to each transformer block, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). Inspired by: https://github.com/pytorch/torchtitan/blob/6b2912a9b53464bfef744e62100716271b2b248f/torchtitan/parallelisms/parallelize_llama.py#L275

Return type:

Module

Parameters:
Note: With fullgraph=True, we enforce the block to be compiled as a whole, which raises an error on

graph breaks and maximizes speedup.

Args:

model (nn.Module): The model to be compiled. block_names (list[str]): List of block names to be compiled individually. fullgraph (bool): Flag enforcing the block to be compiled without graph breaks. debug (Optional[bool]): Flag to enable debug mode. Default is False.

Returns:

nn.Module: The compiled model.

static get_fsdp1_checkpointed_model(checkpoint_loading, checkpoint_path, model)[source]

Loads a FSDP1 checkpointed model from the given checkpoint path.

Return type:

FullyShardedDataParallel

Parameters:
Args:
checkpoint_loading (FSDP1CheckpointLoadingIF): The checkpoint loading

approach used to load the model checkpoint.

checkpoint_path (Path): The path to the checkpoint file. model (nn.Module): The model to be loaded with the checkpoint.

Returns:

nn.Module: The loaded wrapped model.

get_fsdp1_wrapped_model(sync_module_states, block_names, mixed_precision_settings, sharding_strategy)[source]

Get the FSDP1-wrapped model.

Return type:

FullyShardedDataParallel

Parameters:
Args:

model (nn.Module): The original model to be wrapped. sync_module_states (bool): Whether to synchronize module states across ranks. block_names (list[str]): List of block names. mixed_precision_settings (MixedPrecisionSettings): Mixed precision settings. sharding_strategy (ShardingStrategy): Sharding strategy.

Returns:

FSDP1: The FSDP1-wrapped model.

Note:

‘FSDPTransformerAutoWrapPolicyFactory` is hardcoded and should be passed in instead. Different auto wrap policies may be supported in the future.

static get_fsdp2_wrapped_model(model, block_names, device_mesh, mixed_precision_settings, reshard_after_forward)[source]

Get the FSDP2-wrapped model.

Based on https://github.com/pytorch/torchtitan/blob/de9fd2b9ea7e763c9182e0df81fc32c2618cc0b6/torchtitan/parallelisms/parallelize_llama.py#L459 and https://github.com/pytorch/torchtitan/blob/43584e0a4e72645e25cccd05d86f9632587a8beb/docs/fsdp.md NOTE: Torch Titan already implement pipeline parallelism. We skip that here for now.

Return type:

FSDPModule

Parameters:
Args:

model (nn.Module): The original model to be wrapped. block_names (list[str]): List of block names. device_mesh (DeviceMesh): The device mesh. mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings. reshard_after_forward (bool): Whether to reshard after forward.

Returns:

FSDP2: The FSDP2-wrapped model.

static get_weight_initialized_model(model, model_initializer)[source]

Initializes the given model with weights using the provided model initializer. The model can be on a meta device.

Return type:

Module

Parameters:
Args:

model (nn.Module): The model to be initialized with weights. model_initializer (ModelInitializationIF): The model initializer object.

Returns:

nn.Module: The initialized model.

modalities.models.utils module

class modalities.models.utils.ModelTypeEnum(value)[source]

Bases: Enum

Enumeration class representing different types of models.

Attributes:

MODEL (str): Represents a regular model. CHECKPOINTED_MODEL (str): Represents a checkpointed model.

CHECKPOINTED_MODEL = 'checkpointed_model'
MODEL = 'model'
modalities.models.utils.get_model_from_config(config, model_type)[source]

Retrieves a model from the given configuration based on the specified model type.

Args:

config (dict): The configuration dictionary. model_type (ModelTypeEnum): The type of the model to retrieve.

Returns:

Any: The model object based on the specified model type.

Raises:

NotImplementedError: If the model type is not supported.

Parameters:

Module contents