modalities.nn.model_initialization package
Submodules
modalities.nn.model_initialization.composed_initialization module
modalities.nn.model_initialization.initialization_if module
modalities.nn.model_initialization.initialization_routines module
- class modalities.nn.model_initialization.initialization_routines.InitializationRoutines[source]
Bases:
object- static get_plain_initialization(mean, std, parameter_name_regexes, hidden_dim=None)[source]
Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. For other layer types, the initialization must be subclassed and extended
- Return type:
- Parameters:
- Args:
mean (float): mean of the normal distribution std (float): standard deviation of the normal distribution. If set to “auto”, appropiate
value selected as per plain initialization described in https://arxiv.org/abs/2312.16903
hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None.
- static get_scaled_embed_initialization(mean, parameter_name_regexes)[source]
Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4).
- Return type:
- Parameters:
- Args:
mean (float): Mean of the normal distribution parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization
should be applied Defaults to None.
- Returns:
WeightInitializationIF: Weight initialization object
- static get_scaled_initialization(mean, std, num_layers, parameter_name_regexes)[source]
Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903
- Return type:
- Parameters:
- Args:
mean (float): Mean of the normal distribution std (float): Standard deviation of the normal distribution used to initialize the other weights num_layers (int): Number of layers in the model which we use to downscale std with parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
should be applied
- Returns:
WeightInitializationIF: Weight initialization object
- class modalities.nn.model_initialization.initialization_routines.NamedParameterwiseNormalInitialization(mean, std, parameter_name_regexes)[source]
Bases:
ModelInitializationIF- Parameters:
mean (float)
std (float)
parameter_name_regexes (RegexFilter)
- class modalities.nn.model_initialization.initialization_routines.PlainInitializationConfig(**data)[source]
Bases:
BaseModelCreate a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
- class modalities.nn.model_initialization.initialization_routines.ScaledEmbedInitializationConfig(**data)[source]
Bases:
BaseModelCreate a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- class modalities.nn.model_initialization.initialization_routines.ScaledInitializationConfig(**data)[source]
Bases:
BaseModelCreate a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
modalities.nn.model_initialization.parameter_name_filters module
- class modalities.nn.model_initialization.parameter_name_filters.RegexFilter(**data)[source]
Bases:
BaseModelCreate a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.