Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
2fcff3f
adding llama4 image tiling and basic batching for mm SFT.
NicoGrande Sep 25, 2025
c3a9001
single image tiling support for decode.
NicoGrande Sep 26, 2025
9d59fb9
adding llama4 tiling changes to decode.
NicoGrande Sep 26, 2025
da421cf
small typo in decode.
NicoGrande Sep 26, 2025
d422caf
adding missing params for image masks.
NicoGrande Sep 26, 2025
75f44d7
Mainly internal tokenizer path fix
gobbleturk Sep 22, 2025
8d671b4
Continue pinning Tunix dependency for nightly builds
bvandermoon Sep 23, 2025
3f3a9aa
update sharding to remove extra datacopy
suexu1025 Sep 19, 2025
87b9a17
Add documentation to run SFT with Deepseek-V3 model
SurbhiJainUSC Sep 22, 2025
ca75fe6
Add Gemini CLI for PR review
RissyRan Sep 23, 2025
3fa7c90
GPT-OSS: Add user guide and tests
shuningjin Sep 25, 2025
a818e1e
Moved tunix to setup.sh
Rohan-Bierneni Sep 24, 2025
1899214
Update typo for pyconfig to fix the breakage on the head
Google-ML-Automation Sep 25, 2025
af3803c
Update pinned commit for Tunix
SurbhiJainUSC Sep 26, 2025
2b74960
feat(api_server): Add OpenAI-compatible API server for MaxText models
babyplutokurt Sep 2, 2025
1a4fcfb
remove redundant code block
babyplutokurt Sep 24, 2025
5dc1cd7
add vocab tiling
NuojCheng Sep 8, 2025
4160978
All-in-one commit for new pw_recipe modularization
DannyLiCom Sep 22, 2025
3de8da1
Update Tunix commit in extra_deps_from_github
SurbhiJainUSC Sep 26, 2025
ef24250
Add data pipeline perf in explanations
aireenmei Sep 23, 2025
1d9ef55
Add codeowners for model bring-up
RissyRan Sep 9, 2025
1a33123
[src/MaxText] `pytype` + `pylint` + `pyink` + `codespell`
SamuelMarks Sep 26, 2025
48eb50d
Add `mtc_data_parallelism` config for multi-tier checkpointing.
abhinavclemson Sep 28, 2025
d447f53
adding llama4 image tiling and basic batching for mm SFT.
NicoGrande Sep 25, 2025
e8fa373
Merge branch 'main' into nicogrande/sft-llama4-tiling
NicoGrande Sep 29, 2025
23d3a2e
Merge branch 'main' into nicogrande/sft-llama4-tiling
NicoGrande Sep 30, 2025
b74da04
fixing _apply_embedding call.
NicoGrande Sep 30, 2025
161f338
removing need for llama4 input changes.
NicoGrande Oct 1, 2025
99b2b5c
Merge branch 'main' into nicogrande/sft-llama4-tiling
NicoGrande Oct 1, 2025
cbcd022
fixing pad tile normalization.
NicoGrande Oct 2, 2025
3eff680
fixing merge mm embeddings.
NicoGrande Oct 3, 2025
1bd9f00
Merge branch 'main' into nicogrande/sft-llama4-tiling
NicoGrande Oct 3, 2025
f27daa6
linting fixes.
NicoGrande Oct 3, 2025
0b79904
pyink linting fixes.
NicoGrande Oct 3, 2025
be95c02
Merge branch 'main' into nicogrande/sft-llama4-tiling
NicoGrande Oct 3, 2025
e57680b
adding missing dim.
NicoGrande Oct 3, 2025
a812898
fix setup.sh for MODE=nightly
khatwanimohit Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ After installation, you can verify the package is available with `python3 -c "im

## 🔥 Latest news 🔥

* \[September 26, 2025\] Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage, see [doc](https://maxtext.readthedocs.io/en/latest/explanations/tiling.html).
* \[September 26, 2025\] Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage.
* \[September 24, 2025\] The GPT-OSS family of models (20B, 120B) is now supported.
* \[September 5, 2025\] MaxText has moved to an `src` layout as part of [RESTRUCTURE.md](RESTRUCTURE.md). For existing environments, please run `pip install -e .` from MaxText root.
* \[August 13, 2025\] The Qwen3 2507 MoE family of models is now supported: MoEs: 235B Thinking & 280B Coder as well as existing dense models: 0.6B, 4B, 8B, 14B, and 32B.
Expand Down
7 changes: 4 additions & 3 deletions end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export MODEL_BUCKET=gs://maxtext-gemma/unified/gemma3/hf
# Here is an example of qwen3-4b maxtext checkpoint, converted from Qwen/Qwen3-4B
export CKPT_PATH=gs://maxtext-gemma/unified/gemma3/4b/unscanned/2025-08-05-18-18/0/items

# You can upload to huggingface hub or GCS using the HF_CKPT_PATH as base_output_directory
# You can upload to huggingface hub or GCS by uncommenting the HF_CKPT_PATH and using it as base_output_directory
# export HF_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/hf/${idx}
export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx}

Expand All @@ -40,9 +40,10 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MA
use_multimodal=${USE_MULTIMODAL} \
scan_layers=false

# Alternatively, if uploaded the converted ckpt, HF requires local storage of model
# Alternatively, if uploaded the converted ckpt, HF requires local storage of model and please uncomment below
# mkdir -p "${LOCAL_PATH}"
# gcloud storage cp -r ${HF_CKPT_PATH} ${LOCAL_PATH}
# gcloud storage cp -r ${HF_CKPT_PATH}/** ${LOCAL_PATH}
# echo "Copied from ${HF_CKPT_PATH} to ${LOCAL_PATH}"

# We also test whether the forward pass logits match the original HF model
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
Expand Down
23 changes: 11 additions & 12 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -180,19 +180,19 @@ fi
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
# Stable mode
if [[ $DEVICE == "tpu" ]]; then


# TODO: Once tunix has support for GPUs, move it from here to requirements.txt
echo "Installing google-tunix for stable TPU environment"
python3 -m uv pip install 'google-tunix>=0.1.0'
echo "Installing stable jax, jaxlib for tpu"
if [[ -n "$JAX_VERSION" ]]; then
echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}"
python3 -m uv pip install jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python3 -m uv pip install -U jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
else
echo "Installing stable jax, jaxlib, libtpu for tpu"
python3 -m uv pip install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python3 -m uv pip install -U 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
fi

# TODO: Once tunix has support for GPUs, move it from here to requirements.txt
echo "Installing google-tunix for stable TPU environment"
python3 -m uv pip install 'google-tunix>=0.1.0'

if [[ -n "$LIBTPU_GCS_PATH" ]]; then
# Install custom libtpu
echo "Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path"
Expand Down Expand Up @@ -232,12 +232,15 @@ elif [[ $MODE == "nightly" ]]; then
export NVTE_FRAMEWORK=jax
python3 -m uv pip install https://github.com/NVIDIA/TransformerEngine/archive/9d031f.zip
elif [[ $DEVICE == "tpu" ]]; then
echo "Installing nightly tensorboard plugin profile"
python3 -m uv pip install tbp-nightly --upgrade
# Installing tunix
python3 -m uv pip install 'git+https://github.com/google/tunix.git'
echo "Installing jax-nightly, jaxlib-nightly"
# Install jax-nightly
python3 -m uv pip install --pre -U jax -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
# Install jaxlib-nightly
python3 -m uv pip install --pre -U jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

if [[ -n "$LIBTPU_GCS_PATH" ]]; then
# Install custom libtpu
echo "Installing libtpu.so from $LIBTPU_GCS_PATH to $libtpu_path"
Expand All @@ -250,10 +253,6 @@ elif [[ $MODE == "nightly" ]]; then
echo "Installing libtpu-nightly"
python3 -m uv pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
fi
echo "Installing nightly tensorboard plugin profile"
python3 -m uv pip install tbp-nightly --upgrade
# Installing tunix
python3 -m uv pip install 'git+https://github.com/google/tunix.git'
fi
echo "Installing nightly tensorboard plugin profile"
python3 -m uv pip install tbp-nightly --upgrade
Expand Down
11 changes: 10 additions & 1 deletion src/MaxText/configs/models/llama4-17b-128e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,13 @@ temperature_tuning: True
# Chunk attention is used on all RoPE layers
# otherwise, on NoPE layers, use global attention
chunk_attn_window_size: 8192
image_size_for_vit: 336

# Multimodal flags (need to set use_multimodal=true)
image_size_for_vit: 336
num_channels_for_vit: 3
patch_size_for_vit: 14
hidden_size_for_vit: 1408
intermediate_size_for_vit: 5632
num_hidden_layers_for_vit: 34
num_attention_heads_for_vit: 16
image_placeholder: "<|image|>"
12 changes: 11 additions & 1 deletion src/MaxText/configs/models/llama4-17b-16e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,14 @@ temperature_tuning: True
# Chunk attention is used on all RoPE layers
# otherwise, on NoPE layers, use global attention
chunk_attn_window_size: 8192
image_size_for_vit: 336

# Multimodal flags (need to set use_multimodal=true)
image_size_for_vit: 336
num_channels_for_vit: 3
patch_size_for_vit: 14
hidden_size_for_vit: 1408
intermediate_size_for_vit: 5632
num_hidden_layers_for_vit: 34
num_attention_heads_for_vit: 16
image_placeholder: "<|image|>"

3 changes: 2 additions & 1 deletion src/MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main(argv: Sequence[str]) -> None:

text = config.prompt
prefill_length = config.max_prefill_predict_length
# processor_output = multimodal_utils.PreprocessorOutput()
processor_outputs = multimodal_utils.PreprocessorOutput()
if config.use_multimodal:
image_path = config.image_path.split(",")
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
Expand Down Expand Up @@ -151,6 +151,7 @@ def main(argv: Sequence[str]) -> None:
params=params,
padded_tokens=tokens,
images=np.stack([po.pixel_values for po in processor_outputs]) if config.use_multimodal else None,
image_masks=np.stack([po.pixel_mask for po in processor_outputs]) if config.use_multimodal else None,
true_length=true_length,
rng=rng_prefill,
slot=i,
Expand Down
42 changes: 39 additions & 3 deletions src/MaxText/input_pipeline/_input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def prepare_text_for_image_fusion(example, column_name, model_name):
example[column_name], model_name, processor_output=example["images"]
)
if isinstance(example["images"], list):
example["image_masks"] = [image.pixel_mask for image in example["images"]]
example["images"] = [image.pixel_values for image in example["images"]]
else:
example["image_masks"] = example["images"].pixel_mask
example["images"] = example["images"].pixel_values
return example

Expand Down Expand Up @@ -255,6 +257,7 @@ def map(self, element):
"inputs": np.asarray(inputs[: self.max_target_length], dtype=np.int32),
"targets": np.asarray(targets[: self.max_target_length], dtype=np.int32),
"images": element["images"],
"image_masks": element["image_masks"],
}


Expand Down Expand Up @@ -434,6 +437,25 @@ def _pad_text(self, x, max_length, pad_id):
pad_amount = [(0, pad_amount)] + [(0, 0)] * (len(x.shape) - 1)
return np.pad(x, pad_amount, constant_values=pad_id)[: self.max_length]

def _pad_image_mask(self, image_masks):
"""Pads the input image_masks array to match the maximum required number of images per example."""
image_offsets = multimodal_utils.get_image_offsets(self.model_name, None)
max_num_images = (self.max_length // image_offsets) - 1 # -1 to reserve space for at least one text token
if self.max_num_images_per_example > 0:
max_num_images = min(self.max_num_images_per_example, max_num_images)
assert (
image_masks.shape[0] <= max_num_images
), f"Number of image masks {image_masks.shape[0]} exceeds the maximum allowed {max_num_images}"
if image_masks.shape[0] < max_num_images:
pad_size = max_num_images - image_masks.shape[0]
pad_shape = (pad_size,)
pad_image_masks = np.zeros(pad_shape, dtype=image_masks.dtype)
if image_masks is not None and image_masks.size > 0:
image_masks = np.concatenate([image_masks, pad_image_masks], axis=0)
else:
image_masks = pad_image_masks
return image_masks

def _pad_image(self, images):
"""Pads the input images array to match the maximum required number of images per example.

Expand Down Expand Up @@ -487,13 +509,27 @@ def map(self, element: dict[str, np.ndarray]):
element[f"{data_column}_position"] = np.arange(element[data_column].shape[0], dtype=np.int32)
if self.add_true_length:
element[f"{data_column}_true_length"] = np.array([element[data_column].shape[0]], dtype=np.int32)

for key, _ in element.items():
if key == "images":
if isinstance(element["images"], list):
assert self.model_name is not None, "model_name must be provided when padding images"
if isinstance(element["images"], list) and self.model_name is None:
raise ValueError("model_name must be provided when padding images")
elif isinstance(element["images"], list):
element["images"] = self._pad_image(np.asarray(element["images"]))
else:
elif element["images"].ndim == 3:
element["images"] = np.asarray(element["images"])[None, ...]
else:
# Do not add extra image dimension for image tiling case
element["images"] = np.asarray(element["images"])

elif key == "image_masks" and element["image_masks"] is not None:
if isinstance(element["image_masks"], list) and self.model_name is None:
raise ValueError("model_name must be provided when padding image masks")
elif isinstance(element["image_masks"], list):
element["image_masks"] = self._pad_image_mask(np.asarray(element["image_masks"]))
else:
element["image_masks"] = np.asarray(element["image_masks"])

elif "true_length" not in key:
element[key] = self._pad_text(element[key], self.max_length, self.pad_id)
return element
Expand Down
8 changes: 6 additions & 2 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ def get_decoder_layers(self):
case DecoderBlockType.DEEPSEEK:
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
case DecoderBlockType.GEMMA:
return [gemma.GemmaDecoderLayer]
return [gemma.GemmaDecoderLayerToLinen]
case DecoderBlockType.GEMMA2:
return [gemma2.Gemma2DecoderLayer]
return [gemma2.Gemma2DecoderLayerToLinen]
case DecoderBlockType.GEMMA3:
return [gemma3.Gemma3DecoderLayer]
case DecoderBlockType.GPT3:
Expand Down Expand Up @@ -528,6 +528,7 @@ def _apply_embedding(
model_mode,
image_embeddings=None,
bidirectional_mask=None,
image_masks=None,
):
"""Applies token and positional embeddings to the input tokens."""
cfg = self.config
Expand All @@ -541,6 +542,7 @@ def _apply_embedding(
text_embeddings=y,
vision_embeddings=image_embeddings,
mask=bidirectional_mask,
image_masks=image_masks,
)
# TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed
else:
Expand Down Expand Up @@ -635,6 +637,7 @@ def __call__(
page_state: None | page_manager.PageState = None,
bidirectional_mask: None | Any = None,
image_embeddings: None | jnp.ndarray = None,
image_masks: None | jnp.ndarray = None,
):
cfg = self.config
mesh = self.mesh
Expand All @@ -649,6 +652,7 @@ def __call__(
model_mode,
image_embeddings,
bidirectional_mask,
image_masks,
)

policy = self.get_remat_policy()
Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/layers/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_vision_encoder_layers(self):
elif self.config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
from MaxText.layers import llama4 # pylint: disable=import-outside-toplevel

return [llama4.Llama4VisionModel, llama4.Llama4MultiModalProjector]
return [llama4.llama4visionmodel_as_linen, llama4.llama4multimodalprojector_as_linen]
else:
raise ValueError(f"No VisionEncoder implemented for {self.config.model_name} yet")

Expand Down
Loading
Loading