From 3df91d5ea55127ae93bf32fd5584550e34f92a9c Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Tue, 9 Jul 2024 23:36:38 +0000 Subject: [PATCH] Fixing E1120 in sharding.py Signed-off-by: Ming Huang --- transformer_engine/jax/sharding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index b6c606e3cd..555e818891 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -132,10 +132,13 @@ def get_padded_spec(spec, ndim): return spec + (None,) * (ndim - len(spec)) -def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, kwargs: dict): +def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, kwargs: dict = None): """ A wrapper function to invoke lax.p* operations, like psum. """ + if kwargs is None: + kwargs = {} + if mesh_resource is not None: _, resource = _get_mesh_info(mesh_resource) return ops(x, resource, kwargs)