Skip to content

Commit e2e33c4

Browse files
Qualcomm AI Engine Direct - GA Static Gemma3-1B (#14108)
Summary: - e2e script for GA Static [Gemma3-1B](https://huggingface.co/google/gemma-3-1b-it) - perf: 16a4w block quant token rate in kv mode: ~= 110 tokens/sec(SM8750), max_seq_len=1024 - acc: PPL ~= (fp:21.375 -> htp:23.086) in wikitext dataset - add model params config - add End-to-End example in README - add new architecture: - add new class to support global/local ROPE static llama architecture required by Gemma3 - enable global/local static llama architecture support in runner - refactoring: - refactor attention mask to improve integration with global/local ROPE static llama model - refactor kv_inference and prefill_inference for better readability - Unitest: - add unit test for Gemma3-1B - improve readability of memory size constant in unit test - LLM model config visualization - support tabular LLMmodelConfig visulization ### Test plan ``` bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ```
1 parent 8d1684e commit e2e33c4

24 files changed

+1389
-192
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4638,6 +4638,77 @@ def test_qnn_backend_generate_optrace(self):
46384638

46394639

46404640
class TestExampleLLMScript(TestQNN):
4641+
def test_static_gemma3_1b(self):
4642+
if not self.required_envs():
4643+
self.skipTest("missing required envs")
4644+
4645+
prompt = "My favourite condiment is "
4646+
cmds = [
4647+
"python",
4648+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4649+
"--artifact",
4650+
self.artifact_dir,
4651+
"--build_folder",
4652+
self.build_folder,
4653+
"--model",
4654+
self.model,
4655+
"--ip",
4656+
self.ip,
4657+
"--port",
4658+
str(self.port),
4659+
"--prompt",
4660+
f"{prompt}",
4661+
"--ptq",
4662+
"16a4w_block",
4663+
"--temperature",
4664+
"0",
4665+
"--decoder_model",
4666+
"gemma3-1b",
4667+
"--model_mode",
4668+
"kv",
4669+
"--max_seq_len",
4670+
"1024",
4671+
"--eval_perplexity",
4672+
"--tasks",
4673+
"wikitext",
4674+
"--limit",
4675+
"1",
4676+
"--enable_masked_softmax",
4677+
]
4678+
if self.compile_only:
4679+
cmds.extend(["--compile_only"])
4680+
elif self.device:
4681+
cmds.extend(["--device", self.device])
4682+
if self.host:
4683+
cmds.extend(["--host", self.host])
4684+
elif self.enable_x86_64:
4685+
cmds.extend(["--enable_x86_64"])
4686+
if self.pre_gen_pte:
4687+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4688+
4689+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4690+
with Listener((self.ip, self.port)) as listener:
4691+
conn = listener.accept()
4692+
p.communicate()
4693+
msg = json.loads(conn.recv())
4694+
if "Error" in msg:
4695+
self.fail(msg["Error"])
4696+
else:
4697+
if not self.compile_only:
4698+
self.assertLessEqual(msg["wiki_ppl"], 23)
4699+
if not self.enable_x86_64:
4700+
pte_size = msg["pte_size"]
4701+
self.assertLessEqual(pte_size, 1_200_000_000) # 1.2GB
4702+
inference_speed_ref = {"SM8650": 70, "SM8750": 100}
4703+
if (
4704+
not self.compile_only
4705+
and not self.enable_x86_64
4706+
and self.model in inference_speed_ref
4707+
):
4708+
self.assertGreaterEqual(
4709+
msg["inference_speed"], inference_speed_ref[self.model]
4710+
)
4711+
46414712
def test_llama3_2_1b(self):
46424713
if not self.required_envs():
46434714
self.skipTest("missing required envs")
@@ -4708,7 +4779,7 @@ def test_llama3_2_1b(self):
47084779
# Inference speed on x86 is slow, so we only check when running on Android
47094780
if not self.enable_x86_64:
47104781
pte_size = msg["pte_size"]
4711-
self.assertLessEqual(pte_size, 1300000000)
4782+
self.assertLessEqual(pte_size, 1_300_000_000) # 1.3GB
47124783
if not self.compile_only and not self.enable_x86_64:
47134784
self.assertGreaterEqual(msg["inference_speed"], 66) # Lanai
47144785

@@ -4784,7 +4855,7 @@ def test_llama_stories_260k(self):
47844855
# x86 does not allow weight sharing, so we don't check pte size
47854856
if not self.enable_x86_64:
47864857
pte_size = msg["pte_size"]
4787-
self.assertLessEqual(pte_size, 2020000)
4858+
self.assertLessEqual(pte_size, 2_020_000) # 2MB
47884859
if not self.compile_only and not self.enable_x86_64:
47894860
self.assertGreaterEqual(msg["inference_speed"], 1600) # Lanai
47904861

@@ -4859,7 +4930,7 @@ def test_llama_stories_110m(self):
48594930
# x86 does not allow weight sharing, so we don't check pte size
48604931
if not self.enable_x86_64:
48614932
pte_size = msg["pte_size"]
4862-
self.assertLessEqual(pte_size, 130000000)
4933+
self.assertLessEqual(pte_size, 130_000_000) # 130MB
48634934
if not self.compile_only and not self.enable_x86_64:
48644935
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai
48654936

@@ -4922,7 +4993,7 @@ def test_static_phi4(self):
49224993
else:
49234994
inference_speed_ref = {"SM8650": 14, "SM8750": 19}
49244995
self.assertLessEqual(msg["wiki_ppl"], 12)
4925-
self.assertLessEqual(msg["pte_size"], 4000000000) # 4gb
4996+
self.assertLessEqual(msg["pte_size"], 4_000_000_000) # 4GB
49264997
if self.model in inference_speed_ref:
49274998
self.assertGreaterEqual(
49284999
msg["inference_speed"], inference_speed_ref[self.model]
@@ -4981,7 +5052,7 @@ def test_static_qwen2_5(self):
49815052
else:
49825053
inference_speed_ref = {"SM8650": 115, "SM8750": 155}
49835054
self.assertLessEqual(msg["wiki_ppl"], 15)
4984-
self.assertLessEqual(msg["pte_size"], 600000000) # 600mb
5055+
self.assertLessEqual(msg["pte_size"], 600_000_000) # 600MB
49855056
if self.model in inference_speed_ref:
49865057
self.assertGreaterEqual(
49875058
msg["inference_speed"], inference_speed_ref[self.model]
@@ -5040,7 +5111,7 @@ def test_static_qwen3(self):
50405111
else:
50415112
inference_speed_ref = {"SM8650": 38, "SM8750": 56}
50425113
self.assertLessEqual(msg["wiki_ppl"], 18)
5043-
self.assertLessEqual(msg["pte_size"], 950_000_000) # 950mb
5114+
self.assertLessEqual(msg["pte_size"], 950_000_000) # 950MB
50445115
if self.model in inference_speed_ref:
50455116
self.assertGreaterEqual(
50465117
msg["inference_speed"], inference_speed_ref[self.model]

examples/models/gemma3/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.gemma3.convert_weights import convert_weights
5+
from executorch.examples.models.llama.model import Llama2Model
6+
7+
8+
class Gemma3Model(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"Gemma3Model",
15+
"convert_weights",
16+
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"dim": 1152,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 6912,
5+
"n_heads": 4,
6+
"head_dim": 256,
7+
"n_kv_heads": 1,
8+
"n_layers": 26,
9+
"act_fn": "gelu_approx",
10+
"norm_type": "gemma3",
11+
"norm_eps": 1e-06,
12+
"post_attention_norm": true,
13+
"post_ffn_norm": true,
14+
"rope_theta": 1000000.0,
15+
"use_scaled_rope": false,
16+
"apply_embedding": true,
17+
"embedding_scale_factor": 33.941125497,
18+
"vocab_size": 262144,
19+
"use_hf_rope": true,
20+
"attention_qkv_bias": false,
21+
"use_qk_norm": true,
22+
"qk_norm_before_rope": true
23+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import argparse
2+
3+
import json
4+
import os
5+
from typing import Dict
6+
7+
import torch
8+
from safetensors.torch import load_file
9+
10+
from torchtune.models.convert_weights import get_mapped_key
11+
12+
13+
# Weight mappings from Gemma 3's checkpoint to ExecuTorch's transformer parameters.
14+
_GEMMA3_TO_EXECUTORCH = {
15+
"model.embed_tokens.weight": "tok_embeddings.weight",
16+
"model.norm.weight": "norm.weight",
17+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
18+
"model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight",
19+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
20+
"model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight",
21+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
22+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
23+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
24+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_norm.weight",
25+
"model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.ffn_norm.weight",
26+
"model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_ffn_norm.weight",
27+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
28+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
29+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
30+
}
31+
32+
33+
def gemma3_to_executorch(
34+
state_dict: Dict[str, torch.Tensor]
35+
) -> Dict[str, torch.Tensor]:
36+
"""
37+
Convert the state dict so that it matches what ExecuTorch's transformer definition expects.
38+
"""
39+
converted_state_dict = {}
40+
for key, value in state_dict.items():
41+
new_key = get_mapped_key(key, _GEMMA3_TO_EXECUTORCH)
42+
converted_state_dict[new_key] = value
43+
converted_state_dict["output.weight"] = converted_state_dict[
44+
"tok_embeddings.weight"
45+
]
46+
return converted_state_dict
47+
48+
49+
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
50+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
51+
if os.path.exists(index_path):
52+
# Sharded checkpoint.
53+
with open(index_path, "r") as f:
54+
index = json.load(f)
55+
weight_map = index["weight_map"]
56+
checkpoint_shards = sorted(set(weight_map.values()))
57+
58+
# Load all the shards into memory
59+
shard_to_weights = {}
60+
for shard in checkpoint_shards:
61+
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
62+
63+
# Merge tensors into consolidated state dict.
64+
merged_state_dict = {}
65+
for weight_name, shard in weight_map.items():
66+
tensor = shard_to_weights[shard][weight_name]
67+
merged_state_dict[weight_name] = tensor
68+
return merged_state_dict
69+
else:
70+
# Single checkpoint.
71+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
72+
return state_dict
73+
74+
75+
def load_checkpoint(input_dir: str) -> Dict:
76+
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
77+
if os.path.exists(pytorch_path):
78+
print("Loading checkpoint from PyTorch .bin file")
79+
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
80+
print("Loading checkpoint from safetensors directory")
81+
return load_checkpoint_from_safetensors(input_dir)
82+
83+
84+
def convert_weights(input_dir: str, output_file: str) -> None:
85+
print("Loading checkpoint...")
86+
sd = load_checkpoint(input_dir)
87+
print("Converting checkpoint...")
88+
sd = gemma3_to_executorch(sd)
89+
print("Saving checkpoint...")
90+
torch.save(sd, output_file)
91+
print("Done.")
92+
93+
94+
def main():
95+
parser = argparse.ArgumentParser(
96+
description="Convert Gemma3 weights to ExecuTorch transformer format."
97+
)
98+
parser.add_argument(
99+
"input_dir",
100+
type=str,
101+
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
102+
)
103+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
104+
105+
args = parser.parse_args()
106+
convert_weights(args.input_dir, args.output)
107+
108+
109+
if __name__ == "__main__":
110+
main()

examples/models/llama/model_args.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,39 @@
11
import dataclasses
22
from dataclasses import dataclass
3+
from enum import Enum
4+
from functools import partial
35
from typing import Any, Dict, Optional
46

7+
import torch.nn.functional as F
8+
9+
10+
class ActFn(Enum):
11+
SILU = "silu"
12+
GELU = "gelu"
13+
GELU_APPROX = "gelu_approx"
14+
15+
@classmethod
16+
def from_string(cls, value: str) -> "ActFn":
17+
"""Convert string to ActFn enum."""
18+
try:
19+
return cls(value)
20+
except ValueError:
21+
valid_values = [e.value for e in cls]
22+
raise ValueError(
23+
f"Invalid activation function: {value}. Valid options: {valid_values}"
24+
)
25+
26+
def get_function(self):
27+
"""Return the corresponding activation function."""
28+
if self == ActFn.SILU:
29+
return F.silu
30+
elif self == ActFn.GELU:
31+
return F.gelu
32+
elif self == ActFn.GELU_APPROX:
33+
return partial(F.gelu, approximate="tanh")
34+
else:
35+
raise ValueError(f"Unsupported activation function: {self}")
36+
537

638
@dataclass
739
class ModelArgs:
@@ -15,13 +47,17 @@ class ModelArgs:
1547
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
1648
ffn_dim_multiplier: Optional[float] = None
1749
norm_eps: float = 1e-5
50+
post_attention_norm: bool = False
51+
post_ffn_norm: bool = False
1852
max_batch_size: int = 1
1953
max_seq_len: int = 2048
2054
max_context_len: int = 2048
2155
moe: bool = False # True to enable the MoE (Mixture of Experts)
2256
num_experts: int = 8 # Number of experts
2357
num_activated_experts: int = 2 # Number of experts to activate
2458
attention_type: str = "mha" # Attention type, registered in attention.py
59+
norm_type: str = "rmsnorm" # Normalization type, registered in norm.py
60+
act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type
2561
attention_qkv_bias: bool = False
2662
use_kv_cache: bool = False # Use key/value cache
2763
use_sdpa_with_kv_cache_op: bool = (
@@ -37,6 +73,7 @@ class ModelArgs:
3773
# A dictionary mapping from pruned token-id to original token-id
3874
output_prune_map: Optional[Dict[int, int]] = None
3975
apply_embedding: bool = True # Use embedding inside the transformer
76+
embedding_scale_factor: float = 1.0 # Multiple by which to scale embeddings.
4077
apply_output: bool = True # Use output layer (unembedding) inside the transformer
4178
use_qk_norm: bool = False # apply normalization to q and k in the attention
4279
qk_norm_before_rope: bool = False # when to apply qk norm
@@ -103,3 +140,7 @@ def find_multiple(n: int, k: int) -> int:
103140

104141
if self.head_dim is None:
105142
self.head_dim = self.dim // self.n_heads
143+
144+
# Convert string act_fn to enum if needed
145+
if isinstance(self.act_fn, str):
146+
self.act_fn = ActFn.from_string(self.act_fn)

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ list(
2828
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
2929
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
3030
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
31+
${CMAKE_CURRENT_LIST_DIR}/runner/cache_utils.h
3132
${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.cpp
3233
${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.h
3334
${CMAKE_CURRENT_LIST_DIR}/runner/prompt_processor.cpp

0 commit comments

Comments
 (0)