modalities.models.gpt2 package
Submodules
modalities.models.gpt2.collator module
- class modalities.models.gpt2.collator.CollateFnIF[source]
Bases:
ABC
CollateFnIF class to define a collate function interface.
- class modalities.models.gpt2.collator.GPT2LLMCollateFn(sample_key, target_key)[source]
Bases:
CollateFnIF
GPT2LLMCollateFn class to define a collate function for GPT2 language model.
Initializes the Collator object.
- Args:
sample_key (str): The key for accessing the sample data. target_key (str): The key for accessing the target data.
modalities.models.gpt2.gpt2_model module
- class modalities.models.gpt2.gpt2_model.AttentionConfig(**data)[source]
Bases:
BaseModel
Configuration class for attention mechanism.
- Attributes:
qkv_transforms (list[QueryKeyValueTransformConfig]): List of configurations for query-key-value transforms.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
qkv_transforms (list[QueryKeyValueTransformConfig])
- class QueryKeyValueTransformConfig(**data)[source]
Bases:
BaseModel
Configuration class for QueryKeyValueTransform.
- Attributes:
type_hint (QueryKeyValueTransformType): The type hint for the transform. config (RotaryTransformConfig | IdentityTransformConfig): The configuration for the transform.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
type_hint (QueryKeyValueTransformType)
config (RotaryTransformConfig | IdentityTransformConfig)
- class IdentityTransformConfig(**data)[source]
Bases:
BaseModel
IdentityTransformConfig class.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class RotaryTransformConfig(**data)[source]
Bases:
BaseModel
Configuration class for RotaryTransform.
- Attributes:
n_embd (int): Number of embeddings. n_head (int): Number of attention heads. seq_length_dim (int): Dimension of the sequence length. base_freq (int): Base frequency for RoPE.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
n_embd (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=0)])])
n_head (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=0)])])
seq_length_dim (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True)])])
base_freq (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=10000)])])
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- classmethod parse_sharding_strategy_by_name(name)[source]
Parses a QueryKeyValueTransform by its name.
- Args:
name (str): The name of the sharding strategy.
- Returns:
QueryKeyValueTransformType: The parsed sharding strategy.
-
type_hint:
QueryKeyValueTransformType
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
-
qkv_transforms:
list
[QueryKeyValueTransformConfig
]
- class modalities.models.gpt2.gpt2_model.AttentionImplementation(value)[source]
-
Enum class representing different implementations of attention.
- Attributes:
MANUAL (str): Manual attention implementation. PYTORCH_FLASH (str): PyTorch’s flash attention implementation. DAO_FLASH (str): DAO’s flash attention implementation.
- DAO_FLASH = 'dao_flash'
- MANUAL = 'manual'
- PYTORCH_FLASH = 'pytorch_flash'
- class modalities.models.gpt2.gpt2_model.CausalSelfAttention(n_head_q, n_head_kv, n_embd, attention_config, attention_impl, bias, dropout)[source]
Bases:
Module
Causal Self Attention class.
Initializes the CausalSelfAttention object.
- Args:
n_head_q (int): Number of attention heads for queries. n_head_kv (int): Number of attention heads for keys and values. n_embd (int): Size of the embedding dimension. attention_config (AttentionConfig): The attention configuration. attention_impl (AttentionImplementation): The attention implementation. bias (bool): Whether to include bias in linear layers. dropout (float): Dropout rate.
- Returns:
None
- Parameters:
n_head_q (int)
n_head_kv (int)
n_embd (int)
attention_config (AttentionConfig)
attention_impl (AttentionImplementation)
bias (bool)
dropout (float)
- classmethod execute_attention(q, k, v, dropout, attention_impl)[source]
Executes attention mechanism based on the specified implementation.
- Return type:
- Parameters:
q (Tensor)
k (Tensor)
v (Tensor)
dropout (float)
attention_impl (AttentionImplementation)
- Args:
cls (object): The class object. q (torch.Tensor): The query tensor. k (torch.Tensor): The key tensor. v (torch.Tensor): The value tensor. dropout (float): The dropout rate. attention_impl (AttentionImplementation): The attention implementation to use.
- Returns:
torch.Tensor: The output tensor.
- Raises:
NotImplementedError: If the specified attention implementation is not supported.
- static execute_qkv_transforms(q, k, v, qkv_transforms, n_head_q)[source]
Applies a series of transformations to the query, key, and value tensors.
- Return type:
- Parameters:
q (Tensor)
k (Tensor)
v (Tensor)
qkv_transforms (ModuleList)
n_head_q (int)
- Args:
q (torch.Tensor): The query tensors. k (torch.Tensor): The key tensors v (torch.Tensor): The value tensors. qkv_transforms (nn.ModuleList): A list of transformation modules to be applied to q, k, and v. n_head_q (int): The number of heads for the query tensors.
- Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the transformed query, key, and value tensors.
- forward(x)[source]
Forward pass of the CausalSelfAttention module.
- Args:
x (torch.Tensor): Input tensor of shape (B, T, n_embd)
- Returns:
torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection.
- projection(x)[source]
Applies projections to the input tensor to get queries, keys, and values.
- Args:
x (torch.Tensor): The input tensor.
- Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the query, key, and value tensors.
- classmethod repeat_kv_heads(q, k, v)[source]
Repeats the key-value (k, v) heads if the number of query (q) heads is different.
- Args:
cls (class): The class object. q (torch.Tensor): The query tensor of shape (B, nh_q, T, hs). k (torch.Tensor): The key tensor of shape (B, nh_kv, T, hs). v (torch.Tensor): The value tensor of shape (B, nh_kv, T, hs).
- Returns:
tuple: A tuple containing the repeated key tensor (k) and the repeated value tensor (v).
- class modalities.models.gpt2.gpt2_model.GPT2Block(n_embd, bias, n_head_q, n_head_kv, activation_type, attention_impl, attention_config, dropout, ffn_hidden, attention_norm, ffn_norm)[source]
Bases:
Module
GPT2Block class.
Initializes the GPT2Block.
- Args:
n_embd (int): The embedding dimension. bias (bool): Whether to include bias in the model. n_head_q (int): The number of attention heads for queries. n_head_kv (int): The number of attention heads for keys and values. activation_type (ActivationType): The type of activation function to use. attention_impl (AttentionImplementation): The implementation of attention mechanism. attention_config (AttentionConfig): The configuration for attention mechanism. dropout (float): The dropout rate. ffn_hidden (int): The size of the hidden layer in the feed-forward network. attention_norm (nn.Module): The normalization layer for attention. ffn_norm (nn.Module): The normalization layer for feed-forward network.
- Parameters:
n_embd (int)
bias (bool)
n_head_q (int)
n_head_kv (int)
activation_type (ActivationType)
attention_impl (AttentionImplementation)
attention_config (AttentionConfig)
dropout (float)
ffn_hidden (int)
attention_norm (Module)
ffn_norm (Module)
- class modalities.models.gpt2.gpt2_model.GPT2LLM(sample_key, prediction_key, poe_type, sequence_length, vocab_size, n_layer, n_head_q, n_head_kv, n_embd, ffn_hidden, dropout, bias, activation_type, attention_implementation, attention_config, attention_norm_config, ffn_norm_config, lm_head_norm_config, use_weight_tying, seed=None)[source]
Bases:
NNModel
GPT2LLM class.
Initializes the GPT2LLM object.
- Args:
sample_key (str): The sample key. prediction_key (str): The prediction key. poe_type (PositionTypes): The position type. sequence_length (int): The sequence length. vocab_size (int): The vocabulary size. n_layer (int): The number of layers. n_head_q (int): The number of query heads. n_head_kv (int): The number of key-value heads. n_embd (int): The embedding dimension. ffn_hidden (int): The hidden dimension of the feed-forward network. dropout (float): The dropout rate. bias (bool): Whether to include bias in linear layers. activation_type (ActivationType): The activation type. attention_implementation (AttentionImplementation): The attention implementation. attention_config (AttentionConfig): The attention configuration. attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module. ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module. lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module. seed (int, optional): The random seed. Defaults to None. use_weight_tying (bool): Whether to use weight tying.
- Parameters:
sample_key (str)
prediction_key (str)
poe_type (PositionTypes)
sequence_length (int)
vocab_size (int)
n_layer (int)
n_head_q (int)
n_head_kv (int)
n_embd (int)
ffn_hidden (int)
dropout (float)
bias (bool)
activation_type (ActivationType)
attention_implementation (AttentionImplementation)
attention_config (AttentionConfig)
attention_norm_config (LayerNormWrapperConfig)
ffn_norm_config (LayerNormWrapperConfig)
lm_head_norm_config (LayerNormWrapperConfig)
use_weight_tying (bool)
seed (int)
- forward(inputs)[source]
Forward pass of the GPT2LLM module.
- Args:
- inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
sample_key (str): Key for the input tensor containing token ids.
- Returns:
- dict[str, torch.Tensor]: A dictionary containing output tensors.
prediction_key (str): Key for the output tensor containing logits.
- forward_impl(inputs)[source]
Forward pass implementation of the GPT2LLM module.
- Args:
- inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
sample_key (str): Key for the input tensor containing token ids.
- Returns:
- dict[str, torch.Tensor]: A dictionary containing output tensors.
prediction_key (str): Key for the output tensor containing logits.
- class modalities.models.gpt2.gpt2_model.GPT2LLMConfig(**data)[source]
Bases:
BaseModel
Configuration class for GPT2LLM model.
- Args:
sample_key (str): The key for the samples. prediction_key (str): The key for the predictions. use_meta_device (bool, optional): Whether to use meta device. Defaults to False. poe_type (PositionTypes): The type of position encoding. sequence_length (int): The length of the sequence. vocab_size (int): The size of the vocabulary. n_layer (int): The number of layers. n_head_q (int): The number of attention heads for queries. n_head_kv (int): The number of attention heads for keys and values. n_embd (int): The embedding size. ffn_hidden (int): The hidden size of the feed-forward network. dropout (float): The dropout rate. bias (bool): Whether to use bias in Linears. attention_config (AttentionConfig): The attention configuration. attention_implementation (AttentionImplementation): The attention implementation. activation_type (ActivationType): The activation type. attention_norm_config (LayerNormWrapperConfig): Config for normalization of the attention. ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network. lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head. use_weight_tying (bool): Whether to use weight tying.
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
sample_key (str)
prediction_key (str)
use_meta_device (bool | None)
poe_type (PositionTypes)
sequence_length (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])
vocab_size (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])
n_layer (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])
n_head_q (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])
n_head_kv (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])
n_embd (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])
ffn_hidden (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])
dropout (Annotated[float, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=0.0)])])
bias (bool)
attention_config (AttentionConfig)
attention_implementation (AttentionImplementation)
activation_type (ActivationType)
attention_norm_config (LayerNormWrapperConfig)
ffn_norm_config (LayerNormWrapperConfig)
lm_head_norm_config (LayerNormWrapperConfig)
use_weight_tying (bool)
-
activation_type:
ActivationType
-
attention_config:
AttentionConfig
-
attention_implementation:
AttentionImplementation
-
attention_norm_config:
LayerNormWrapperConfig
- check_divisibility()[source]
Check if the value of n_head_q is divisible by n_head_kv.
- Return type:
- Raises:
ValueError: If n_head_q is not divisible by n_head_kv.
- Returns:
GPT2LLMConfig: The current instance of GPT2LLMConfig.
-
ffn_norm_config:
LayerNormWrapperConfig
-
lm_head_norm_config:
LayerNormWrapperConfig
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
-
poe_type:
PositionTypes
- class modalities.models.gpt2.gpt2_model.IdentityTransform(*args, **kwargs)[source]
Bases:
QueryKeyValueTransform
IdentityTransform class which does not apply any transform.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class modalities.models.gpt2.gpt2_model.LayerNormWrapperConfig(**data)[source]
Bases:
BaseModel
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
norm_type (LayerNorms)
config (LayerNormConfig | RMSLayerNormConfig)
-
config:
LayerNormConfig
|RMSLayerNormConfig
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
-
norm_type:
LayerNorms
- class modalities.models.gpt2.gpt2_model.LayerNorms(value)[source]
Bases:
LookupEnum
Enum lookup class for LayerNorms.
- Attributes:
RMSNorm: RMSLayerNorm class. LayerNorm: nn.LayerNorm class.
- layer_norm(normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None) = <class 'torch.nn.modules.normalization.LayerNorm'>
- class modalities.models.gpt2.gpt2_model.PositionTypes(value)[source]
-
Enum class representing different position types.
- Attributes:
ABSOLUTE (str): Represents the absolute position type. NOPE (str): Represents the nope (no postional emebddigns) position type.
- ABSOLUTE = 'ABSOLUTE'
- NOPE = 'NOPE'
- class modalities.models.gpt2.gpt2_model.QueryKeyValueTransform(*args, **kwargs)[source]
Bases:
Module
Query Key Value Transform base class.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- abstractmethod forward(q, k, v)[source]
Perform forward pass for transforming queries/keys/values.
- Args:
q (torch.Tensor): The query tensor. k (torch.Tensor): The key tensor. v (torch.Tensor): The value tensor.
- Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the output tensors.
- class modalities.models.gpt2.gpt2_model.QueryKeyValueTransformType(value)[source]
Bases:
Enum
Enum class representing different types of query-key-value transform.
- Attributes:
IdentityTransform: Represents the identity transform. RotaryTransform: Represents the rotary transform.
- IdentityTransform(*args, **kwargs) = <class 'modalities.models.gpt2.gpt2_model.IdentityTransform'>
- Return type:
None
- class modalities.models.gpt2.gpt2_model.RotaryTransform(n_embd, n_head, seq_length_dim=-2, base_freq=10000)[source]
Bases:
QueryKeyValueTransform
RotaryTransform class which implements rotary positional embeddings.
- Source: https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
We added the corresponding code here, becauase there is a conflict with “@torch.jit.script” used in the XFormers implementation and removed in this implementation.#
Initializes the RotaryTransform object.
- Args:
n_embd (int): The size of the embedding dimension. n_head (int): The number of attention heads. seq_length_dim (int, optional): The dimension along which the sequence length is defined. Defaults to -2. base_freq (int): Base frequency for RoPE. Defaults to 10000.
- apply_rotary_pos_emb(x, cos, sin)[source]
Applies rotary positional embedding to the input tensor.
- Args:
x (torch.Tensor): Input tensor. cos (torch.Tensor): Cosine values for rotary positional embedding. sin (torch.Tensor): Sine values for rotary positional embedding.
- Returns:
torch.Tensor: Tensor after applying rotary positional embedding.
- class modalities.models.gpt2.gpt2_model.TransformerMLP(n_embd, ffn_hidden, bias, dropout)[source]
Bases:
Module
TransformerMLP class.
Initializes the TransformerMLP class.
- Args:
n_embd (int): The size of the input embedding. ffn_hidden (int): The size of the hidden layer in the feed-forward network. bias (bool): Whether to include bias terms in the linear layers. dropout (float): The dropout probability.
- Returns:
None
- modalities.models.gpt2.gpt2_model.manual_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)[source]
Compute scaled dot product attention.
- Return type:
- Args:
query (torch.Tensor): The query tensor of shape (batch_size, num_queries, query_dim). key (torch.Tensor): The key tensor of shape (batch_size, num_keys, key_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_values, value_dim). attn_mask (torch.Tensor, optional): The attention mask tensor of shape (num_queries, num_keys).
Defaults to None.
dropout_p (float, optional): The dropout probability. Defaults to 0.0. is_causal (bool, optional): Whether the attention is causal or not. Defaults to False. scale (float, optional): The scaling factor. Defaults to None.
- Returns:
torch.Tensor: The attention weights tensor of shape (batch_size, num_queries, num_keys).
- Note:
Taken from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html