Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
from QEfficient.transformers.modeling_utils import (
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH,
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH,
)
from QEfficient.transformers.models.pytorch_transforms import (
BlockedKVAttentionTransform,
Expand Down Expand Up @@ -2592,27 +2591,33 @@ def export(
)
enable_chunking = kwargs.get("enable_chunking", False)

# TODO: move this to a DA Serving utility class
if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH:
if prefill_only:
if self.continuous_batching and not enable_chunking:
raise NotImplementedError("Can't enable prefix-caching without chunking")
self.prefill(enable=True, enable_chunking=enable_chunking)
self.hash_params.pop("retain_full_kv", None)
seq_len = self.get_seq_len_and_handle_specialized_prefill_model(
prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
if prefill_only:
if not enable_chunking and self.continuous_batching:
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
else:
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.hash_params.pop("prefill_only", None)
self.hash_params.pop("NUM_Q_BLOCKS", None)
self.hash_params.pop("NUM_FFN_BLOCKS", None)
self.hash_params.pop("ENABLE_OPT_SWA", None)
self.hash_params.pop("chunking", None)
if kwargs.get("retain_full_kv", False):
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
self.hash_params["retain_full_kv"] = True
self.prefill(enable=True, enable_chunking=enable_chunking)
self.hash_params.pop("retain_full_kv", None)
seq_len = self.get_seq_len_and_handle_specialized_prefill_model(
prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
)
kv_cache_shape[2] = (
seq_len + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0)
if enable_chunking
else seq_len
)
else:
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
self.hash_params.pop("prefill_only", None)
self.hash_params.pop("NUM_Q_BLOCKS", None)
self.hash_params.pop("NUM_FFN_BLOCKS", None)
self.hash_params.pop("ENABLE_OPT_SWA", None)
self.hash_params.pop("chunking", None)
if kwargs.get("retain_full_kv", False):
kv_cache_shape[2] = seq_len + (
self.model.config.sliding_window if self.model.config.sliding_window is not None else 0
)
self.hash_params["retain_full_kv"] = True

example_inputs = {
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@
QEffQwen3Model,
)
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
QEffPrefillChunkedQwen3MoeSparseMoeBlock,
QEffQwen3MoeAttention,
QEffQwen3MoeDecoderLayer,
QEffQwen3MoeForCausalLM,
Expand Down Expand Up @@ -663,19 +664,25 @@ class PrefillOnlyTransform(ModuleMappingTransform):

class PrefillOnlyChunkedTransform(ModuleMappingTransform):
_module_mapping = {
# GPT_OSS
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP,
# Qwen3Moe
QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock,
}


class RevertPrefillKeepAttentionTransform(ModuleMappingTransform):
_module_mapping = {
# GPT_OSS
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffPrefillOnlyGptOssMLP: QEffGptOssMLP,
QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP,
# Qwen3Moe
QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
}


Expand Down
50 changes: 23 additions & 27 deletions QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def eager_attention_forward(
key_states = repeat_kv(key, module.num_key_value_groups)

value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
Expand All @@ -118,53 +117,50 @@ def eager_attention_forward(
return attn_output, attn_weights


class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __qeff_init__(self):
self.gate_proj_w = []
self.up_proj_w = []
self.down_proj_w = []
with torch.no_grad():
for e in range(self.num_experts):
self.gate_proj_w.append(self.experts[e].gate_proj.weight.T)
self.up_proj_w.append(self.experts[e].up_proj.weight.T)
self.down_proj_w.append(self.experts[e].down_proj.weight.T)
self.gate_proj_w = torch.stack(self.gate_proj_w)
self.up_proj_w = torch.stack(self.up_proj_w)
self.down_proj_w = torch.stack(self.down_proj_w)

def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
B, S, H = hidden_states.shape
T = B * S
x = hidden_states.view(T, H)

router_logits = self.gate(x) # [T, E]
prob = F.softmax(router_logits, -1, dtype=torch.float)
top_w, top_i = torch.topk(prob, self.top_k, -1)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
top_w /= top_w.sum(-1, keepdim=True)
top_w = top_w.to(x.dtype)
top_w = top_w.to(hidden_states.dtype)
masked_logits = torch.zeros_like(router_logits)
masked_logits.scatter_(1, top_i, top_w)

# Routing weights for each expert [T, E]
routing_weights = masked_logits

# ────────────────── allocate the output tensor ─────
expert_out = x.new_zeros((T, H)) # accumulation buffer

# ───────────────────────── Expert computation loop ─────────────────────────────
for e in range(self.num_experts):
routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
W_g, W_u = self.experts[e].gate_proj, self.experts[e].up_proj # [H, I], [H, I]
W_d = self.experts[e].down_proj # [I, H]
gate = W_g(x) # [T, I]
up = W_u(x) # [T, I]
down = W_d(up * self.experts[e].act_fn(gate)) # [T, H]

W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I]
W_d = self.experts[e].down_proj.weight.T # [I, H]
gate = x @ W_g # [T, I]
up = x @ W_u # [T, I]
down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H]
masked_down = torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out))
expert_out += masked_down
return expert_out.view(B, S, H), router_logits


class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
def __qeff_init__(self):
self.gate_proj_w = []
self.up_proj_w = []
self.down_proj_w = []
with torch.no_grad():
for e in range(self.num_experts):
self.gate_proj_w.append(self.experts[e].gate_proj.weight.T)
self.up_proj_w.append(self.experts[e].up_proj.weight.T)
self.down_proj_w.append(self.experts[e].down_proj.weight.T)
self.gate_proj_w = torch.stack(self.gate_proj_w)
self.up_proj_w = torch.stack(self.up_proj_w)
self.down_proj_w = torch.stack(self.down_proj_w)

def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
B, S, H = hidden_states.shape
T = B * S
Expand Down
133 changes: 133 additions & 0 deletions examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import time

import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer

from QEfficient import QEFFAutoModelForCausalLM
from QEfficient.generation.cloud_infer import QAICInferenceSession

model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32
prompt = """
Explain quantum computing in simple terms.
"""
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
PREFILL_SEQ_LEN = 128
CTX_LEN = 128 * 3

qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
decode_qpc_path = qeff_model.compile(
prefill_seq_len=1,
ctx_len=CTX_LEN,
num_cores=16,
mxfp6_matmul=True,
mxint8_kv_cache=True,
num_devices=1,
mos=1,
aic_enable_depth_first=True,
num_speculative_tokens=None,
offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step
retain_full_kv=True,
)

# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68

# prefill_qpc_path = ""

prefill_qpc_path = qeff_model.compile(
prefill_seq_len=PREFILL_SEQ_LEN,
ctx_len=CTX_LEN,
num_cores=16,
mxfp6_matmul=True,
mxint8_kv_cache=True,
num_devices=2,
split_retained_state_io=True,
mos=1,
aic_enable_depth_first=True,
num_speculative_tokens=None,
prefill_only=True,
enable_chunking=True,
# use_onnx_subfunctions=True,
)


inputs = tokenizer(prompt, return_tensors="np", padding=True)
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
generation_len = CTX_LEN - position_ids.max()
padded_len = inputs["input_ids"].shape[1]
num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
inputs.pop("token_type_ids", None)
inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
inputs.pop("past_key_values", None)
inputs = {k: v.detach().numpy() for k, v in inputs.items()}


prefill_session = QAICInferenceSession(prefill_qpc_path)
decode_session = QAICInferenceSession(decode_qpc_path)

all_outputs = []
for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
ins = time.time()
qpc_out = prefill_session.run(chunk_inputs)
print(f"time for this run={time.time() - ins}")
for i in range(config.num_hidden_layers):
inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]

all_outputs.append(np.argmax(qpc_out["logits"]))

decode_inputs = {
"input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
"position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
}
for i in range(config.num_hidden_layers):
decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]

st = time.time()
decode_out = decode_session.run(decode_inputs)
print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
all_outputs.append(np.argmax(decode_out["logits"]))
pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
loop_decode_inputs = {
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
"position_ids": pos_id,
}

for i in range(config.num_hidden_layers):
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]

st = time.time()
for i in range(generation_len - 2):
decode_out = decode_session.run(loop_decode_inputs)
all_outputs.append(np.argmax(decode_out["logits"]))
pos_id += 1
for i in range(config.num_hidden_layers):
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]

loop_decode_inputs.update(
{
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
"position_ids": pos_id,
}
)
ft = time.time()

print(f"decode tok/sec={(generation_len - 2) / (ft - st)}")
print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}")
Loading
Loading