modalities.models package
Subpackages
- modalities.models.coca package
- Submodules
- modalities.models.coca.attention_pooling module
- modalities.models.coca.coca_model module
CoCa
CoCaConfig
CoCaConfig.bias_attn_pool
CoCaConfig.epsilon_attn_pool
CoCaConfig.model_config
CoCaConfig.n_pool_head
CoCaConfig.n_vision_queries
CoCaConfig.prediction_key
CoCaConfig.text_cls_prediction_key
CoCaConfig.text_decoder_config
CoCaConfig.text_embd_prediction_key
CoCaConfig.vision_cls_prediction_key
CoCaConfig.vision_embd_prediction_key
CoCaConfig.vision_encoder_config
TextDecoderConfig
TextDecoderConfig.activation
TextDecoderConfig.attention_config
TextDecoderConfig.bias
TextDecoderConfig.block_size
TextDecoderConfig.dropout
TextDecoderConfig.epsilon
TextDecoderConfig.ffn_hidden
TextDecoderConfig.model_config
TextDecoderConfig.n_embd
TextDecoderConfig.n_head
TextDecoderConfig.n_layer_multimodal_text
TextDecoderConfig.n_layer_text
TextDecoderConfig.prediction_key
TextDecoderConfig.sample_key
TextDecoderConfig.vocab_size
- modalities.models.coca.collator module
- modalities.models.coca.multi_modal_decoder module
- modalities.models.coca.text_decoder module
- Module contents
- modalities.models.components package
- modalities.models.gpt2 package
- Submodules
- modalities.models.gpt2.collator module
- modalities.models.gpt2.gpt2_model module
AttentionConfig
AttentionImplementation
CausalSelfAttention
GPT2Block
GPT2LLM
GPT2LLMConfig
GPT2LLMConfig.activation_type
GPT2LLMConfig.attention_config
GPT2LLMConfig.attention_implementation
GPT2LLMConfig.attention_norm_config
GPT2LLMConfig.bias
GPT2LLMConfig.check_divisibility()
GPT2LLMConfig.dropout
GPT2LLMConfig.ffn_hidden
GPT2LLMConfig.ffn_norm_config
GPT2LLMConfig.lm_head_norm_config
GPT2LLMConfig.model_config
GPT2LLMConfig.n_embd
GPT2LLMConfig.n_head_kv
GPT2LLMConfig.n_head_q
GPT2LLMConfig.n_layer
GPT2LLMConfig.poe_type
GPT2LLMConfig.prediction_key
GPT2LLMConfig.sample_key
GPT2LLMConfig.sequence_length
GPT2LLMConfig.use_meta_device
GPT2LLMConfig.use_weight_tying
GPT2LLMConfig.validate_sizes()
GPT2LLMConfig.vocab_size
IdentityTransform
LayerNormWrapperConfig
LayerNorms
PositionTypes
QueryKeyValueTransform
QueryKeyValueTransformType
RotaryTransform
TransformerMLP
manual_scaled_dot_product_attention()
- Module contents
- modalities.models.huggingface package
- Submodules
- modalities.models.huggingface.huggingface_model module
HuggingFaceModelTypes
HuggingFacePretrainedModel
HuggingFacePretrainedModelConfig
HuggingFacePretrainedModelConfig.huggingface_prediction_subscription_key
HuggingFacePretrainedModelConfig.kwargs
HuggingFacePretrainedModelConfig.model_args
HuggingFacePretrainedModelConfig.model_config
HuggingFacePretrainedModelConfig.model_name
HuggingFacePretrainedModelConfig.model_type
HuggingFacePretrainedModelConfig.prediction_key
HuggingFacePretrainedModelConfig.sample_key
- Module contents
- modalities.models.huggingface_adapters package
- modalities.models.vision_transformer package
- Submodules
- modalities.models.vision_transformer.vision_transformer_model module
ImagePatchEmbedding
VisionTransformer
VisionTransformerBlock
VisionTransformerConfig
VisionTransformerConfig.add_cls_token
VisionTransformerConfig.attention_config
VisionTransformerConfig.bias
VisionTransformerConfig.dropout
VisionTransformerConfig.img_size
VisionTransformerConfig.model_config
VisionTransformerConfig.n_classes
VisionTransformerConfig.n_embd
VisionTransformerConfig.n_head
VisionTransformerConfig.n_img_channels
VisionTransformerConfig.n_layer
VisionTransformerConfig.patch_size
VisionTransformerConfig.patch_stride
VisionTransformerConfig.prediction_key
VisionTransformerConfig.sample_key
- Module contents
Submodules
modalities.models.model module
- class modalities.models.model.ActivationType(value)[source]
-
Enum class representing different activation types.
- Attributes:
GELU (str): GELU activation type. SWIGLU (str): SWIGLU activation type.
- GELU = 'gelu'
- SWIGLU = 'swiglu'
- class modalities.models.model.NNModel(seed=None, weight_decay_groups=None)[source]
Bases:
Module
NNModel class to define a base model.
Initializes an NNModel object.
- Args:
seed (int, optional): The seed value for random number generation. Defaults to None. weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None.
- abstractmethod forward(inputs)[source]
Forward pass of the model.
- Args:
inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
- Returns:
dict[str, torch.Tensor]: A dictionary containing output tensors.
- class modalities.models.model.SwiGLU(n_embd, ffn_hidden, bias)[source]
Bases:
Module
SwiGLU class to define the SwiGLU activation function.
Initializes the SwiGLU object.
- Args:
n_embd (int): The number of embedding dimensions. ffn_hidden (int): The number of hidden dimensions in the feed-forward network. Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) bias (bool): Whether to include bias terms in the linear layers.
- modalities.models.model.model_predict_batch(model, batch)[source]
Predicts the output for a batch of samples using the given model.
- Return type:
- Parameters:
model (Module)
batch (DatasetBatch)
- Args:
model (nn.Module): The model used for prediction. batch (DatasetBatch): The batch of samples to be predicted.
- Returns:
InferenceResultBatch: The batch of inference results containing the predicted targets and predictions.
modalities.models.model_factory module
- class modalities.models.model_factory.GPT2ModelFactory[source]
Bases:
object
- static get_gpt2_model(sample_key, prediction_key, poe_type, sequence_length, vocab_size, n_layer, n_head_q, n_head_kv, n_embd, ffn_hidden, dropout, bias, activation_type, attention_implementation, attention_config, attention_norm_config, ffn_norm_config, lm_head_norm_config, use_weight_tying, use_meta_device=False, seed=None)[source]
- Return type:
- Parameters:
sample_key (str)
prediction_key (str)
poe_type (PositionTypes)
sequence_length (int)
vocab_size (int)
n_layer (int)
n_head_q (int)
n_head_kv (int)
n_embd (int)
ffn_hidden (int)
dropout (float)
bias (bool)
activation_type (ActivationType)
attention_implementation (AttentionImplementation)
attention_config (AttentionConfig)
attention_norm_config (LayerNormWrapperConfig)
ffn_norm_config (LayerNormWrapperConfig)
lm_head_norm_config (LayerNormWrapperConfig)
use_weight_tying (bool)
use_meta_device (bool | None)
seed (int)
- static get_gpt2_tensor_parallelized_model(model, device_mesh)[source]
- Return type:
- Parameters:
model (GPT2LLM)
device_mesh (DeviceMesh)
- class modalities.models.model_factory.ModelFactory[source]
Bases:
object
Model factory class to create models.
- static get_activation_checkpointed_fsdp1_model_(model, activation_checkpointing_modules)[source]
Apply activation checkpointing to the given FSDP1-wrapped model (in-place operation).
- Return type:
- Parameters:
model (FullyShardedDataParallel)
- Args:
model (FSDP1): The FSDP1-wrapped model to apply activation checkpointing to. activation_checkpointing_modules (list[str]): List of module names to apply activation checkpointing to.
- Raises:
ValueError: Activation checkpointing can only be applied to FSDP1-wrapped models!
- Returns:
FSDP1: The model with activation checkpointing applied.
- static get_activation_checkpointed_fsdp2_model_(ac_variant, layers_fqn, model, ac_fun_params)[source]
FSDP2 variant for applying activation checkpointing to the given model (in-place operation).
- Return type:
- Parameters:
ac_variant (ActivationCheckpointingVariants)
layers_fqn (str)
model (Module)
ac_fun_params (FullACParams | SelectiveLayerACParams | SelectiveOpACParams)
- Important: When using FSDP2, we always first apply activation checkpointing to the model
and then wrap it with FSDP2.
- Args:
ac_variant (ActivationCheckpointingVariants): The activation checkpointing variant to use. layers_fqn (str): Fully qualified name (FQN) of the layers to apply activation checkpointing to. model (nn.Module): The (unwrapped) model to apply activation checkpointing to. ac_fun_params (ACM.FullACParams | ACM.SelectiveLayerACParams | ACM.SelectiveOpACParams):
The parameters for the activation checkpointing function, depending on the variant.
- Raises:
ValueError: Activation checkpointing can only be applied to unwrapped nn.Module models
- Returns:
nn.Module: The model with activation checkpointing applied.
- static get_compiled_model(model, block_names, fullgraph, debug=False)[source]
Apply torch.compile to each transformer block, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). Inspired by: https://github.com/pytorch/torchtitan/blob/6b2912a9b53464bfef744e62100716271b2b248f/torchtitan/parallelisms/parallelize_llama.py#L275
- Return type:
- Parameters:
- Note: With fullgraph=True, we enforce the block to be compiled as a whole, which raises an error on
graph breaks and maximizes speedup.
- Args:
model (nn.Module): The model to be compiled. block_names (list[str]): List of block names to be compiled individually. fullgraph (bool): Flag enforcing the block to be compiled without graph breaks. debug (Optional[bool]): Flag to enable debug mode. Default is False.
- Returns:
nn.Module: The compiled model.
- static get_debugging_enriched_model(model, logging_dir_path, tracked_ranks=None, log_interval_steps=1)[source]
Enriches the model with debugging hooks to log tensor statistics during forward and backward passes. During the forward pass, it logs the input and output tensors of each module, as well as the parameters. Similarly, during the backward pass, it logs the gradients of the output tensors.
- Return type:
- Parameters:
- The following tensor statistics are logged:
global shape
local shape
is_dtensor (whether the tensor is a DTensor)
nan count
inf count
mean
std
min
max
The statistics are written to a JSONL file in the specified logging directory.
- Args:
model (nn.Module): The model to be enriched with debugging hooks. logging_dir_path (Path): The directory path where the tensor statistics will be logged. tracked_ranks (Optional[Set[int]]): A set of ranks to track. If provided, only these ranks
will log the statistics. If None, all ranks will log the statistics.
log_interval_steps (int): The interval in steps at which to log the tensor statistics. Default is 1.
- static get_fsdp1_checkpointed_model(checkpoint_loading, checkpoint_path, model)[source]
Loads a FSDP1 checkpointed model from the given checkpoint path.
- Return type:
- Parameters:
checkpoint_loading (FSDP1CheckpointLoadingIF)
checkpoint_path (Path)
model (Module)
- Args:
- checkpoint_loading (FSDP1CheckpointLoadingIF): The checkpoint loading
approach used to load the model checkpoint.
checkpoint_path (Path): The path to the checkpoint file. model (nn.Module): The model to be loaded with the checkpoint.
- Returns:
nn.Module: The loaded wrapped model.
- get_fsdp1_wrapped_model(sync_module_states, block_names, mixed_precision_settings, sharding_strategy)[source]
Get the FSDP1-wrapped model.
- Return type:
- Parameters:
model (Module)
sync_module_states (bool)
mixed_precision_settings (MixedPrecisionSettings)
sharding_strategy (ShardingStrategy)
- Args:
model (nn.Module): The original model to be wrapped. sync_module_states (bool): Whether to synchronize module states across ranks. block_names (list[str]): List of block names. mixed_precision_settings (MixedPrecisionSettings): Mixed precision settings. sharding_strategy (ShardingStrategy): Sharding strategy.
- Returns:
FSDP1: The FSDP1-wrapped model.
- Note:
‘FSDPTransformerAutoWrapPolicyFactory` is hardcoded and should be passed in instead. Different auto wrap policies may be supported in the future.
- static get_fsdp2_wrapped_model(model, block_names, device_mesh, mixed_precision_settings, reshard_after_forward)[source]
Get the FSDP2-wrapped model.
Based on https://github.com/pytorch/torchtitan/blob/de9fd2b9ea7e763c9182e0df81fc32c2618cc0b6/torchtitan/parallelisms/parallelize_llama.py#L459 and https://github.com/pytorch/torchtitan/blob/43584e0a4e72645e25cccd05d86f9632587a8beb/docs/fsdp.md NOTE: Torch Titan already implement pipeline parallelism. We skip that here for now.
- Return type:
- Parameters:
model (Module)
device_mesh (DeviceMesh)
mixed_precision_settings (FSDP2MixedPrecisionSettings)
reshard_after_forward (bool)
- Args:
model (nn.Module): The original model to be wrapped. block_names (list[str]): List of block names. device_mesh (DeviceMesh): The device mesh. mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings. reshard_after_forward (bool): Whether to reshard after forward.
- Returns:
FSDP2: The FSDP2-wrapped model.
- static get_weight_initialized_model(model, model_initializer)[source]
Initializes the given model with weights using the provided model initializer. The model can be on a meta device.
- Return type:
- Parameters:
model (Module)
model_initializer (ModelInitializationIF)
- Args:
model (nn.Module): The model to be initialized with weights. model_initializer (ModelInitializationIF): The model initializer object.
- Returns:
nn.Module: The initialized model.
modalities.models.utils module
- class modalities.models.utils.ModelTypeEnum(value)[source]
Bases:
Enum
Enumeration class representing different types of models.
- Attributes:
MODEL (str): Represents a regular model. CHECKPOINTED_MODEL (str): Represents a checkpointed model.
- CHECKPOINTED_MODEL = 'checkpointed_model'
- MODEL = 'model'
- modalities.models.utils.get_model_from_config(config, model_type)[source]
Retrieves a model from the given configuration based on the specified model type.
- Args:
config (dict): The configuration dictionary. model_type (ModelTypeEnum): The type of the model to retrieve.
- Returns:
Any: The model object based on the specified model type.
- Raises:
NotImplementedError: If the model type is not supported.
- Parameters:
config (dict)
model_type (ModelTypeEnum)