import os
import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
from pydantic import BaseModel
# TODO find a solution for github actions
# to install this as a dependency
# from pkg_resources import packaging
from torch.distributed.fsdp import MixedPrecision
from modalities.config.lookup_enum import LookupEnum
[docs]
def is_running_with_torchrun():
# Check for one of the environment variables set by torchrun
return "LOCAL_RANK" in os.environ
[docs]
def has_bfloat_support():
return (
torch.version.cuda
and torch.cuda.is_available()
and torch.cuda.is_bf16_supported()
# TODO find a solution for github actions
# to install this as a dependency
# and packaging.version.parse(torch.version.cuda).release >= (11, 0)
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
# requires grad scaler in main loop
fpSixteen = MixedPrecision(
param_dtype=torch.float16,
# Gradient communication precision.
reduce_dtype=torch.float16,
# Buffer precision.
buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
param_dtype=torch.bfloat16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16,
# Buffer precision.
buffer_dtype=torch.bfloat16,
)
bfSixteen_working = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
megatron_strategy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
# buffer_dtype=torch.bfloat16,
)
fpThirtytwo = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
no_mixed_precision = None
[docs]
class MixedPrecisionSettings(LookupEnum):
FP_16 = fpSixteen
BF_16 = bfSixteen
BF_16_WORKING = bfSixteen_working
FP_32 = fpThirtytwo
MIXED_PRECISION_MEGATRON = megatron_strategy
NO_MIXED_PRECISION = no_mixed_precision
[docs]
class PyTorchDtypes(LookupEnum):
FP_16 = torch.float16
FP_32 = torch.float32
BF_16 = torch.bfloat16
[docs]
class FSDP2MixedPrecisionSettings(BaseModel):
param_dtype: PyTorchDtypes
reduce_dtype: PyTorchDtypes