-
Notifications
You must be signed in to change notification settings - Fork 64
Adding dissagg mode support to Qwen3Moe #682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
qcdipankar
wants to merge
10
commits into
quic:main
Choose a base branch
from
qcdipankar:qwen3moe_dissagg_mode
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+204
−57
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
81cf999
Adding dissagg mode support to Qwen3Moe
qcdipankar 4819fbc
Cleaning of example script
qcdipankar d2ba282
Lint fix
qcdipankar 4585c93
Minor fixes
qcdipankar f24dc83
Lint Fix 2
qcdipankar 28d3f78
Adding qwen3moe to test
qcdipankar 28c1743
Cleaning test for dissagg
qcdipankar 728df6d
Addressing Changes and Review Comments
qcdipankar b96f544
Merge branch 'main' into qwen3moe_dissagg_mode
qcdipankar 17662e7
Fix lint error
qcdipankar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py
qcdipankar marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| # ----------------------------------------------------------------------------- | ||
| # | ||
| # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| import time | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from transformers import AutoConfig, AutoTokenizer | ||
|
|
||
| from QEfficient import QEFFAutoModelForCausalLM | ||
| from QEfficient.generation.cloud_infer import QAICInferenceSession | ||
|
|
||
| model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 | ||
| prompt = """ | ||
| Explain quantum computing in simple terms. | ||
| """ | ||
| config = AutoConfig.from_pretrained(model_id) | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| PREFILL_SEQ_LEN = 128 | ||
| CTX_LEN = 128 * 3 | ||
|
|
||
| qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) | ||
| decode_qpc_path = qeff_model.compile( | ||
| prefill_seq_len=1, | ||
| ctx_len=CTX_LEN, | ||
| num_cores=16, | ||
| mxfp6_matmul=True, | ||
| mxint8_kv_cache=True, | ||
| num_devices=1, | ||
| mos=1, | ||
| aic_enable_depth_first=True, | ||
| num_speculative_tokens=None, | ||
| offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step | ||
| retain_full_kv=True, | ||
| ) | ||
|
|
||
| # Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 | ||
|
|
||
| # prefill_qpc_path = "" | ||
|
|
||
| prefill_qpc_path = qeff_model.compile( | ||
| prefill_seq_len=PREFILL_SEQ_LEN, | ||
| ctx_len=CTX_LEN, | ||
| num_cores=16, | ||
| mxfp6_matmul=True, | ||
| mxint8_kv_cache=True, | ||
| num_devices=2, | ||
| split_retained_state_io=True, | ||
| mos=1, | ||
| aic_enable_depth_first=True, | ||
| num_speculative_tokens=None, | ||
| prefill_only=True, | ||
| enable_chunking=True, | ||
| # use_onnx_subfunctions=True, | ||
| ) | ||
|
|
||
|
|
||
| inputs = tokenizer(prompt, return_tensors="np", padding=True) | ||
| position_ids = inputs["attention_mask"].sum(1, keepdims=True) | ||
| generation_len = CTX_LEN - position_ids.max() | ||
| padded_len = inputs["input_ids"].shape[1] | ||
| num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float | ||
| padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len | ||
| inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) | ||
| inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) | ||
| inputs.pop("token_type_ids", None) | ||
| inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} | ||
| inputs.pop("past_key_values", None) | ||
| inputs = {k: v.detach().numpy() for k, v in inputs.items()} | ||
|
|
||
|
|
||
| prefill_session = QAICInferenceSession(prefill_qpc_path) | ||
| decode_session = QAICInferenceSession(decode_qpc_path) | ||
|
|
||
| all_outputs = [] | ||
| for i in range(num_chunks): | ||
| chunk_inputs = inputs.copy() | ||
| chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] | ||
| chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] | ||
| ins = time.time() | ||
| qpc_out = prefill_session.run(chunk_inputs) | ||
| print(f"time for this run={time.time() - ins}") | ||
| for i in range(config.num_hidden_layers): | ||
| inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] | ||
| inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] | ||
|
|
||
| all_outputs.append(np.argmax(qpc_out["logits"])) | ||
|
|
||
| decode_inputs = { | ||
| "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), | ||
| "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, | ||
| } | ||
| for i in range(config.num_hidden_layers): | ||
| decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] | ||
| decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] | ||
|
|
||
| st = time.time() | ||
| decode_out = decode_session.run(decode_inputs) | ||
| print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") | ||
| all_outputs.append(np.argmax(decode_out["logits"])) | ||
| pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 | ||
| loop_decode_inputs = { | ||
| "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), | ||
| "position_ids": pos_id, | ||
| } | ||
|
|
||
| for i in range(config.num_hidden_layers): | ||
| loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] | ||
| loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] | ||
|
|
||
| st = time.time() | ||
| for i in range(generation_len - 2): | ||
| decode_out = decode_session.run(loop_decode_inputs) | ||
| all_outputs.append(np.argmax(decode_out["logits"])) | ||
| pos_id += 1 | ||
| for i in range(config.num_hidden_layers): | ||
| loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] | ||
| loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] | ||
|
|
||
| loop_decode_inputs.update( | ||
| { | ||
| "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), | ||
| "position_ids": pos_id, | ||
| } | ||
| ) | ||
| ft = time.time() | ||
|
|
||
| print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") | ||
| print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.