Source code for modalities.inference.inference

#!/usr/bin/env python3

from typing import Optional

from pydantic import FilePath

from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.instantiation_models import TextGenerationInstantiationModel
from modalities.inference.text.config import TextInferenceComponentConfig
from modalities.inference.text.inference_component import TextInferenceComponent
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.running_env.cuda_env import CudaEnv
from modalities.running_env.env_utils import is_running_with_torchrun


[docs] def generate_text(config_path: FilePath, registry: Optional[Registry] = None): config_dict = load_app_config_dict(config_path) if registry is None: registry = Registry(COMPONENTS) registry.add_entity( component_key="inference_component", variant_key="text", component_type=TextInferenceComponent, component_config_type=TextInferenceComponentConfig, ) component_factory = ComponentFactory(registry=registry) if is_running_with_torchrun(): with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): components = component_factory.build_components( config_dict=config_dict, components_model_type=TextGenerationInstantiationModel, ) else: components = component_factory.build_components( config_dict=config_dict, components_model_type=TextGenerationInstantiationModel, ) text_inference_component = components.text_inference_component text_inference_component.run()