Skip to content

Commit

Permalink
Adding sharding utils for further development of RingAttention
Browse files Browse the repository at this point in the history
Signed-off-by: Ming Huang <[email protected]>
  • Loading branch information
mingxu1067 committed Jul 9, 2024
1 parent 33dbf62 commit 59a7918
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 3 deletions.
62 changes: 61 additions & 1 deletion tests/jax/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
#
# See LICENSE for license information.

import jax
import pytest
import numpy as np

from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from transformer_engine.jax.sharding import get_group_of_mesh_axis, get_rank_of_mesh_axis
from transformer_engine.jax.sharding import global_shard_guard, num_of_devices
from transformer_engine.jax.sharding import MeshResource

LOGICAL_RULES = [
[(("a1", None), ("a2", "ma2")), False],
Expand All @@ -22,6 +26,45 @@
MeshResource("data", "model"),
]

MESH_INFO = [
(
4,
(2, 2),
("a1", "a2"),
{"a1": [0, 0, 1, 1], "a2": [0, 1, 0, 1]},
{"a1": [0, 1, 0, 1], "a2": [0, 0, 1, 1]},
),
(
4,
(4, 1),
("a1", "a2"),
{"a1": [0, 1, 2, 3], "a2": [0, 0, 0, 0]},
{"a1": [0, 0, 0, 0], "a2": [0, 1, 2, 3]},
),
(
4,
(1, 4),
("a1", "a2"),
{"a1": [0, 0, 0, 0], "a2": [0, 1, 2, 3]},
{"a1": [0, 1, 2, 3], "a2": [0, 0, 0, 0]},
),
(
8,
(2, 2, 2),
("a1", "a2", "a3"),
{
"a1": [0, 0, 0, 0, 1, 1, 1, 1],
"a2": [0, 0, 1, 1, 0, 0, 1, 1],
"a3": [0, 1, 0, 1, 0, 1, 0, 1],
},
{
"a1": [0, 1, 2, 3, 0, 1, 2, 3],
"a2": [0, 1, 0, 1, 2, 3, 2, 3],
"a3": [0, 0, 1, 1, 2, 2, 3, 3],
},
),
]


class TestShardingSideAPI:

Expand All @@ -36,3 +79,20 @@ def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
assert not need_assert
except AssertionError as ae:
assert need_assert, f"{ae.args}"

@pytest.mark.parametrize("mesh_info", MESH_INFO)
def test_get_rank_and_group_of_mesh_axis(self, mesh_info):
num_device, mesh_shape, mesh_axes, rank_ref, group_ref = mesh_info

if num_of_devices() < num_device:
pytest.skip("Not enough devices for this test.")

devices = np.asarray(jax.devices()[:num_device]).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, mesh_axes)

for d_id in range(num_device):
for axis in mesh_axes:
rank = get_rank_of_mesh_axis(d_id, axis, mesh)
assert rank == rank_ref[axis][d_id]
group = get_group_of_mesh_axis(d_id, axis, mesh)
assert group == group_ref[axis][d_id]
77 changes: 75 additions & 2 deletions transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))


def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, kwargs: dict):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource)
return ops(x, resource)
return ops(x, resource, kwargs)
return x


Expand All @@ -149,6 +149,79 @@ def num_of_devices():
return len(jax.devices())


def get_device_ids():
"""
Get a ID list of deteched devices with a proper sharding.
"""
device_ids = jnp.arange(num_of_devices())
device_ids = jax.lax.with_sharding_constraint(device_ids, PartitionSpec(get_all_mesh_axes()))
return device_ids


def get_mesh_axis_size(axis, mesh=None):
"""
Get the axis size of the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""

if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh

assert axis in mesh.shape, f"{axis} is not a axis of the given mesh f{mesh.shape}"
return mesh.shape[axis]


def get_group_of_mesh_axis(device_id, axis, mesh=None):
"""
Get the axis group of the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""

if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh

post_axis_size = 1
axis_size = get_mesh_axis_size(axis, mesh)

hit_the_axis = False
for ax in mesh.axis_names:
if hit_the_axis:
post_axis_size = post_axis_size * mesh.shape[ax]
if ax == axis:
hit_the_axis = True

group = (device_id % post_axis_size) + device_id // (
axis_size * post_axis_size
) * post_axis_size
return group


def get_rank_of_mesh_axis(device_id, axis, mesh=None):
"""
Get the axis rank of the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""

if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh

post_axis_size = 1
axis_size = get_mesh_axis_size(axis, mesh)

hit_the_axis = False
for ax in mesh.axis_names:
if hit_the_axis:
post_axis_size = post_axis_size * mesh.shape[ax]
if ax == axis:
hit_the_axis = True

rank = (device_id // post_axis_size) % axis_size
return rank


@dataclass
class MeshResource:
"""
Expand Down

0 comments on commit 59a7918

Please sign in to comment.