From a2212c0d3ef64877d316014227fbf08d1ab7e332 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 20 Oct 2024 20:51:01 -0700 Subject: [PATCH] fix kwargs for torch.amp.autocast Signed-off-by: Xin Yao --- transformer_engine/pytorch/distributed.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e3af4d60f8..5e311e8082 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -251,21 +251,27 @@ def in_fp8_activation_recompute_phase() -> bool: def _get_active_autocast_contexts(): """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state - at the time of this function's execution. + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast + state at the time of this function's execution. """ autocast_cached = torch.is_autocast_cache_enabled() gpu_autocast_enabled = torch.is_autocast_enabled("cuda") gpu_autocast_dtype = torch.get_autocast_dtype("cuda") gpu_autocast_ctx = torch.amp.autocast( - "cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached + "cuda", + enabled=gpu_autocast_enabled, + dtype=gpu_autocast_dtype, + cache_enabled=autocast_cached, ) cpu_autocast_enabled = torch.is_autocast_enabled("cpu") cpu_autocast_dtype = torch.get_autocast_dtype("cpu") cpu_autocast_ctx = torch.amp.autocast( - "cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached + "cpu", + enabled=cpu_autocast_enabled, + dtype=cpu_autocast_dtype, + cache_enabled=autocast_cached, ) return gpu_autocast_ctx, cpu_autocast_ctx @@ -274,8 +280,8 @@ def _get_active_autocast_contexts(): def _get_active_autocast_contexts(): """ - Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state - at the time of this function's execution. + Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast + state at the time of this function's execution. """ autocast_cached = torch.is_autocast_cache_enabled()