modalities.models package
Subpackages
- modalities.models.coca package- Submodules
- modalities.models.coca.attention_pooling module
- modalities.models.coca.coca_model module- CoCa
- CoCaConfig- CoCaConfig.bias_attn_pool
- CoCaConfig.epsilon_attn_pool
- CoCaConfig.model_config
- CoCaConfig.n_pool_head
- CoCaConfig.n_vision_queries
- CoCaConfig.prediction_key
- CoCaConfig.text_cls_prediction_key
- CoCaConfig.text_decoder_config
- CoCaConfig.text_embd_prediction_key
- CoCaConfig.vision_cls_prediction_key
- CoCaConfig.vision_embd_prediction_key
- CoCaConfig.vision_encoder_config
 
- TextDecoderConfig- TextDecoderConfig.activation
- TextDecoderConfig.attention_config
- TextDecoderConfig.bias
- TextDecoderConfig.block_size
- TextDecoderConfig.dropout
- TextDecoderConfig.epsilon
- TextDecoderConfig.ffn_hidden
- TextDecoderConfig.model_config
- TextDecoderConfig.n_embd
- TextDecoderConfig.n_head
- TextDecoderConfig.n_layer_multimodal_text
- TextDecoderConfig.n_layer_text
- TextDecoderConfig.prediction_key
- TextDecoderConfig.sample_key
- TextDecoderConfig.vocab_size
 
 
- modalities.models.coca.collator module
- modalities.models.coca.multi_modal_decoder module
- modalities.models.coca.text_decoder module
- Module contents
 
- modalities.models.components package
- modalities.models.gpt2 package- Submodules
- modalities.models.gpt2.collator module
- modalities.models.gpt2.gpt2_model module- AttentionConfig
- AttentionImplementation
- CausalSelfAttention
- GPT2Block
- GPT2LLM
- GPT2LLMConfig- GPT2LLMConfig.activation_type
- GPT2LLMConfig.attention_config
- GPT2LLMConfig.attention_implementation
- GPT2LLMConfig.attention_norm_config
- GPT2LLMConfig.bias
- GPT2LLMConfig.check_divisibility()
- GPT2LLMConfig.dropout
- GPT2LLMConfig.enforce_swiglu_hidden_dim_multiple_of
- GPT2LLMConfig.ffn_hidden
- GPT2LLMConfig.ffn_norm_config
- GPT2LLMConfig.lm_head_norm_config
- GPT2LLMConfig.model_config
- GPT2LLMConfig.n_embd
- GPT2LLMConfig.n_head_kv
- GPT2LLMConfig.n_head_q
- GPT2LLMConfig.n_layer
- GPT2LLMConfig.poe_type
- GPT2LLMConfig.prediction_key
- GPT2LLMConfig.sample_key
- GPT2LLMConfig.seed
- GPT2LLMConfig.sequence_length
- GPT2LLMConfig.use_meta_device
- GPT2LLMConfig.use_weight_tying
- GPT2LLMConfig.validate_sizes()
- GPT2LLMConfig.vocab_size
 
- IdentityTransform
- LayerNormWrapperConfig
- LayerNorms
- PositionTypes
- QueryKeyValueTransform
- QueryKeyValueTransformType
- RotaryTransform
- TransformerMLP
- manual_scaled_dot_product_attention()
 
- Module contents
 
- modalities.models.huggingface package- Submodules
- modalities.models.huggingface.huggingface_model module- HuggingFaceModelTypes
- HuggingFacePretrainedModel
- HuggingFacePretrainedModelConfig- HuggingFacePretrainedModelConfig.huggingface_prediction_subscription_key
- HuggingFacePretrainedModelConfig.kwargs
- HuggingFacePretrainedModelConfig.model_args
- HuggingFacePretrainedModelConfig.model_config
- HuggingFacePretrainedModelConfig.model_name
- HuggingFacePretrainedModelConfig.model_type
- HuggingFacePretrainedModelConfig.prediction_key
- HuggingFacePretrainedModelConfig.sample_key
 
 
- Module contents
 
- modalities.models.huggingface_adapters package
- modalities.models.vision_transformer package- Submodules
- modalities.models.vision_transformer.vision_transformer_model module- ImagePatchEmbedding
- VisionTransformer
- VisionTransformerBlock
- VisionTransformerConfig- VisionTransformerConfig.add_cls_token
- VisionTransformerConfig.attention_config
- VisionTransformerConfig.bias
- VisionTransformerConfig.dropout
- VisionTransformerConfig.img_size
- VisionTransformerConfig.model_config
- VisionTransformerConfig.n_classes
- VisionTransformerConfig.n_embd
- VisionTransformerConfig.n_head
- VisionTransformerConfig.n_img_channels
- VisionTransformerConfig.n_layer
- VisionTransformerConfig.patch_size
- VisionTransformerConfig.patch_stride
- VisionTransformerConfig.prediction_key
- VisionTransformerConfig.sample_key
 
 
- Module contents
 
Submodules
modalities.models.model module
- class modalities.models.model.ActivationType(value)[source]
- 
Enum class representing different activation types. - Attributes:
- GELU (str): GELU activation type. SWIGLU (str): SWIGLU activation type. 
 - GELU = 'gelu'
 - SWIGLU = 'swiglu'
 
- class modalities.models.model.NNModel(seed=None, weight_decay_groups=None)[source]
- Bases: - Module- NNModel class to define a base model. - Initializes an NNModel object. - Args:
- seed (int, optional): The seed value for random number generation. Defaults to None. weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None. 
 - abstractmethod forward(inputs)[source]
- Forward pass of the model. - Args:
- inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. 
- Returns:
- dict[str, torch.Tensor]: A dictionary containing output tensors. 
 
 
- class modalities.models.model.SwiGLU(n_embd, ffn_hidden, bias, enforce_swiglu_hidden_dim_multiple_of=256)[source]
- Bases: - Module- SwiGLU class to define the SwiGLU activation function. - Initializes the SwiGLU object. - Args:
- n_embd (int): The number of embedding dimensions. ffn_hidden (int): The number of hidden dimensions in the feed-forward network. Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) bias (bool): Whether to include bias terms in the linear layers. enforce_swiglu_hidden_dim_multiple_of (int): The multiple of which the hidden - dimension should be enforced. Defaults to 256. This is required for FSDP + TP as the combincation does not support uneven sharding (yet). Defaults to 256 if not provided. 
 
- modalities.models.model.model_predict_batch(model, batch)[source]
- Predicts the output for a batch of samples using the given model. - Return type:
- Parameters:
- model (Module) 
- batch (DatasetBatch) 
 
 - Args:
- model (nn.Module): The model used for prediction. batch (DatasetBatch): The batch of samples to be predicted. 
- Returns:
- InferenceResultBatch: The batch of inference results containing the predicted targets and predictions. 
 
modalities.models.model_factory module
- class modalities.models.model_factory.GPT2ModelFactory[source]
- Bases: - object- static get_gpt2_model(sample_key, prediction_key, poe_type, sequence_length, vocab_size, n_layer, n_head_q, n_head_kv, n_embd, ffn_hidden, dropout, bias, activation_type, attention_implementation, attention_config, attention_norm_config, ffn_norm_config, lm_head_norm_config, use_weight_tying, use_meta_device=False, seed=None, enforce_swiglu_hidden_dim_multiple_of=256)[source]
- Return type:
- Parameters:
- sample_key (str) 
- prediction_key (str) 
- poe_type (PositionTypes) 
- sequence_length (int) 
- vocab_size (int) 
- n_layer (int) 
- n_head_q (int) 
- n_head_kv (int) 
- n_embd (int) 
- ffn_hidden (int) 
- dropout (float) 
- bias (bool) 
- activation_type (ActivationType) 
- attention_implementation (AttentionImplementation) 
- attention_config (AttentionConfig) 
- attention_norm_config (LayerNormWrapperConfig) 
- ffn_norm_config (LayerNormWrapperConfig) 
- lm_head_norm_config (LayerNormWrapperConfig) 
- use_weight_tying (bool) 
- use_meta_device (bool | None) 
- seed (int | None) 
- enforce_swiglu_hidden_dim_multiple_of (int) 
 
 
 - static get_gpt2_tensor_parallelized_model(model, device_mesh)[source]
