From 4235bfe38b45e9299876b04e644ecbade32979a1 Mon Sep 17 00:00:00 2001 From: yifeis-nv Date: Wed, 28 Aug 2024 10:57:12 +0800 Subject: [PATCH] Fix param input order for cudagraph Signed-off-by: yifeis-nv --- transformer_engine/pytorch/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index e2642bc360..e8cdb67c6c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -171,8 +171,8 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for c in callables: - for i in range(num_microbatches): + for i in range(num_microbatches): + for c in callables: per_callable_module_params.append( tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () )