From 39a9fe6d86f6f2703377fcc18d4241b5a2b3783a Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 22 Oct 2024 20:39:14 -0700 Subject: [PATCH] check torch version inside functions Signed-off-by: Xin Yao --- transformer_engine/pytorch/distributed.py | 34 ++++++++--------------- transformer_engine/pytorch/utils.py | 17 ++++-------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 5e311e8082..b659048b2b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -245,17 +245,16 @@ def in_fp8_activation_recompute_phase() -> bool: return _FP8_ACTIVATION_RECOMPUTE_PHASE -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) -if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: - - def _get_active_autocast_contexts(): - """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast - state at the time of this function's execution. - """ - autocast_cached = torch.is_autocast_cache_enabled() +def _get_active_autocast_contexts(): + """ + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state + at the time of this function's execution. + """ + autocast_cached = torch.is_autocast_cache_enabled() + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: gpu_autocast_enabled = torch.is_autocast_enabled("cuda") gpu_autocast_dtype = torch.get_autocast_dtype("cuda") gpu_autocast_ctx = torch.amp.autocast( @@ -273,18 +272,7 @@ def _get_active_autocast_contexts(): dtype=cpu_autocast_dtype, cache_enabled=autocast_cached, ) - - return gpu_autocast_ctx, cpu_autocast_ctx - -else: - - def _get_active_autocast_contexts(): - """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast - state at the time of this function's execution. - """ - autocast_cached = torch.is_autocast_cache_enabled() - + else: gpu_autocast_enabled = torch.is_autocast_enabled() gpu_autocast_dtype = torch.get_autocast_gpu_dtype() gpu_autocast_ctx = torch.cuda.amp.autocast( @@ -297,7 +285,7 @@ def _get_active_autocast_contexts(): cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached ) - return gpu_autocast_ctx, cpu_autocast_ctx + return gpu_autocast_ctx, cpu_autocast_ctx class _CheckpointFunction(torch.autograd.Function): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 9d7695675e..935838ad3a 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -307,16 +307,11 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: return device1 == device2 -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) -if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: - - def torch_get_autocast_gpu_dtype() -> torch.dtype: - """Get PyTorch autocast GPU dtype.""" +def torch_get_autocast_gpu_dtype() -> torch.dtype: + """Get PyTorch autocast GPU dtype.""" + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + if TORCH_MAJOR == 2 and TORCH_MINOR >= 4: return torch.get_autocast_dtype("cuda") - -else: - - def torch_get_autocast_gpu_dtype() -> torch.dtype: - """Get PyTorch autocast GPU dtype.""" + else: return torch.get_autocast_gpu_dtype()