#!/usr/bin/env python3
from typing import Optional
from pydantic import FilePath
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.instantiation_models import TextGenerationInstantiationModel
from modalities.inference.text.config import TextInferenceComponentConfig
from modalities.inference.text.inference_component import TextInferenceComponent
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.running_env.cuda_env import CudaEnv
from modalities.running_env.env_utils import is_running_with_torchrun
[docs]
def generate_text(config_path: FilePath, registry: Optional[Registry] = None):
config_dict = load_app_config_dict(config_path)
if registry is None:
registry = Registry(COMPONENTS)
registry.add_entity(
component_key="inference_component",
variant_key="text",
component_type=TextInferenceComponent,
component_config_type=TextInferenceComponentConfig,
)
component_factory = ComponentFactory(registry=registry)
if is_running_with_torchrun():
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
components = component_factory.build_components(
config_dict=config_dict,
components_model_type=TextGenerationInstantiationModel,
)
else:
components = component_factory.build_components(
config_dict=config_dict,
components_model_type=TextGenerationInstantiationModel,
)
text_inference_component = components.text_inference_component
text_inference_component.run()