Source code for modalities.models.parallelism.pipeline_parallelism_configs

from typing import Annotated

from pydantic import BaseModel, Field

from modalities.config.pydantic_if_types import (
    PydanticDeviceMeshIFType,
    PydanticLossIFType,
    PydanticPipelineStageType,
    PydanticPipelineType,
    PydanticPytorchModuleType,
    PydanticStagesGeneratorType,
)
from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes


[docs] class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate pass
[docs] class StagedPipelineConfig(BaseModel): whole_model: PydanticPytorchModuleType stages_generator: PydanticStagesGeneratorType device_mesh: PydanticDeviceMeshIFType local_rank: Annotated[int, Field(strict=True, ge=0)] pp_schedule_name: str num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)]
[docs] class ScheduledPipelineConfig(BaseModel): loss_fn: PydanticLossIFType pp_schedule_name: str batch_size: Annotated[int, Field(strict=True, ge=1)] microbatch_size: Annotated[int, Field(strict=True, ge=1)] pp_degree: Annotated[int, Field(strict=True, ge=2)] pipeline: PydanticPipelineType
[docs] class ComponentSelectorFromPipelineConfig(BaseModel): pipeline: PydanticPipelineType selection_type: PipelineSelectionTypes
[docs] class PipelineConfig(BaseModel): pp_stage: PydanticPipelineStageType model_part: PydanticPytorchModuleType pp_schedule: PydanticPipelineType | None = None