Skip to content

Conversation

khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Oct 8, 2025

Description

  • adds tokamax as a deps in requirements*.txt files.
  • use ragged_dot api from tokamax hidding behind a flag use_tokamax
  • this PR does a conditional import for tokamax because maxtext's seed-env is pinned at jax==0.7.0 and tokamax has jax>=0.7.2 in its deps. Added a TODO to remove the condition once maxtext upgrades to jax==0.8.0

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/436335556

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Perf and convergence runs done by @suexu1025 in b/436335556#comment11

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@khatwanimohit khatwanimohit force-pushed the mohit/tokamax-gmm branch 4 times, most recently from a83d9c0 to 94521bf Compare October 9, 2025 00:25
@khatwanimohit khatwanimohit force-pushed the mohit/tokamax-gmm branch 5 times, most recently from 0856ca0 to b8b5d6b Compare October 17, 2025 19:17
@khatwanimohit khatwanimohit marked this pull request as ready for review October 17, 2025 19:17
@khatwanimohit khatwanimohit changed the title [WIP] Add tokamax megablox kernel Add tokamax megablox kernel Oct 17, 2025
tensorflow-text
tensorflow
tiktoken
tokamax>=0.0.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we plan to include latest version in the nightly build?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our nightly docker image build only supports using jax,jaxlib,libtpu nightly versions. If we want to have a nightly run then we can prepend pip install git+https://github.com/openxla/tokamax.git before the train command which should work

@khatwanimohit khatwanimohit requested a review from rdyro October 17, 2025 21:35
shard_optimizer_over_data: False

# Use tokamax library for kernel implementations
use_tokamax: false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to distinguish specific tokamax kernels, use_tokamax_gmm use_tokamax_splash etc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants