Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions src/MaxText/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,33 +137,64 @@ def create_sharded_state():
model = nnx.merge(graphdef, sharded_state)

if config.load_parameters_path:
target_for_restore = jax.tree.map(
lambda v: v.value,
sharded_state,
is_leaf=lambda n: isinstance(n, nnx.Variable),
)

try:
ckptr = ocp.Checkpointer(
ocp.PyTreeCheckpointHandler(
restore_concurrent_gb=None,
save_concurrent_gb=None,
use_ocdbt=True,
use_zarr3=True,
restore_concurrent_gb=config.checkpoint_storage_concurrent_gb,
save_concurrent_gb=config.checkpoint_storage_concurrent_gb,
use_ocdbt=config.checkpoint_storage_use_ocdbt,
use_zarr3=config.checkpoint_storage_use_zarr3,
)
)

# This is a memory optimization. We don't want to restore the entire checkpoint - only the params.
# Rather than passing the entire abstract state, which could unnecessarily restore opt_state and
# waste memory, we instead restore the params field of the checkpoint (which itself may be a dictionary
# containing a key named 'params').
restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)

# Get the structure of checkpoint in `config.load_parameters_path`
metadata = ckptr.metadata(config.load_parameters_path)

is_nnx_checkpoint = True
if (
"params" in metadata.item_metadata.tree.keys()
and "params" in metadata.item_metadata.tree.get("params", {}).keys()
):
# structure of linen checkpoint: {'params': {'params': {'decoder': ...}}}
is_nnx_checkpoint = False
target_for_restore = jax.tree.map(
lambda v: v.value,
sharded_state,
is_leaf=lambda n: hasattr(n, "value"),
)

item_to_restore = {"params": {"params": target_for_restore}}
restore_args = {"params": {"params": ocp.checkpoint_utils.construct_restore_args(target_for_restore)}}
else:
# structure of nnx checkpoint: {'decoder': {'value': ...}}
target_for_restore = jax.tree.map(
lambda v: {"value": v.value},
sharded_state,
is_leaf=lambda n: isinstance(n, nnx.Variable),
)
item_to_restore = target_for_restore
restore_args = ocp.checkpoint_utils.construct_restore_args(target_for_restore)

restored = ckptr.restore(
epath.Path(config.load_parameters_path),
item={"params": {"params": target_for_restore}},
item=item_to_restore,
transforms={},
restore_args={"params": {"params": restore_args}},
restore_args=restore_args,
)
checkpoint = restored["params"]["params"]

if is_nnx_checkpoint:
checkpoint = jax.tree.map(
lambda v: v["value"],
restored,
is_leaf=lambda x: isinstance(x, dict) and "value" in x and not isinstance(x.get("value"), dict),
)
else:
checkpoint = restored["params"]["params"]

if checkpoint:
nnx.update(model, checkpoint)
Expand Down
45 changes: 28 additions & 17 deletions src/MaxText/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
An example script to perform decoding using vLLM with a MaxText model.

Example command:
python3 -m MaxText.vllm_decode MaxText/configs/sft.yml \
python3 -m MaxText.vllm_decode MaxText/configs/base.yml \
model_name=llama3.1-8b tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
tokenizer_type=huggingface hf_access_token=<your_hf_token> \
load_parameters_path=<your_checkpoint_path> \
per_device_batch_size=1 run_name=vllm_decode_test \
use_chat_template=True prompt="Suggest some famous landmarks in London." \
per_device_batch_size=1 run_name=vllm_decode_test max_target_length=64 \
use_chat_template=False prompt="Suggest some famous landmarks in London." \
decode_sampling_temperature=0.0 decode_sampling_nucleus_p=1.0 decode_sampling_top_k=0.0
"""

Expand Down Expand Up @@ -50,48 +50,59 @@ def decode(
# Wrap the model for Tunix
tunix_model = TunixMaxTextAdapter(base_model=model)

# Load the tokenizer and format the prompt
model_tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path, token=config.hf_access_token)
model_tokenizer.bos_token = None
# Load the tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer_path,
token=config.hf_access_token,
model_max_length=config.max_target_length,
)
tokenizer.bos_token = None

# Format the prompt using chat template if specified
prompts = [config.prompt]
if config.use_chat_template:
# Format the prompt using chat template if specified
messages = [
{"role": "user", "content": config.prompt},
]
formatted_prompt_string = model_tokenizer.apply_chat_template(
input_with_chat_template = tokenizer.apply_chat_template(
messages,
tokenize=False, # Set to False to get the string
add_generation_prompt=True,
add_special_tokens=False, # Prevent adding special tokens
)
print("Formatted prompt string:", formatted_prompt_string)
prompts = [formatted_prompt_string]
prompts = [input_with_chat_template]

max_prompt_length = max(len(tokenizer.encode(p)) for p in prompts)
max_tokens_to_generate = config.max_target_length - max_prompt_length

# Create vLLM rollout for inference
rollout_config = base_rollout.RolloutConfig(
max_tokens_to_generate=config.max_target_length - config.max_prefill_predict_length,
max_prompt_length=config.max_prefill_predict_length,
max_tokens_to_generate=max_tokens_to_generate,
max_prompt_length=max_prompt_length,
temperature=config.decode_sampling_temperature,
top_p=config.decode_sampling_nucleus_p,
top_k=config.decode_sampling_top_k,
)
vllm_rollout = VllmRollout(
model=tunix_model,
tokenizer=model_tokenizer,
cache_config_or_size=config.max_target_length, # Max sequence length
tokenizer=tokenizer,
# The cache_config_or_size sets the absolute maximum sequence length.
# We add 256 as a safety buffer to account for tokens added by
# other special formatting, which is not part of max_prompt_length.
cache_config_or_size=max_prompt_length + max_tokens_to_generate + 256,
mesh=mesh,
model_version=config.tokenizer_path,
hbm_utilization=0.8,
init_with_random_weights=True, # Use random weights
# Initialize vllm model with random weights to speed up bootstrap time.
# Actual model weights will be loaded later.
init_with_random_weights=True,
tpu_backend_type="jax",
)

# Generate text
output = vllm_rollout.generate(prompts, rollout_config)
print("Generated text:")
print(output.text[0])
print(f"Prompt: {config.prompt}")
print(f"Output: {output.text[0]}")


def main(argv: Sequence[str]) -> None:
Expand Down
Loading