import os
import torch
import torch.cuda.nccl as nccl
import torch.distributed as dist
from pydantic import BaseModel
# TODO find a solution for github actions
# to install this as a dependency
# from pkg_resources import packaging
from torch.distributed.fsdp import MixedPrecision
from modalities.config.lookup_enum import LookupEnum
[docs]
def is_running_with_torchrun():
    # Check for one of the environment variables set by torchrun
    return "LOCAL_RANK" in os.environ 
[docs]
def has_bfloat_support():
    return (
        torch.version.cuda
        and torch.cuda.is_available()
        and torch.cuda.is_bf16_supported()
        # TODO find a solution for github actions
        # to install this as a dependency
        # and packaging.version.parse(torch.version.cuda).release >= (11, 0)
        and dist.is_nccl_available()
        and nccl.version() >= (2, 10)
    ) 
# requires grad scaler in main loop
fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)
bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)
bfSixteen_working = MixedPrecision(
    param_dtype=torch.float32,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)
megatron_strategy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    # buffer_dtype=torch.bfloat16,
)
fpThirtytwo = MixedPrecision(
    param_dtype=torch.float32,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float32,
)
no_mixed_precision = None
[docs]
class MixedPrecisionSettings(LookupEnum):
    FP_16 = fpSixteen
    BF_16 = bfSixteen
    BF_16_WORKING = bfSixteen_working
    FP_32 = fpThirtytwo
    MIXED_PRECISION_MEGATRON = megatron_strategy
    NO_MIXED_PRECISION = no_mixed_precision 
[docs]
class PyTorchDtypes(LookupEnum):
    FP_16 = torch.float16
    FP_32 = torch.float32
    BF_16 = torch.bfloat16 
[docs]
class FSDP2MixedPrecisionSettings(BaseModel):
    param_dtype: PyTorchDtypes
    reduce_dtype: PyTorchDtypes