modalities.models.gpt2 package

Submodules

modalities.models.gpt2.collator module

class modalities.models.gpt2.collator.CollateFnIF[source]

Bases: ABC

CollateFnIF class to define a collate function interface.

class modalities.models.gpt2.collator.GPT2LLMCollateFn(sample_key, target_key)[source]

Bases: CollateFnIF

GPT2LLMCollateFn class to define a collate function for GPT2 language model.

Initializes the Collator object.

Args:

sample_key (str): The key for accessing the sample data. target_key (str): The key for accessing the target data.

Parameters:
  • sample_key (str)

  • target_key (str)

modalities.models.gpt2.gpt2_model module

class modalities.models.gpt2.gpt2_model.AttentionConfig(**data)[source]

Bases: BaseModel

Configuration class for attention mechanism.

Attributes:

qkv_transforms (list[QueryKeyValueTransformConfig]): List of configurations for query-key-value transforms.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

qkv_transforms (list[QueryKeyValueTransformConfig])

class QueryKeyValueTransformConfig(**data)[source]

Bases: BaseModel

Configuration class for QueryKeyValueTransform.

Attributes:

type_hint (QueryKeyValueTransformType): The type hint for the transform. config (RotaryTransformConfig | IdentityTransformConfig): The configuration for the transform.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:
class IdentityTransformConfig(**data)[source]

Bases: BaseModel

IdentityTransformConfig class.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class RotaryTransformConfig(**data)[source]

Bases: BaseModel

Configuration class for RotaryTransform.

Attributes:

n_embd (int): Number of embeddings. n_head (int): Number of attention heads. seq_length_dim (int): Dimension of the sequence length. base_freq (int): Base frequency for RoPE.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:
  • n_embd (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=0)])])

  • n_head (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=0)])])

  • seq_length_dim (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True)])])

  • base_freq (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=10000)])])

base_freq: Annotated[int]
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

n_embd: Annotated[int]
n_head: Annotated[int]
seq_length_dim: Annotated[int]
config: RotaryTransformConfig | IdentityTransformConfig
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

classmethod parse_sharding_strategy_by_name(name)[source]

Parses a QueryKeyValueTransform by its name.

Args:

name (str): The name of the sharding strategy.

Returns:

QueryKeyValueTransformType: The parsed sharding strategy.

type_hint: QueryKeyValueTransformType
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

qkv_transforms: list[QueryKeyValueTransformConfig]
class modalities.models.gpt2.gpt2_model.AttentionImplementation(value)[source]

Bases: str, Enum

Enum class representing different implementations of attention.

Attributes:

MANUAL (str): Manual attention implementation. PYTORCH_FLASH (str): PyTorch’s flash attention implementation. DAO_FLASH (str): DAO’s flash attention implementation.

DAO_FLASH = 'dao_flash'
MANUAL = 'manual'
PYTORCH_FLASH = 'pytorch_flash'
class modalities.models.gpt2.gpt2_model.CausalSelfAttention(n_head_q, n_head_kv, n_embd, attention_config, attention_impl, bias, dropout)[source]

Bases: Module

Causal Self Attention class.

Initializes the CausalSelfAttention object.

Args:

n_head_q (int): Number of attention heads for queries. n_head_kv (int): Number of attention heads for keys and values. n_embd (int): Size of the embedding dimension. attention_config (AttentionConfig): The attention configuration. attention_impl (AttentionImplementation): The attention implementation. bias (bool): Whether to include bias in linear layers. dropout (float): Dropout rate.

Returns:

None

Parameters:
classmethod execute_attention(q, k, v, dropout, attention_impl)[source]

Executes attention mechanism based on the specified implementation.

Return type:

Tensor

Parameters:
Args:

cls (object): The class object. q (torch.Tensor): The query tensor. k (torch.Tensor): The key tensor. v (torch.Tensor): The value tensor. dropout (float): The dropout rate. attention_impl (AttentionImplementation): The attention implementation to use.

Returns:

torch.Tensor: The output tensor.

Raises:

NotImplementedError: If the specified attention implementation is not supported.

static execute_qkv_transforms(q, k, v, qkv_transforms, n_head_q)[source]

Applies a series of transformations to the query, key, and value tensors.

Return type:

tuple[Tensor, Tensor, Tensor]

Parameters:
Args:

q (torch.Tensor): The query tensors. k (torch.Tensor): The key tensors v (torch.Tensor): The value tensors. qkv_transforms (nn.ModuleList): A list of transformation modules to be applied to q, k, and v. n_head_q (int): The number of heads for the query tensors.

Returns:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the transformed query, key, and value tensors.

forward(x)[source]

Forward pass of the CausalSelfAttention module.

Return type:

Tensor

Parameters:

x (Tensor)

Args:

x (torch.Tensor): Input tensor of shape (B, T, n_embd)

Returns:

torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection.

projection(x)[source]

Applies projections to the input tensor to get queries, keys, and values.

Return type:

tuple[Tensor, Tensor, Tensor]

Parameters:

x (Tensor)

Args:

x (torch.Tensor): The input tensor.

Returns:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the query, key, and value tensors.

classmethod repeat_kv_heads(q, k, v)[source]

Repeats the key-value (k, v) heads if the number of query (q) heads is different.

Args:

cls (class): The class object. q (torch.Tensor): The query tensor of shape (B, nh_q, T, hs). k (torch.Tensor): The key tensor of shape (B, nh_kv, T, hs). v (torch.Tensor): The value tensor of shape (B, nh_kv, T, hs).

Returns:

tuple: A tuple containing the repeated key tensor (k) and the repeated value tensor (v).

Parameters:
class modalities.models.gpt2.gpt2_model.GPT2Block(n_embd, bias, n_head_q, n_head_kv, activation_type, attention_impl, attention_config, dropout, ffn_hidden, attention_norm, ffn_norm)[source]

Bases: Module

GPT2Block class.

Initializes the GPT2Block.

Args:

n_embd (int): The embedding dimension. bias (bool): Whether to include bias in the model. n_head_q (int): The number of attention heads for queries. n_head_kv (int): The number of attention heads for keys and values. activation_type (ActivationType): The type of activation function to use. attention_impl (AttentionImplementation): The implementation of attention mechanism. attention_config (AttentionConfig): The configuration for attention mechanism. dropout (float): The dropout rate. ffn_hidden (int): The size of the hidden layer in the feed-forward network. attention_norm (nn.Module): The normalization layer for attention. ffn_norm (nn.Module): The normalization layer for feed-forward network.

Parameters:
forward(x)[source]

Forward pass of the GPT2Block.

Return type:

Tensor

Parameters:

x (Tensor)

Args:

x (torch.Tensor): Input tensor.

Returns:

torch.Tensor: Output tensor.

class modalities.models.gpt2.gpt2_model.GPT2LLM(sample_key, prediction_key, poe_type, sequence_length, vocab_size, n_layer, n_head_q, n_head_kv, n_embd, ffn_hidden, dropout, bias, activation_type, attention_implementation, attention_config, attention_norm_config, ffn_norm_config, lm_head_norm_config, use_weight_tying, seed=None)[source]

Bases: NNModel

GPT2LLM class.

Initializes the GPT2LLM object.

Args:

sample_key (str): The sample key. prediction_key (str): The prediction key. poe_type (PositionTypes): The position type. sequence_length (int): The sequence length. vocab_size (int): The vocabulary size. n_layer (int): The number of layers. n_head_q (int): The number of query heads. n_head_kv (int): The number of key-value heads. n_embd (int): The embedding dimension. ffn_hidden (int): The hidden dimension of the feed-forward network. dropout (float): The dropout rate. bias (bool): Whether to include bias in linear layers. activation_type (ActivationType): The activation type. attention_implementation (AttentionImplementation): The attention implementation. attention_config (AttentionConfig): The attention configuration. attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module. ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module. lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module. seed (int, optional): The random seed. Defaults to None. use_weight_tying (bool): Whether to use weight tying.

Parameters:
forward(inputs)[source]

Forward pass of the GPT2LLM module.

Return type:

dict[str, Tensor]

Parameters:

inputs (dict[str, Tensor])

Args:
inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
  • sample_key (str): Key for the input tensor containing token ids.

Returns:
dict[str, torch.Tensor]: A dictionary containing output tensors.
  • prediction_key (str): Key for the output tensor containing logits.

forward_impl(inputs)[source]

Forward pass implementation of the GPT2LLM module.

Return type:

dict[str, Tensor]

Parameters:

inputs (dict[str, Tensor])

Args:
inputs (dict[str, torch.Tensor]): A dictionary containing input tensors.
  • sample_key (str): Key for the input tensor containing token ids.

Returns:
dict[str, torch.Tensor]: A dictionary containing output tensors.
  • prediction_key (str): Key for the output tensor containing logits.

class modalities.models.gpt2.gpt2_model.GPT2LLMConfig(**data)[source]

Bases: BaseModel

Configuration class for GPT2LLM model.

Args:

sample_key (str): The key for the samples. prediction_key (str): The key for the predictions. use_meta_device (bool, optional): Whether to use meta device. Defaults to False. poe_type (PositionTypes): The type of position encoding. sequence_length (int): The length of the sequence. vocab_size (int): The size of the vocabulary. n_layer (int): The number of layers. n_head_q (int): The number of attention heads for queries. n_head_kv (int): The number of attention heads for keys and values. n_embd (int): The embedding size. ffn_hidden (int): The hidden size of the feed-forward network. dropout (float): The dropout rate. bias (bool): Whether to use bias in Linears. attention_config (AttentionConfig): The attention configuration. attention_implementation (AttentionImplementation): The attention implementation. activation_type (ActivationType): The activation type. attention_norm_config (LayerNormWrapperConfig): Config for normalization of the attention. ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network. lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head. use_weight_tying (bool): Whether to use weight tying.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:
  • sample_key (str)

  • prediction_key (str)

  • use_meta_device (bool | None)

  • poe_type (PositionTypes)

  • sequence_length (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])

  • vocab_size (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])

  • n_layer (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])

  • n_head_q (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])

  • n_head_kv (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])

  • n_embd (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])

  • ffn_hidden (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=1)])])

  • dropout (Annotated[float, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=0.0)])])

  • bias (bool)

  • attention_config (AttentionConfig)

  • attention_implementation (AttentionImplementation)

  • activation_type (ActivationType)

  • attention_norm_config (LayerNormWrapperConfig)

  • ffn_norm_config (LayerNormWrapperConfig)

  • lm_head_norm_config (LayerNormWrapperConfig)

  • use_weight_tying (bool)

activation_type: ActivationType
attention_config: AttentionConfig
attention_implementation: AttentionImplementation
attention_norm_config: LayerNormWrapperConfig
bias: bool
check_divisibility()[source]

Check if the value of n_head_q is divisible by n_head_kv.

Return type:

GPT2LLMConfig

Raises:

ValueError: If n_head_q is not divisible by n_head_kv.

Returns:

GPT2LLMConfig: The current instance of GPT2LLMConfig.

dropout: Annotated[float]
ffn_hidden: Annotated[int]
ffn_norm_config: LayerNormWrapperConfig
lm_head_norm_config: LayerNormWrapperConfig
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

n_embd: Annotated[int]
n_head_kv: Annotated[int]
n_head_q: Annotated[int]
n_layer: Annotated[int]
poe_type: PositionTypes
prediction_key: str
sample_key: str
sequence_length: Annotated[int]
use_meta_device: Optional[bool]
use_weight_tying: bool
validate_sizes()[source]

Validates the sizes of the GPT2 model parameters.

Return type:

GPT2LLMConfig

Returns:

GPT2LLMConfig: The current instance of GPT2LLMConfig object.

Raises:

ValueError: If any of the parameters (ffn_hidden, vocab_size, n_embd) is not divisible by 128.

vocab_size: Annotated[int]
class modalities.models.gpt2.gpt2_model.IdentityTransform(*args, **kwargs)[source]

Bases: QueryKeyValueTransform

IdentityTransform class which does not apply any transform.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(q, k, v)[source]

Forward pass of the IdentityTransform which does not apply any transform.

Return type:

tuple[Tensor, Tensor, Tensor]

Parameters:
Args:

q (torch.Tensor): The query tensor. k (torch.Tensor): The key tensor. v (torch.Tensor): The value tensor.

Returns:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The tensors q, k, and v.

class modalities.models.gpt2.gpt2_model.LayerNormWrapperConfig(**data)[source]

Bases: BaseModel

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:
config: LayerNormConfig | RMSLayerNormConfig
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

norm_type: LayerNorms
class modalities.models.gpt2.gpt2_model.LayerNorms(value)[source]

Bases: LookupEnum

Enum lookup class for LayerNorms.

Attributes:

RMSNorm: RMSLayerNorm class. LayerNorm: nn.LayerNorm class.

layer_norm(normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None) = <class 'torch.nn.modules.normalization.LayerNorm'>
Parameters:
Return type:

None

rms_norm(ndim, bias=True, epsilon=1e-05) = <class 'modalities.models.components.layer_norms.RMSLayerNorm'>
Parameters:
class modalities.models.gpt2.gpt2_model.PositionTypes(value)[source]

Bases: str, Enum

Enum class representing different position types.

Attributes:

ABSOLUTE (str): Represents the absolute position type. NOPE (str): Represents the nope (no postional emebddigns) position type.

ABSOLUTE = 'ABSOLUTE'
NOPE = 'NOPE'
class modalities.models.gpt2.gpt2_model.QueryKeyValueTransform(*args, **kwargs)[source]

Bases: Module

Query Key Value Transform base class.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstractmethod forward(q, k, v)[source]

Perform forward pass for transforming queries/keys/values.

Return type:

tuple[Tensor, Tensor, Tensor]

Parameters:
Args:

q (torch.Tensor): The query tensor. k (torch.Tensor): The key tensor. v (torch.Tensor): The value tensor.

Returns:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the output tensors.

class modalities.models.gpt2.gpt2_model.QueryKeyValueTransformType(value)[source]

Bases: Enum

Enum class representing different types of query-key-value transform.

Attributes:

IdentityTransform: Represents the identity transform. RotaryTransform: Represents the rotary transform.

IdentityTransform(*args, **kwargs) = <class 'modalities.models.gpt2.gpt2_model.IdentityTransform'>
Return type:

None

RotaryTransform(n_embd, n_head, seq_length_dim=-2, base_freq=10000) = <class 'modalities.models.gpt2.gpt2_model.RotaryTransform'>
Parameters:
  • n_embd (int)

  • n_head (int)

  • seq_length_dim (int)

  • base_freq (int)

class modalities.models.gpt2.gpt2_model.RotaryTransform(n_embd, n_head, seq_length_dim=-2, base_freq=10000)[source]

Bases: QueryKeyValueTransform

RotaryTransform class which implements rotary positional embeddings.

Source: https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py

We added the corresponding code here, becauase there is a conflict with “@torch.jit.script” used in the XFormers implementation and removed in this implementation.#

Initializes the RotaryTransform object.

Args:

n_embd (int): The size of the embedding dimension. n_head (int): The number of attention heads. seq_length_dim (int, optional): The dimension along which the sequence length is defined. Defaults to -2. base_freq (int): Base frequency for RoPE. Defaults to 10000.

Parameters:
  • n_embd (int)

  • n_head (int)

  • seq_length_dim (int)

  • base_freq (int)

apply_rotary_pos_emb(x, cos, sin)[source]

Applies rotary positional embedding to the input tensor.

Args:

x (torch.Tensor): Input tensor. cos (torch.Tensor): Cosine values for rotary positional embedding. sin (torch.Tensor): Sine values for rotary positional embedding.

Returns:

torch.Tensor: Tensor after applying rotary positional embedding.

forward(q, k, v)[source]

Forward pass of the RotaryTransform module.

Return type:

tuple[Tensor, Tensor, Tensor]

Parameters:
Args:

q (torch.Tensor): Query tensor. k (torch.Tensor): Key tensor. v (torch.Tensor): Value tensor.

Returns:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing the modified query tensor, key tensor, and value tensor.

rotate_half(x)[source]

Rearange tentor elements.

Args:

x (torch.Tensor): The input tensor.

Returns:

torch.Tensor: The output tensor.

class modalities.models.gpt2.gpt2_model.TransformerMLP(n_embd, ffn_hidden, bias, dropout)[source]

Bases: Module

TransformerMLP class.

Initializes the TransformerMLP class.

Args:

n_embd (int): The size of the input embedding. ffn_hidden (int): The size of the hidden layer in the feed-forward network. bias (bool): Whether to include bias terms in the linear layers. dropout (float): The dropout probability.

Returns:

None

Parameters:
forward(x)[source]

Forward pass of the TransformerMLP module.

Return type:

Tensor

Parameters:

x (Tensor)

Args:

x (torch.Tensor): Input tensor.

Returns:

torch.Tensor: Output tensor.

modalities.models.gpt2.gpt2_model.manual_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)[source]

Compute scaled dot product attention.

Return type:

Tensor

Args:

query (torch.Tensor): The query tensor of shape (batch_size, num_queries, query_dim). key (torch.Tensor): The key tensor of shape (batch_size, num_keys, key_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_values, value_dim). attn_mask (torch.Tensor, optional): The attention mask tensor of shape (num_queries, num_keys).

Defaults to None.

dropout_p (float, optional): The dropout probability. Defaults to 0.0. is_causal (bool, optional): Whether the attention is causal or not. Defaults to False. scale (float, optional): The scaling factor. Defaults to None.

Returns:

torch.Tensor: The attention weights tensor of shape (batch_size, num_queries, num_keys).

Note:

Taken from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

modalities.models.gpt2.pretrained_gpt_model module

Module contents