- Return type:
- Parameters:
- model (GPT2LLM) 
- device_mesh (DeviceMesh) 
 
 
 
- class modalities.models.model_factory.ModelFactory[source]
- Bases: - object- Model factory class to create models. - static get_activation_checkpointed_fsdp1_model_(model, activation_checkpointing_modules)[source]
- Apply activation checkpointing to the given FSDP1-wrapped model (in-place operation). - Return type:
- Parameters:
- model (FullyShardedDataParallel) 
 
 - Args:
- model (FSDP1): The FSDP1-wrapped model to apply activation checkpointing to. activation_checkpointing_modules (list[str]): List of module names to apply activation checkpointing to. 
- Raises:
- ValueError: Activation checkpointing can only be applied to FSDP1-wrapped models! 
- Returns:
- FSDP1: The model with activation checkpointing applied. 
 
 - static get_activation_checkpointed_fsdp2_model_(ac_variant, layers_fqn, model, ac_fun_params)[source]
- FSDP2 variant for applying activation checkpointing to the given model (in-place operation). - Return type:
- Parameters:
- ac_variant (ActivationCheckpointingVariants) 
- layers_fqn (str) 
- model (Module) 
- ac_fun_params (FullACParams | SelectiveLayerACParams | SelectiveOpACParams) 
 
 - Important: When using FSDP2, we always first apply activation checkpointing to the model
- and then wrap it with FSDP2. 
- Args:
- ac_variant (ActivationCheckpointingVariants): The activation checkpointing variant to use. layers_fqn (str): Fully qualified name (FQN) of the layers to apply activation checkpointing to. model (nn.Module): The (unwrapped) model to apply activation checkpointing to. ac_fun_params (ACM.FullACParams | ACM.SelectiveLayerACParams | ACM.SelectiveOpACParams): - The parameters for the activation checkpointing function, depending on the variant. 
- Raises:
- ValueError: Activation checkpointing can only be applied to unwrapped nn.Module models 
- Returns:
- nn.Module: The model with activation checkpointing applied. 
 
 - static get_compiled_model(model, block_names, fullgraph, debug=False)[source]
- Apply torch.compile to each transformer block, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). Inspired by: https://github.com/pytorch/torchtitan/blob/6b2912a9b53464bfef744e62100716271b2b248f/torchtitan/parallelisms/parallelize_llama.py#L275 - Return type:
- Parameters:
 - Note: With fullgraph=True, we enforce the block to be compiled as a whole, which raises an error on
- graph breaks and maximizes speedup. 
- Args:
- model (nn.Module): The model to be compiled. block_names (list[str]): List of block names to be compiled individually. fullgraph (bool): Flag enforcing the block to be compiled without graph breaks. debug (Optional[bool]): Flag to enable debug mode. Default is False. 
- Returns:
- nn.Module: The compiled model. 
 
 - static get_debugging_enriched_model(model, logging_dir_path, tracked_ranks=None, log_interval_steps=1)[source]
- Enriches the model with debugging hooks to log tensor statistics during forward and backward passes. During the forward pass, it logs the input and output tensors of each module, as well as the parameters. Similarly, during the backward pass, it logs the gradients of the output tensors. - Return type:
- Parameters:
 - The following tensor statistics are logged:
- global shape 
- local shape 
- is_dtensor (whether the tensor is a DTensor) 
- nan count 
- inf count 
- mean 
- std 
- min 
- max 
 
 - The statistics are written to a JSONL file in the specified logging directory. - Args:
- model (nn.Module): The model to be enriched with debugging hooks. logging_dir_path (Path): The directory path where the tensor statistics will be logged. tracked_ranks (Optional[Set[int]]): A set of ranks to track. If provided, only these ranks - will log the statistics. If None, all ranks will log the statistics. - log_interval_steps (int): The interval in steps at which to log the tensor statistics. Default is 1. 
 
 - static get_fsdp1_checkpointed_model(checkpoint_loading, checkpoint_path, model)[source]
