Skip to content

Commit

Permalink
Fixing E1120 in sharding.py
Browse files Browse the repository at this point in the history
Signed-off-by: Ming Huang <[email protected]>
  • Loading branch information
mingxu1067 committed Jul 9, 2024
1 parent 59a7918 commit 3df91d5
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,13 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))


def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, kwargs: dict):
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, kwargs: dict = None):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if kwargs is None:
kwargs = {}

if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource)
return ops(x, resource, kwargs)
Expand Down

0 comments on commit 3df91d5

Please sign in to comment.