modalities.models package
Subpackages
- modalities.models.coca package
- Submodules
- modalities.models.coca.attention_pooling module
- modalities.models.coca.coca_model module
CoCa
CoCaConfig
CoCaConfig.bias_attn_pool
CoCaConfig.epsilon_attn_pool
CoCaConfig.model_config
CoCaConfig.n_pool_head
CoCaConfig.n_vision_queries
CoCaConfig.prediction_key
CoCaConfig.text_cls_prediction_key
CoCaConfig.text_decoder_config
CoCaConfig.text_embd_prediction_key
CoCaConfig.vision_cls_prediction_key
CoCaConfig.vision_embd_prediction_key
CoCaConfig.vision_encoder_config
TextDecoderConfig
TextDecoderConfig.activation
TextDecoderConfig.attention_config
TextDecoderConfig.bias
TextDecoderConfig.block_size
TextDecoderConfig.dropout
TextDecoderConfig.epsilon
TextDecoderConfig.ffn_hidden
TextDecoderConfig.model_config
TextDecoderConfig.n_embd
TextDecoderConfig.n_head
TextDecoderConfig.n_layer_multimodal_text
TextDecoderConfig.n_layer_text
TextDecoderConfig.prediction_key
TextDecoderConfig.sample_key
TextDecoderConfig.vocab_size
- modalities.models.coca.collator module
- modalities.models.coca.multi_modal_decoder module
- modalities.models.coca.text_decoder module
- Module contents
- modalities.models.components package
- modalities.models.gpt2 package
- Submodules
- modalities.models.gpt2.collator module
- modalities.models.gpt2.gpt2_model module
AttentionConfig
AttentionImplementation
CausalSelfAttention
GPT2Block
GPT2LLM
GPT2LLMConfig
GPT2LLMConfig.activation_type
GPT2LLMConfig.attention_config
GPT2LLMConfig.attention_implementation
GPT2LLMConfig.attention_norm_config
GPT2LLMConfig.bias
GPT2LLMConfig.check_divisibility()
GPT2LLMConfig.dropout
GPT2LLMConfig.ffn_hidden
GPT2LLMConfig.ffn_norm_config
GPT2LLMConfig.lm_head_norm_config
GPT2LLMConfig.model_config
GPT2LLMConfig.n_embd
GPT2LLMConfig.n_head_kv
GPT2LLMConfig.n_head_q
GPT2LLMConfig.n_layer
GPT2LLMConfig.poe_type
GPT2LLMConfig.prediction_key
GPT2LLMConfig.sample_key
GPT2LLMConfig.sequence_length
GPT2LLMConfig.use_meta_device
GPT2LLMConfig.use_weight_tying
GPT2LLMConfig.validate_sizes()
GPT2LLMConfig.vocab_size
IdentityTransform
LayerNormWrapperConfig
LayerNorms
PositionTypes
QueryKeyValueTransform
QueryKeyValueTransformType
RotaryTransform
TransformerMLP
manual_scaled_dot_product_attention()
- modalities.models.gpt2.pretrained_gpt_model module
- Module contents
- modalities.models.huggingface package
- Submodules
- modalities.models.huggingface.huggingface_model module
HuggingFaceModelTypes
HuggingFacePretrainedModel
HuggingFacePretrainedModelConfig
HuggingFacePretrainedModelConfig.huggingface_prediction_subscription_key
HuggingFacePretrainedModelConfig.kwargs
HuggingFacePretrainedModelConfig.model_args
HuggingFacePretrainedModelConfig.model_config
HuggingFacePretrainedModelConfig.model_name
HuggingFacePretrainedModelConfig.model_type
HuggingFacePretrainedModelConfig.prediction_key
HuggingFacePretrainedModelConfig.sample_key
- Module contents
- modalities.models.huggingface_adapters package
- modalities.models.vision_transformer package
- Submodules
- modalities.models.vision_transformer.vision_transformer_model module
ImagePatchEmbedding
VisionTransformer
VisionTransformerBlock
VisionTransformerConfig
VisionTransformerConfig.add_cls_token
VisionTransformerConfig.attention_config
VisionTransformerConfig.bias
VisionTransformerConfig.dropout
VisionTransformerConfig.img_size
VisionTransformerConfig.model_config
VisionTransformerConfig.n_classes
VisionTransformerConfig.n_embd
VisionTransformerConfig.n_head
VisionTransformerConfig.n_img_channels
VisionTransformerConfig.n_layer
VisionTransformerConfig.patch_size
VisionTransformerConfig.patch_stride
VisionTransformerConfig.prediction_key
VisionTransformerConfig.sample_key
- Module contents
Submodules
modalities.models.model module
- class modalities.models.model.ActivationType(value)[source]
-
Enum class representing different activation types.
- Attributes:
GELU (str): GELU activation type. SWIGLU (str): SWIGLU activation type.
- GELU = 'gelu'
- SWIGLU = 'swiglu'
- class modalities.models.model.NNModel(seed=None, weight_decay_groups=None)[source]
Bases:
Module
NNModel class to define a base model.
Initializes an NNModel object.
- Args:
seed (int, optional): The seed value for random number generation. Defaults to None. weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None.
- abstractmethod forward(inputs)[source]
Forward pass of the model.
- Args:
inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
- Returns:
dict[str, torch.Tensor]: A dictionary containing output tensors.
- class modalities.models.model.SwiGLU(n_embd, ffn_hidden, bias)[source]
Bases:
Module
SwiGLU class to define the SwiGLU activation function.
Initializes the SwiGLU object.
- Args:
n_embd (int): The number of embedding dimensions. ffn_hidden (int): The number of hidden dimensions in the feed-forward network. Best practice: 4 * n_embd (https://arxiv.org/pdf/1706.03762) bias (bool): Whether to include bias terms in the linear layers.
- modalities.models.model.model_predict_batch(model, batch)[source]
Predicts the output for a batch of samples using the given model.
- Return type:
- Parameters:
model (Module)
batch (DatasetBatch)
- Args:
model (nn.Module): The model used for prediction. batch (DatasetBatch): The batch of samples to be predicted.
- Returns:
InferenceResultBatch: The batch of inference results containing the predicted targets and predictions.
modalities.models.model_factory module
- class modalities.models.model_factory.GPT2ModelFactory[source]
Bases:
object
- static get_gpt2_model(sample_key, prediction_key, poe_type, sequence_length, vocab_size, n_layer, n_head_q, n_head_kv, n_embd, ffn_hidden, dropout, bias, activation_type, attention_implementation, attention_config, attention_norm_config, ffn_norm_config, lm_head_norm_config, use_weight_tying, use_meta_device=False, seed=None)[source]
- Return type:
- Parameters:
sample_key (str)
prediction_key (str)
poe_type (PositionTypes)
sequence_length (int)
vocab_size (int)
n_layer (int)
n_head_q (int)
n_head_kv (int)
n_embd (int)
ffn_hidden (int)
dropout (float)
bias (bool)
activation_type (ActivationType)
attention_implementation (AttentionImplementation)
attention_config (AttentionConfig)
attention_norm_config (LayerNormWrapperConfig)
ffn_norm_config (LayerNormWrapperConfig)
lm_head_norm_config (LayerNormWrapperConfig)
use_weight_tying (bool)
use_meta_device (bool | None)
seed (int)
- class modalities.models.model_factory.ModelFactory[source]
Bases:
object
Model factory class to create models.
- static get_activation_checkpointed_fsdp1_model(model, activation_checkpointing_modules)[source]
Apply activation checkpointing to the given model (in-place operation).
- Return type:
- Parameters:
model (FullyShardedDataParallel)
- Args:
model (FSDP1): The FSDP1-wrapped model to apply activation checkpointing to. activation_checkpointing_modules (list[str]): List of module names to apply activation checkpointing to.
- Raises:
ValueError: Activation checkpointing can only be applied to FSDP1-wrapped models!
- Returns:
FSDP1: The model with activation checkpointing applied.
- static get_compiled_model(model, block_names, fullgraph, debug=False)[source]
Apply torch.compile to each transformer block, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). Inspired by: https://github.com/pytorch/torchtitan/blob/6b2912a9b53464bfef744e62100716271b2b248f/torchtitan/parallelisms/parallelize_llama.py#L275
- Return type:
- Parameters:
- Note: With fullgraph=True, we enforce the block to be compiled as a whole, which raises an error on
graph breaks and maximizes speedup.
- Args:
model (nn.Module): The model to be compiled. block_names (list[str]): List of block names to be compiled individually. fullgraph (bool): Flag enforcing the block to be compiled without graph breaks. debug (Optional[bool]): Flag to enable debug mode. Default is False.
- Returns:
nn.Module: The compiled model.
- static get_fsdp1_checkpointed_model(checkpoint_loading, checkpoint_path, model)[source]
Loads a FSDP1 checkpointed model from the given checkpoint path.
- Return type:
- Parameters:
checkpoint_loading (FSDP1CheckpointLoadingIF)
checkpoint_path (Path)
model (Module)
- Args:
- checkpoint_loading (FSDP1CheckpointLoadingIF): The checkpoint loading
approach used to load the model checkpoint.
checkpoint_path (Path): The path to the checkpoint file. model (nn.Module): The model to be loaded with the checkpoint.
- Returns:
nn.Module: The loaded wrapped model.
- get_fsdp1_wrapped_model(sync_module_states, block_names, mixed_precision_settings, sharding_strategy)[source]
Get the FSDP1-wrapped model.
- Return type:
- Parameters:
model (Module)
sync_module_states (bool)
mixed_precision_settings (MixedPrecisionSettings)
sharding_strategy (ShardingStrategy)
- Args:
model (nn.Module): The original model to be wrapped. sync_module_states (bool): Whether to synchronize module states across ranks. block_names (list[str]): List of block names. mixed_precision_settings (MixedPrecisionSettings): Mixed precision settings. sharding_strategy (ShardingStrategy): Sharding strategy.
- Returns:
FSDP1: The FSDP1-wrapped model.
- Note:
‘FSDPTransformerAutoWrapPolicyFactory` is hardcoded and should be passed in instead. Different auto wrap policies may be supported in the future.
- static get_fsdp2_wrapped_model(model, block_names, device_mesh, mixed_precision_settings, reshard_after_forward)[source]
Get the FSDP2-wrapped model.
Based on https://github.com/pytorch/torchtitan/blob/de9fd2b9ea7e763c9182e0df81fc32c2618cc0b6/torchtitan/parallelisms/parallelize_llama.py#L459 and https://github.com/pytorch/torchtitan/blob/43584e0a4e72645e25cccd05d86f9632587a8beb/docs/fsdp.md NOTE: Torch Titan already implement pipeline parallelism. We skip that here for now.
- Return type:
- Parameters:
model (Module)
device_mesh (DeviceMesh)
mixed_precision_settings (FSDP2MixedPrecisionSettings)
reshard_after_forward (bool)
- Args:
model (nn.Module): The original model to be wrapped. block_names (list[str]): List of block names. device_mesh (DeviceMesh): The device mesh. mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings. reshard_after_forward (bool): Whether to reshard after forward.
- Returns:
FSDP2: The FSDP2-wrapped model.
- static get_weight_initialized_model(model, model_initializer)[source]
Initializes the given model with weights using the provided model initializer. The model can be on a meta device.
- Return type:
- Parameters:
model (Module)
model_initializer (ModelInitializationIF)
- Args:
model (nn.Module): The model to be initialized with weights. model_initializer (ModelInitializationIF): The model initializer object.
- Returns:
nn.Module: The initialized model.
modalities.models.utils module
- class modalities.models.utils.ModelTypeEnum(value)[source]
Bases:
Enum
Enumeration class representing different types of models.
- Attributes:
MODEL (str): Represents a regular model. CHECKPOINTED_MODEL (str): Represents a checkpointed model.
- CHECKPOINTED_MODEL = 'checkpointed_model'
- MODEL = 'model'
- modalities.models.utils.get_model_from_config(config, model_type)[source]
Retrieves a model from the given configuration based on the specified model type.
- Args:
config (dict): The configuration dictionary. model_type (ModelTypeEnum): The type of the model to retrieve.
- Returns:
Any: The model object based on the specified model type.
- Raises:
NotImplementedError: If the model type is not supported.
- Parameters:
config (dict)
model_type (ModelTypeEnum)