modalities.running_env.fsdp package
Submodules
modalities.running_env.fsdp.device_mesh module
- class modalities.running_env.fsdp.device_mesh.DeviceMeshConfig(**data)[source]
Bases:
BaseModel
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
device_type (str)
data_parallel_replicate_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])])
data_parallel_shard_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=-1)])])
tensor_parallel_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])])
pipeline_parallel_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])])
context_parallel_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])])
enable_loss_parallel (bool | None)
world_size (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])])
- model_config: ClassVar[ConfigDict] = {}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- class modalities.running_env.fsdp.device_mesh.ParallelismDegrees(value)[source]
Bases:
Enum
- CP = 'cp'
- DP_REPLICATE = 'dp_replicate'
- DP_SHARD = 'dp_shard'
- PP = 'pp'
- TP = 'tp'
- modalities.running_env.fsdp.device_mesh.get_device_mesh(device_type, data_parallel_replicate_degree, data_parallel_shard_degree, tensor_parallel_degree, pipeline_parallel_degree, context_parallel_degree, enable_loss_parallel, world_size)[source]
Gets the device mesh for the specified parallelism degrees.
- Return type:
- Parameters:
- Args:
device_type (str): The device type. data_parallel_replicate_degree (int): The data parallel replicate degree. data_parallel_shard_degree (int): The data parallel shard degree. tensor_parallel_degree (int): The tensor parallel degree. pipeline_parallel_degree (int): The pipeline parallel degree. context_parallel_degree (int): The context parallel degree. enable_loss_parallel (bool): Whether to enable loss parallelism. world_size (int): The world size.
- Returns:
DeviceMesh: The device mesh.
modalities.running_env.fsdp.fsdp_auto_wrapper module
- class modalities.running_env.fsdp.fsdp_auto_wrapper.FSDPAutoWrapFactoryTypes(value)[source]
Bases:
LookupEnum