- Loads a FSDP1 checkpointed model from the given checkpoint path. - Return type:
- Parameters:
- checkpoint_loading (FSDP1CheckpointLoadingIF) 
- checkpoint_path (Path) 
- model (Module) 
 
 - Args:
- checkpoint_loading (FSDP1CheckpointLoadingIF): The checkpoint loading
- approach used to load the model checkpoint. 
 - checkpoint_path (Path): The path to the checkpoint file. model (nn.Module): The model to be loaded with the checkpoint. 
- Returns:
- nn.Module: The loaded wrapped model. 
 
 - get_fsdp1_wrapped_model(sync_module_states, block_names, mixed_precision_settings, sharding_strategy)[source]
- Get the FSDP1-wrapped model. - Return type:
- Parameters:
- model (Module) 
- sync_module_states (bool) 
- mixed_precision_settings (MixedPrecisionSettings) 
- sharding_strategy (ShardingStrategy) 
 
 - Args:
- model (nn.Module): The original model to be wrapped. sync_module_states (bool): Whether to synchronize module states across ranks. block_names (list[str]): List of block names. mixed_precision_settings (MixedPrecisionSettings): Mixed precision settings. sharding_strategy (ShardingStrategy): Sharding strategy. 
- Returns:
- FSDP1: The FSDP1-wrapped model. 
- Note:
- ‘FSDPTransformerAutoWrapPolicyFactory` is hardcoded and should be passed in instead. Different auto wrap policies may be supported in the future. 
 
 - static get_fsdp2_wrapped_model(model, block_names, device_mesh, mixed_precision_settings, reshard_after_forward)[source]
- Get the FSDP2-wrapped model. - Based on https://github.com/pytorch/torchtitan/blob/de9fd2b9ea7e763c9182e0df81fc32c2618cc0b6/torchtitan/parallelisms/parallelize_llama.py#L459 and https://github.com/pytorch/torchtitan/blob/43584e0a4e72645e25cccd05d86f9632587a8beb/docs/fsdp.md NOTE: Torch Titan already implement pipeline parallelism. We skip that here for now. - Return type:
- Parameters:
- model (Module) 
- device_mesh (DeviceMesh) 
- mixed_precision_settings (FSDP2MixedPrecisionSettings) 
- reshard_after_forward (bool) 
 
 - Args:
- model (nn.Module): The original model to be wrapped. block_names (list[str]): List of block names. device_mesh (DeviceMesh): The device mesh. mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings. reshard_after_forward (bool): Whether to reshard after forward. 
- Returns:
- FSDP2: The FSDP2-wrapped model. 
 
 - static get_weight_initialized_model(model, model_initializer)[source]
- Initializes the given model with weights using the provided model initializer. The model can be on a meta device. - Return type:
- Parameters:
- model (Module) 
- model_initializer (ModelInitializationIF) 
 
 - Args:
- model (nn.Module): The model to be initialized with weights. model_initializer (ModelInitializationIF): The model initializer object. 
- Returns:
- nn.Module: The initialized model. 
 
 
modalities.models.utils module
- class modalities.models.utils.ModelTypeEnum(value)[source]
- Bases: - Enum- Enumeration class representing different types of models. - Attributes:
- MODEL (str): Represents a regular model. CHECKPOINTED_MODEL (str): Represents a checkpointed model. 
 - CHECKPOINTED_MODEL = 'checkpointed_model'
 - MODEL = 'model'
 
- modalities.models.utils.get_model_from_config(config, model_type)[source]
- Retrieves a model from the given configuration based on the specified model type. - Args:
- config (dict): The configuration dictionary. model_type (ModelTypeEnum): The type of the model to retrieve. 
- Returns:
- Any: The model object based on the specified model type. 
- Raises:
- NotImplementedError: If the model type is not supported. 
 - Parameters:
- config (dict) 
- model_type (ModelTypeEnum)