modalities.nn.model_initialization package
Submodules
modalities.nn.model_initialization.composed_initialization module
- class modalities.nn.model_initialization.composed_initialization.ComposedInitializationRoutines[source]
Bases:
object
- static get_composed_model_initializer(model_type, weight_init_type, mean, std, hidden_dim=None, num_layers=None)[source]
This initialization allows to intialize a model with plain, scaled or scaled_embed initialization. Note that plain initialization is always performed in the beginning. In case of scaled_embed, also scaled is being performed before scaled_embed and after plain.
- Return type:
- Parameters:
model_type (SupportWeightInitModels)
weight_init_type (WeightInitTypes)
mean (float)
hidden_dim (int | None)
num_layers (int)
- Args:
model_type (SupportWeightInitModels): Model type enum referencing the model (e.g., “gpt2”) weight_init_type (WeightInitTypes): The initialization method we want to perform. mean (float): Mean of the normal distribution std (float | str): Standard deviation of the plain normal distribution hidden_dim (Optional[int], optional): Hidden dimension size of the model (required for plain if std=”auto”).
Defaults to None.
- num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only).
Defaults to None.
- Returns:
ModelInitializationIF: The Weight Initializer performing the initialization as specified.
- static get_model_initializer_wrapper(model_initializers)[source]
- Return type:
- Parameters:
model_initializers (list[ModelInitializationIF])
- class modalities.nn.model_initialization.composed_initialization.ComposedModelInitializationConfig(**data)[source]
Bases:
BaseModel
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
model_type (SupportWeightInitModels)
weight_init_type (WeightInitTypes)
mean (float)
std (Annotated[float, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=0.0)])] | str)
hidden_dim (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])] | None)
num_layers (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])] | None)
- model_config: ClassVar[ConfigDict] = {'protected_namespaces': ()}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
-
model_type:
SupportWeightInitModels
-
weight_init_type:
WeightInitTypes
- class modalities.nn.model_initialization.composed_initialization.ModelInitializerWrapper(model_initializers)[source]
Bases:
ModelInitializationIF
- Parameters:
model_initializers (list[ModelInitializationIF])
- class modalities.nn.model_initialization.composed_initialization.ModelInitializerWrapperConfig(**data)[source]
Bases:
BaseModel
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
model_initializers (list[Annotated[ModelInitializationIF, <modalities.config.pydantic_if_types.PydanticThirdPartyTypeIF object at 0x7f67efce8ed0>]])
- model_config: ClassVar[ConfigDict] = {'protected_namespaces': ()}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
-
model_initializers:
list
[Annotated
[ModelInitializationIF
]]
modalities.nn.model_initialization.initialization_if module
modalities.nn.model_initialization.initialization_routines module
- class modalities.nn.model_initialization.initialization_routines.InitializationRoutines[source]
Bases:
object
- static get_plain_initialization(mean, std, parameter_name_regexes, hidden_dim=None)[source]
Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. For other layer types, the initialization must be subclassed and extended
- Return type:
- Parameters:
- Args:
mean (float): mean of the normal distribution std (float): standard deviation of the normal distribution. If set to “auto”, appropiate
value selected as per plain initialization described in https://arxiv.org/abs/2312.16903
hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None.
- static get_scaled_embed_initialization(mean, parameter_name_regexes)[source]
Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4).
- Return type:
- Parameters:
- Args:
mean (float): Mean of the normal distribution parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization
should be applied Defaults to None.
- Returns:
WeightInitializationIF: Weight initialization object
- static get_scaled_initialization(mean, std, num_layers, parameter_name_regexes)[source]
Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903
- Return type:
- Parameters:
- Args:
mean (float): Mean of the normal distribution std (float): Standard deviation of the normal distribution used to initialize the other weights num_layers (int): Number of layers in the model which we use to downscale std with parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
should be applied
- Returns:
WeightInitializationIF: Weight initialization object
- class modalities.nn.model_initialization.initialization_routines.NamedParameterwiseNormalInitialization(mean, std, parameter_name_regexes)[source]
Bases:
ModelInitializationIF
- Parameters:
mean (float)
std (float)
parameter_name_regexes (RegexFilter)
- class modalities.nn.model_initialization.initialization_routines.PlainInitializationConfig(**data)[source]
Bases:
BaseModel
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class modalities.nn.model_initialization.initialization_routines.ScaledEmbedInitializationConfig(**data)[source]
Bases:
BaseModel
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class modalities.nn.model_initialization.initialization_routines.ScaledInitializationConfig(**data)[source]
Bases:
BaseModel
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
modalities.nn.model_initialization.parameter_name_filters module
- class modalities.nn.model_initialization.parameter_name_filters.RegexFilter(**data)[source]
Bases:
BaseModel
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].