modalities.running_env.fsdp package
Submodules
modalities.running_env.fsdp.device_mesh module
- class modalities.running_env.fsdp.device_mesh.DeviceMeshConfig(**data)[source]
- Bases: - BaseModel- Create a new model by parsing and validating input data from keyword arguments. - Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model. - self is explicitly positional-only to allow self as a field name. - Parameters:
- device_type (str) 
- data_parallel_replicate_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])]) 
- data_parallel_shard_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Ge(ge=-1)])]) 
- tensor_parallel_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])]) 
- pipeline_parallel_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])]) 
- context_parallel_degree (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])]) 
- enable_loss_parallel (bool | None) 
- world_size (Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Strict(strict=True), Gt(gt=0)])]) 
 
 - model_config: ClassVar[ConfigDict] = {}
- Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict]. 
 
- class modalities.running_env.fsdp.device_mesh.ParallelismDegrees(value)[source]
- Bases: - Enum- CP = 'cp'
 - DP_REPLICATE = 'dp_replicate'
 - DP_SHARD = 'dp_shard'
 - PP = 'pp'
 - TP = 'tp'
 
- modalities.running_env.fsdp.device_mesh.get_device_mesh(device_type, data_parallel_replicate_degree, data_parallel_shard_degree, tensor_parallel_degree, pipeline_parallel_degree, context_parallel_degree, enable_loss_parallel, world_size)[source]
- Gets the device mesh for the specified parallelism degrees. - Return type:
- Parameters:
 - Args:
- device_type (str): The device type. data_parallel_replicate_degree (int): The data parallel replicate degree. data_parallel_shard_degree (int): The data parallel shard degree. tensor_parallel_degree (int): The tensor parallel degree. pipeline_parallel_degree (int): The pipeline parallel degree. context_parallel_degree (int): The context parallel degree. enable_loss_parallel (bool): Whether to enable loss parallelism. world_size (int): The world size. 
- Returns:
- DeviceMesh: The device mesh. 
 
- modalities.running_env.fsdp.device_mesh.get_parallel_degree(device_mesh, parallelism_methods)[source]
- Gets the number of parallel ranks (i.e., the parallelism degree) from the device mesh for a specific parallelism method. Args: - device_mesh (DeviceMesh): The device mesh. parallelism_methods (list[ParallelismDegrees]): The parallelism methods. - Returns:
- int: The number of parallel ranks for the specified parallelism method. 
 - Return type:
- Parameters:
- device_mesh (DeviceMesh) 
- parallelism_methods (list[ParallelismDegrees]) 
 
 
modalities.running_env.fsdp.fsdp_auto_wrapper module
- class modalities.running_env.fsdp.fsdp_auto_wrapper.FSDPAutoWrapFactoryTypes(value)[source]
- Bases: - LookupEnum