Skip to content

Conversation

RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Oct 2, 2025

Description

  • Migrate Gemma3 text layers to NNX

Tests

Training tests

export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d)
export DATASET_PATH=gs://maxtext-dataset

python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=gemma3-4b-after per_device_batch_size=4 enable_checkpointing=false model_name=gemma3-4b ici_fsdp_parallelism=4 steps=10 max_target_length=4096 async_checkpointing=false dataset_type=synthetic dtype=bfloat16 weight_dtype=bfloat16 scan_layers=True attention=flash
# Before

Total memory size: 30.7 GB, Output size: 5.4 GB, Temp size: 25.3 GB, Argument size: 5.4 GB, Host temp size: 0.0 GB.
Memstats: After params initialized:
	Using (GB) 5.69 / 95.74 (5.943179%) on TPU_0(process=0,(0,0,0,0))
	Using (GB) 5.69 / 95.74 (5.943179%) on TPU_1(process=0,(1,0,0,0))
	Using (GB) 5.69 / 95.74 (5.943179%) on TPU_2(process=0,(0,1,0,0))
	Using (GB) 5.69 / 95.74 (5.943179%) on TPU_3(process=0,(1,1,0,0))

completed step: 8, seconds: 1.600, TFLOP/s/device: 244.680, Tokens/s/device: 10239.808, total_weights: 65536, loss: 12.574


# After

Total memory size: 30.7 GB, Output size: 5.4 GB, Temp size: 25.3 GB, Argument size: 5.4 GB, Host temp size: 0.0 GB.
Memstats: After params initialized:
	Using (GB) 5.69 / 95.74 (5.943179%) on TPU_0(process=0,(0,0,0,0))
	Using (GB) 5.69 / 95.74 (5.943179%) on TPU_1(process=0,(1,0,0,0))
	Using (GB) 5.69 / 95.74 (5.943179%) on TPU_2(process=0,(0,1,0,0))
	Using (GB) 5.69 / 95.74 (5.943179%) on TPU_3(process=0,(1,1,0,0))

completed step: 8, seconds: 1.600, TFLOP/s/device: 244.663, Tokens/s/device: 10239.130, total_weights: 65536, loss: 12.619

Decoding tests

Noticed less memory usage (~10GB) in decoding in RAMstats.


# Before

Memstats: After load_params:
	Using (GB) 3.61 / 95.74 (3.770629%) on TPU_0(process=0,(0,0,0,0))
	Using (GB) 3.61 / 95.74 (3.770629%) on TPU_1(process=0,(1,0,0,0))
	Using (GB) 3.61 / 95.74 (3.770629%) on TPU_2(process=0,(0,1,0,0))
	Using (GB) 3.61 / 95.74 (3.770629%) on TPU_3(process=0,(1,1,0,0))

RAMstats: After load_params:
	Using (GB) 32.51 / 440.83 (7.374725%) -->  Available:405.7

Input `I love to` -> ` cook and I love to eat. I'm always looking for new recipes and ways to make my meals more interesting. I'm also a big fan of healthy eating, so I try to incorporate lots of fruits, vegetables, and lean protein into my diet.

# After

Memstats: After load_params:
	Using (GB) 3.61 / 95.74 (3.770629%) on TPU_0(process=0,(0,0,0,0))
	Using (GB) 3.61 / 95.74 (3.770629%) on TPU_1(process=0,(1,0,0,0))
	Using (GB) 3.61 / 95.74 (3.770629%) on TPU_2(process=0,(0,1,0,0))
	Using (GB) 3.61 / 95.74 (3.770629%) on TPU_3(process=0,(1,1,0,0))

RAMstats: After load_params:
	Using (GB) 22.7 / 440.83 (5.149377%) -->  Available:415.51

Input `I love to` -> ` cook and I love to eat. I'm always looking for new recipes and ways to make my meals more interesting. I'm also a big fan of healthy eating, so I try to incorporate lots of fruits, vegetables, and lean protein into my diet.

== Updated

JetStream tests

Before:

Memstats: After load_params:
	Using (GB) 1.81 / 95.74 (1.890537%) on TPU_0(process=0,(0,0,0,0))
	Using (GB) 1.81 / 95.74 (1.890537%) on TPU_1(process=0,(1,0,0,0))
	Using (GB) 1.81 / 95.74 (1.890537%) on TPU_2(process=0,(0,1,0,0))
	Using (GB) 1.81 / 95.74 (1.890537%) on TPU_3(process=0,(1,1,0,0))

RAMstats: After load_params:
	Using (GB) 22.76 / 440.83 (5.162988%) -->  Available:415.28

After:

Memstats: After load_params:
	Using (GB) 1.81 / 95.74 (1.890537%) on TPU_0(process=0,(0,0,0,0))
	Using (GB) 1.81 / 95.74 (1.890537%) on TPU_1(process=0,(1,0,0,0))
	Using (GB) 1.81 / 95.74 (1.890537%) on TPU_2(process=0,(0,1,0,0))
	Using (GB) 1.81 / 95.74 (1.890537%) on TPU_3(process=0,(1,1,0,0))

RAMstats: After load_params:
	Using (GB) 22.7 / 440.83 (5.149377%) -->  Available:415.34

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link

github-actions bot commented Oct 2, 2025

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

github-actions bot commented Oct 2, 2025

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request migrates Gemma3DecoderLayer and Gemma3ScannableBlock from Flax Linen to NNX. The changes are well-structured and follow the imperative style of NNX. The core logic appears to be preserved correctly during the refactoring.

🔍 General Feedback

  • The migration to NNX is clean and improves the code's structure by separating layer definition from execution.
  • Good job on updating the necessary wrappers and configurations in decoders.py to accommodate the new NNX-based layers.
  • I've added a couple of minor suggestions to clean up unused imports and improve docstring consistency.

@RissyRan RissyRan force-pushed the gemma3_text_nnx branch 3 times, most recently from f5ae73f to c7a12b6 Compare October 2, 2025 23:48
@RissyRan RissyRan changed the title [WIP] Migrate Gemma3DecoderLayer and Gemma3ScannableBlock to NNX Migrate Gemma3DecoderLayer and Gemma3ScannableBlock to NNX Oct 3, 2025
@shuningjin
Copy link
Collaborator

shuningjin commented Oct 3, 2025

The before/after JetStream accuracy differs slightly for gemma3: https://diff.googleplex.com/#key=1I4o3eENf3aa

Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new test results! Jetstream acc diff could be due to missing enable_dropout=False in config, which is out of scope for this PR. I see you increased the profiling time from 200ms to 1000ms, this might have contributed to the HLO matching.

Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@copybara-service copybara-service bot merged commit d6193a5 into main Oct 13, 2025
33 of 34 checks passed
@copybara-service copybara-service bot deleted the gemma3_text_nnx branch October 13, 2025 18:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants