From 35d7a312292b98b60a19bac61508122672d3e191 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Sun, 20 Oct 2024 20:51:01 -0700 Subject: [PATCH] fix kwargs for torch.amp.autocast Signed-off-by: Xin Yao --- transformer_engine/pytorch/distributed.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e3af4d60f8..3eb1dbc7a8 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -259,13 +259,19 @@ def _get_active_autocast_contexts(): gpu_autocast_enabled = torch.is_autocast_enabled("cuda") gpu_autocast_dtype = torch.get_autocast_dtype("cuda") gpu_autocast_ctx = torch.amp.autocast( - "cuda", gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached + "cuda", + enabled=gpu_autocast_enabled, + dtype=gpu_autocast_dtype, + cache_enabled=autocast_cached, ) cpu_autocast_enabled = torch.is_autocast_enabled("cpu") cpu_autocast_dtype = torch.get_autocast_dtype("cpu") cpu_autocast_ctx = torch.amp.autocast( - "cpu", cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached + "cpu", + enabled=cpu_autocast_enabled, + dtype=cpu_autocast_dtype, + cache_enabled=autocast_cached, ) return gpu_autocast_ctx, cpu_autocast_ctx