Skip to content

Commit d83636d

Browse files
authored
Ability to specify full file configs for export_llm (#11809)
1 parent 7f2fcb0 commit d83636d

File tree

6 files changed

+327
-5
lines changed

6 files changed

+327
-5
lines changed

examples/models/llama/config/llm_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ class BaseConfig:
6565
params: Model parameters, such as n_layers, hidden_size, etc.
6666
If left empty will use defaults specified in model_args.py.
6767
checkpoint: Path to the checkpoint file.
68-
If left empty, the model will be initialized with random weights.
68+
If left empty, the model will either be initialized with random weights
69+
if it is a Llama model or the weights will be downloaded from HuggingFace
70+
if it is a non-Llama model.
6971
checkpoint_dir: Path to directory containing sharded checkpoint files.
7072
tokenizer_path: Path to the tokenizer file.
7173
metadata: Json string containing metadata information.

examples/models/llama/export_llama_lib.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
)
5454
from executorch.util.activation_memory_profiler import generate_memory_trace
5555

56+
from omegaconf import DictConfig
57+
5658
from ..model_factory import EagerModelFactory
5759
from .source_transformation.apply_spin_quant_r1_r2 import (
5860
fuse_layer_norms,
@@ -571,12 +573,14 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
571573

572574

573575
def export_llama(
574-
export_options: Union[argparse.Namespace, LlmConfig],
576+
export_options: Union[argparse.Namespace, LlmConfig, DictConfig],
575577
) -> str:
576578
if isinstance(export_options, argparse.Namespace):
577579
# Legacy CLI.
578580
llm_config = LlmConfig.from_args(export_options)
579-
elif isinstance(export_options, LlmConfig):
581+
elif isinstance(export_options, LlmConfig) or isinstance(
582+
export_options, DictConfig
583+
):
580584
# Hydra CLI.
581585
llm_config = export_options
582586
else:

extension/llm/export/README.md

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# LLM Export API
2+
3+
This directory contains the unified API for exporting Large Language Models (LLMs) to ExecuTorch. The `export_llm` module provides a streamlined interface to convert various LLM architectures to optimized `.pte` files for on-device inference.
4+
5+
## Overview
6+
7+
The LLM export process transforms a model from its original format to an optimized representation suitable for mobile and edge devices. This involves several key steps:
8+
9+
1. **Model Instantiation**: Load the model architecture and weights from sources like Hugging Face
10+
2. **Source Transformations**: Apply model-specific optimizations and quantization
11+
3. **IR Export**: Convert to intermediate representations (EXIR, Edge dialect)
12+
4. **Graph Transformations**: Apply backend-specific optimizations and PT2E quantization
13+
5. **Backend Delegation**: Partition operations to hardware-specific backends (XNNPACK, CoreML, QNN, etc.)
14+
6. **Serialization**: Export to final ExecuTorch `.pte` format
15+
16+
## Supported Models
17+
18+
- **Llama**: Llama 2, Llama 3, Llama 3.1, Llama 3.2 (1B, 3B, 8B variants)
19+
- **Qwen**: Qwen 2.5, Qwen 3 (0.6B, 1.7B, 4B variants)
20+
- **Phi**: Phi-3-Mini, Phi-4-Mini
21+
- **Stories**: Stories110M (educational model)
22+
- **SmolLM**: SmolLM2
23+
24+
## Usage
25+
26+
The export API supports two configuration approaches:
27+
28+
### Option 1: Hydra CLI Arguments
29+
30+
Use structured configuration arguments directly on the command line:
31+
32+
```bash
33+
python -m extension.llm.export.export_llm \
34+
base.model_class=llama3 \
35+
model.use_sdpa_with_kv_cache=True \
36+
model.use_kv_cache=True \
37+
export.max_seq_length=128 \
38+
debug.verbose=True \
39+
backend.xnnpack.enabled=True \
40+
backend.xnnpack.extended_ops=True \
41+
quantization.qmode=8da4w
42+
```
43+
44+
### Option 2: Configuration File
45+
46+
Create a YAML configuration file and reference it:
47+
48+
```bash
49+
python -m extension.llm.export.export_llm --config my_config.yaml
50+
```
51+
52+
Example `my_config.yaml`:
53+
```yaml
54+
base:
55+
model_class: llama3
56+
tokenizer_path: /path/to/tokenizer.json
57+
58+
model:
59+
use_kv_cache: true
60+
use_sdpa_with_kv_cache: true
61+
enable_dynamic_shape: true
62+
63+
export:
64+
max_seq_length: 512
65+
output_dir: ./exported_models
66+
output_name: llama3_optimized.pte
67+
68+
quantization:
69+
qmode: 8da4w
70+
group_size: 32
71+
72+
backend:
73+
xnnpack:
74+
enabled: true
75+
extended_ops: true
76+
77+
debug:
78+
verbose: true
79+
```
80+
81+
**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both.
82+
83+
## Example Commands
84+
85+
### Export Qwen3 0.6B with XNNPACK backend and quantization
86+
```bash
87+
python -m extension.llm.export.export_llm \
88+
base.model_class=qwen3-0_6b \
89+
base.params=examples/models/qwen3/0_6b_config.json \
90+
base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \
91+
model.use_kv_cache=true \
92+
model.use_sdpa_with_kv_cache=true \
93+
model.dtype_override=FP32 \
94+
export.max_seq_length=512 \
95+
export.output_name=qwen3_0_6b.pte \
96+
quantization.qmode=8da4w \
97+
backend.xnnpack.enabled=true \
98+
backend.xnnpack.extended_ops=true \
99+
debug.verbose=true
100+
```
101+
102+
### Export Phi-4-Mini with custom checkpoint
103+
```bash
104+
python -m extension.llm.export.export_llm \
105+
base.model_class=phi_4_mini \
106+
base.checkpoint=/path/to/phi4_checkpoint.pth \
107+
base.params=examples/models/phi-4-mini/config.json \
108+
base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \
109+
model.use_kv_cache=true \
110+
model.use_sdpa_with_kv_cache=true \
111+
export.max_seq_length=256 \
112+
export.output_name=phi4_mini.pte \
113+
backend.xnnpack.enabled=true \
114+
debug.verbose=true
115+
```
116+
117+
### Export with CoreML backend (iOS optimization)
118+
```bash
119+
python -m extension.llm.export.export_llm \
120+
base.model_class=llama3 \
121+
model.use_kv_cache=true \
122+
export.max_seq_length=128 \
123+
backend.coreml.enabled=true \
124+
backend.coreml.compute_units=ALL \
125+
quantization.pt2e_quantize=coreml_c4w \
126+
debug.verbose=true
127+
```
128+
129+
## Configuration Options
130+
131+
For a complete reference of all available configuration options, see the [LlmConfig class definition](../../../examples/models/llama/config/llm_config.py) which documents all supported parameters for base, model, export, quantization, backend, and debug configurations.
132+
133+
## Further Reading
134+
135+
- [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide
136+
- [LLM Runner](../runner/) - Running exported models
137+
- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview

extension/llm/export/export_llm.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,16 @@
2323
backend.xnnpack.enabled=True \
2424
backend.xnnpack.extended_ops=True \
2525
quantization.qmode="8da4w"
26+
27+
Example usage using config file:
28+
python -m extension.llm.export.export_llm \
29+
--config example_llm_config.yaml
2630
"""
2731

32+
import argparse
33+
import sys
34+
from typing import Any, List, Tuple
35+
2836
import hydra
2937

3038
from executorch.examples.models.llama.config.llm_config import LlmConfig
@@ -36,10 +44,50 @@
3644
cs.store(name="llm_config", node=LlmConfig)
3745

3846

39-
@hydra.main(version_base=None, config_path=None, config_name="llm_config")
40-
def main(llm_config: LlmConfig) -> None:
47+
def parse_config_arg() -> Tuple[str, List[Any]]:
48+
"""First parse out the arg for whether to use Hydra or the old CLI."""
49+
parser = argparse.ArgumentParser(add_help=True)
50+
parser.add_argument("--config", type=str, help="Path to the LlmConfig file")
51+
args, remaining = parser.parse_known_args()
52+
return args.config, remaining
53+
54+
55+
def pop_config_arg() -> str:
56+
"""
57+
Removes '--config' and its value from sys.argv.
58+
Assumes --config is specified and argparse has already validated the args.
59+
"""
60+
idx = sys.argv.index("--config")
61+
value = sys.argv[idx + 1]
62+
del sys.argv[idx : idx + 2]
63+
return value
64+
65+
66+
@hydra.main(version_base=None, config_name="llm_config")
67+
def hydra_main(llm_config: LlmConfig) -> None:
4168
export_llama(OmegaConf.to_object(llm_config))
4269

4370

71+
def main() -> None:
72+
config, remaining_args = parse_config_arg()
73+
if config:
74+
# Check if there are any remaining hydra CLI args when --config is specified
75+
# This might change in the future to allow overriding config file values
76+
if remaining_args:
77+
raise ValueError(
78+
"Cannot specify additional CLI arguments when using --config. "
79+
f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both."
80+
)
81+
82+
config_file_path = pop_config_arg()
83+
default_llm_config = LlmConfig()
84+
llm_config_from_file = OmegaConf.load(config_file_path)
85+
# Override defaults with values specified in the .yaml provided by --config.
86+
merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file)
87+
export_llama(merged_llm_config)
88+
else:
89+
hydra_main()
90+
91+
4492
if __name__ == "__main__":
4593
main()
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import sys
9+
import tempfile
10+
import unittest
11+
from unittest.mock import MagicMock, patch
12+
13+
from executorch.extension.llm.export.export_llm import (
14+
main,
15+
parse_config_arg,
16+
pop_config_arg,
17+
)
18+
19+
20+
class TestExportLlm(unittest.TestCase):
21+
def test_parse_config_arg_with_config(self) -> None:
22+
"""Test parse_config_arg when --config is provided."""
23+
# Mock sys.argv to include --config
24+
test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"]
25+
with patch.object(sys, "argv", test_argv):
26+
config_path, remaining = parse_config_arg()
27+
self.assertEqual(config_path, "test_config.yaml")
28+
self.assertEqual(remaining, ["extra", "args"])
29+
30+
def test_parse_config_arg_without_config(self) -> None:
31+
"""Test parse_config_arg when --config is not provided."""
32+
test_argv = ["script.py", "debug.verbose=True"]
33+
with patch.object(sys, "argv", test_argv):
34+
config_path, remaining = parse_config_arg()
35+
self.assertIsNone(config_path)
36+
self.assertEqual(remaining, ["debug.verbose=True"])
37+
38+
def test_pop_config_arg(self) -> None:
39+
"""Test pop_config_arg removes --config and its value from sys.argv."""
40+
test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"]
41+
with patch.object(sys, "argv", test_argv):
42+
config_path = pop_config_arg()
43+
self.assertEqual(config_path, "test_config.yaml")
44+
self.assertEqual(sys.argv, ["script.py", "other", "args"])
45+
46+
@patch("executorch.extension.llm.export.export_llm.export_llama")
47+
def test_with_config(self, mock_export_llama: MagicMock) -> None:
48+
"""Test main function with --config file and no hydra args."""
49+
# Create a temporary config file
50+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
51+
f.write(
52+
"""
53+
base:
54+
tokenizer_path: /path/to/tokenizer.json
55+
export:
56+
max_seq_length: 256
57+
"""
58+
)
59+
config_file = f.name
60+
61+
try:
62+
test_argv = ["script.py", "--config", config_file]
63+
with patch.object(sys, "argv", test_argv):
64+
main()
65+
66+
# Verify export_llama was called with config
67+
mock_export_llama.assert_called_once()
68+
called_config = mock_export_llama.call_args[0][0]
69+
self.assertEqual(
70+
called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json"
71+
)
72+
self.assertEqual(called_config["export"]["max_seq_length"], 256)
73+
finally:
74+
os.unlink(config_file)
75+
76+
def test_with_cli_args(self) -> None:
77+
"""Test main function with only hydra CLI args."""
78+
test_argv = ["script.py", "debug.verbose=True"]
79+
with patch.object(sys, "argv", test_argv):
80+
with patch(
81+
"executorch.extension.llm.export.export_llm.hydra_main"
82+
) as mock_hydra:
83+
main()
84+
mock_hydra.assert_called_once()
85+
86+
def test_config_with_cli_args_error(self) -> None:
87+
"""Test that --config rejects additional CLI arguments to prevent mixing approaches."""
88+
# Create a temporary config file
89+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
90+
f.write("base:\n checkpoint: /path/to/checkpoint.pth")
91+
config_file = f.name
92+
93+
try:
94+
test_argv = ["script.py", "--config", config_file, "debug.verbose=True"]
95+
with patch.object(sys, "argv", test_argv):
96+
with self.assertRaises(ValueError) as cm:
97+
main()
98+
99+
error_msg = str(cm.exception)
100+
self.assertIn(
101+
"Cannot specify additional CLI arguments when using --config",
102+
error_msg,
103+
)
104+
finally:
105+
os.unlink(config_file)
106+
107+
def test_config_rejects_multiple_cli_args(self) -> None:
108+
"""Test that --config rejects multiple CLI arguments (not just single ones)."""
109+
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
110+
f.write("export:\n max_seq_length: 128")
111+
config_file = f.name
112+
113+
try:
114+
test_argv = [
115+
"script.py",
116+
"--config",
117+
config_file,
118+
"debug.verbose=True",
119+
"export.output_dir=/tmp",
120+
]
121+
with patch.object(sys, "argv", test_argv):
122+
with self.assertRaises(ValueError):
123+
main()
124+
finally:
125+
os.unlink(config_file)
126+
127+
128+
if __name__ == "__main__":
129+
unittest.main()

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ wheel # For building the pip package archive.
99
zstd # Imported by resolve_buck.py.
1010
lintrunner==0.12.7
1111
lintrunner-adapters==0.12.4
12+
hydra-core>=1.3.0
13+
omegaconf>=2.3.0

0 commit comments

Comments
 (0)