Skip to content

Conversation

DannyYuyang-quic
Copy link
Contributor

@DannyYuyang-quic DannyYuyang-quic commented Sep 9, 2025

Summary:

  • e2e script for GA Static Gemma3-1B
    • perf: 16a4w block quant token rate in kv mode: ~= 110 tokens/sec(SM8750), max_seq_len=1024
    • acc: PPL ~= (fp:21.375 -> htp:23.086) in wikitext dataset
    • add model params config
    • add End-to-End example in README
  • add new architecture:
    • add new class to support global/local ROPE static llama architecture required by Gemma3
    • enable global/local static llama architecture support in runner
  • refactoring:
    • refactor attention mask to improve integration with global/local ROPE static llama model
    • refactor kv_inference and prefill_inference for better readability
  • Unitest:
    • add unit test for Gemma3-1B
    • improve readability of memory size constant in unit test
  • LLM model config visualization
    • support tabular LLMmodelConfig visulization

Test plan

python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1

Copy link

pytorch-bot bot commented Sep 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/14108

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job, 5 Unrelated Failures

As of commit 9c80e2f with merge base 8496f27 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 9, 2025
@DannyYuyang-quic
Copy link
Contributor Author

@pytorchbot label "release notes: qualcomm"

@pytorch-bot pytorch-bot bot added the release notes: qualcomm Changes to the Qualcomm backend delegate label Sep 9, 2025
@DannyYuyang-quic
Copy link
Contributor Author

Hi @cccclai,
This PR enables support for Gemma3-1B in the static version.

Both accuracy and performance in Hybrid/KV mode are promising.
However, since Gemma3 uses global/local attention mechanism, the implementation of lookahead decoding is a bit tricky, so we temporarily block lookahead decoding only for Gemma3 now, we plan to enable lookahead decoding support for Gemma3 in a future update.

cc: @haowhsu-quic

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this in D82034374.

Copy link
Contributor

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes in examples/models/gemma3 and examples/models/llama relevant?

@cccclai
Copy link
Contributor

cccclai commented Sep 9, 2025

Are the changes in examples/models/gemma3 and examples/models/llama relevant?

Yes because we reuse the config from the etllm for qualcomm llm models as well

@cccclai
Copy link
Contributor

cccclai commented Sep 9, 2025

There are some merge conflicts, can you resolve it?

Summary:
- e2e script for GA Static Gemma3-1B
  - perf: 16a4w block quant token rate in kv mode: ~= 110 tokens/sec(SM8750)
  - acc: PPL ~= (fp:21.375 -> htp:23.086) in wikitext dataset
  - add model params config
  - add End-to-End example in README
- add new architecture:
  - add new class to support global/local ROPE static llama architecture required by Gemma3
  - enable global/local static llama architecture support in runner
- refactoring:
  - refactor attention mask to improve integration with global/local ROPE static llama model
  - refactor kv_inference and prefill_inference for better readability
- Unitest:
  - add unit test for Gemma3-1B
  - improve readability of memory size constant in unit test
- LLM model config visualization
  - support tabular LLMmodelConfig visulization
@DannyYuyang-quic DannyYuyang-quic force-pushed the dev1/danny/GA_static_gemma3 branch from 2333ffc to 9c80e2f Compare September 10, 2025 01:26
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this in D82034374.

Copy link
Contributor

@jackzhxng jackzhxng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Trunk errors look like flakes

@cccclai cccclai merged commit e2e33c4 into pytorch:main Sep 10, 2025
456 of 472 checks passed
@mergennachin
Copy link
Contributor

@DannyYuyang-quic

There's a bug in this code that was uncovered in our internal testing:

executorch/examples/qualcomm/oss_scripts/llama/model/static_llama.py", line 266, in forward_sha
    attn = attn + atten_mask
           ~~~~~^~~~~~~~~~~~
TypeError: unsupported operand type(s) for +: 'Tensor' and 'AttentionMask'

@mergennachin
Copy link
Contributor

atten_mask is sometimes Tensor and sometimes AttenionMask object.

cc @cccclai

@mergennachin
Copy link
Contributor

Can you fix this asap, otherwise, I'll revert this PR

@haowhsu-quic
Copy link
Collaborator

May I know the details of internal test scenario? I tested the latest mainline without masked_softmax and everything works fine.

@haowhsu-quic
Copy link
Collaborator

The atten_mask is now a wrapper class for causal / sliding attention. We do unwrap it when perform lowering process in llama.py.
Would it be possible you use the inputs from LlamaModel.get_example_inputs() to do the inference? If so, please add asterisk to the attention_mask:

tokens, atten_mask, pos_ids, k_cache, v_cache = model.get_example_inputs()
logits, new_k_caches, new_v_caches = module(
  tokens,
  *atten_mask,
  pos_ids,
  *k_caches,
  *v_caches,
)

@cccclai
Copy link
Contributor

cccclai commented Sep 11, 2025

I'll try to fix it, looks like just an internal reference need to update

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: qualcomm Changes to the Qualcomm backend delegate
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants