Source code for modalities.models.weight_tying

import torch.nn as nn


[docs] def has_tied_word_embeddings(model: nn.Module) -> bool: model_has_tied_word_embeddings = getattr(model, "has_tied_word_embeddings", None) if model_has_tied_word_embeddings is None: raise TypeError( f"{type(model).__name__} must define 'has_tied_word_embeddings' to be used with tied-embedding validation." ) return bool(model_has_tied_word_embeddings)