Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Sharding Utils #1003

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

Conversation

mingxu1067
Copy link
Collaborator

@mingxu1067 mingxu1067 commented Jul 9, 2024

Description

Adding sharding utils, like getting rank or group of a specific device, for further development of parallelism.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  1. Adding a device list with all detected devices and a proper sharding.
  2. Adding two functions to get rank and group ID of the given mesh axis and device ID.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@mingxu1067 mingxu1067 requested review from aaronp24, zlsh80826, denera and phu0ngng and removed request for aaronp24 July 9, 2024 23:32
@mingxu1067 mingxu1067 force-pushed the mingh/sharding_utils branch from e498099 to 3df91d5 Compare July 9, 2024 23:42
Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI.

@zlsh80826
Copy link
Collaborator

Hi @mingxu1067, do you have documentation on how group and size are defined here? It seems they follow some rules, but I'm not understanding them well.

@mingxu1067
Copy link
Collaborator Author

Hi @mingxu1067, do you have documentation on how group and size are defined here? It seems they follow some rules, but I'm not understanding them well.

I followed the MPI concept to design the functionalities of group and rank. In short, rank is the local ID of devices in the given axis, that is to specify which the part of the device belongs to in the axis. For example, if we have Mesh={'TP': 8}, then rank of TP of GPU-0 is 0. Similarly, GPU-1 is 1 and -2 is 2. group is to specify which GPUs are in the same group to perform communications. For example, if we have Mesh={'DP':2, TP:'4'}, then GPU 0-3 is in the same TP group to perform AllReduce. Similarly, GPU 4-7 is in the same TP group to perform AllReduce.

@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@mingxu1067 mingxu1067 marked this pull request as draft July 15, 2024 16:20
@mingxu1067
Copy link
Collaborator Author

As @zlsh80826's suggestion, jax.lax.axis_index() might be a in-built solution to obtain rank. Testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants