Skip to content

Commit 61b1445

Browse files
authored
BugFix: Fix reshape error for llama swiftkv models (#432)
This fixed the issue with higher BS compilation for SwiftKV models ``` Compiler command: ['/opt/qti-aic/exec/qaic-exec', '-aic-hw', '-aic-hw-version=2.0', '-m=/prj/qct/aisyssol_scratch/users/shagsood/quic_shagun/LlamaSwiftKVForCausalLM-a5879ebc0e59ab40/LlamaSwiftKVForCausal LM.onnx', '-compile-only', '-retained-state', '-convert-to-fp16', '-aic-num-cores=16', '-network-specialization-config=/prj/qct/aisyssol_scratch/users/shagsood/quic_shagun/LlamaSwiftKVForCausalLM-a5879eb c0e59ab40/qpc-60f86f912a187346/specializations.json', '-custom-IO-list-file=/prj/qct/aisyssol_scratch/users/shagsood/quic_shagun/LlamaSwiftKVForCausalLM-a5879ebc0e59ab40/qpc-60f86f912a187346/custom_io.ya ml', '-mdp-load-partition-config=/prj/qct/aisyssol_scratch/users/shagsood/quic_shagun/LlamaSwiftKVForCausalLM-a5879ebc0e59ab40/qpc-60f86f912a187346/mdp_ts_4.json', '-aic-binary-dir=/prj/qct/aisyssol_scra tch/users/shagsood/quic_shagun/LlamaSwiftKVForCausalLM-a5879ebc0e59ab40/qpc-60f86f912a187346/qpc'] Compiler exitcode: 1 Compiler stderr: QAIC_ERROR: Error message: [Operator-'/model/layers.16/self_attn/Reshape'] : Reshape: input shape (4, 4, 4096) and output shape (4, 1, 32, 128) have different number of elements (in 65536 vs. out 16384) Unable to AddNodesToGraphFromModel ``` Tested with BS4. Able to compile now Signed-off-by: quic-shagun <[email protected]>
1 parent 740f7c2 commit 61b1445

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,8 @@ def forward(
371371
hidden_states = orig_hidden_states[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), last_pos_id, :]
372372
causal_mask = causal_mask[torch.arange(orig_hidden_states.shape[0]).reshape(-1, 1), :, last_pos_id, :]
373373
else:
374-
hidden_states = orig_hidden_states[torch.arange(bsz), last_pos_id, :]
375-
causal_mask = causal_mask[torch.arange(bsz), :, last_pos_id, :]
374+
hidden_states = orig_hidden_states[torch.arange(bsz).reshape(-1, 1), last_pos_id, :]
375+
causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :]
376376

377377
hidden_states, next_decoder_cache = self._run_swiftkv_layers(
378378
hidden_states, position_ids, past_key_values, causal_mask, batch_index

0 commit comments

Comments
 (0)