Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions extension/llm/export/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,14 @@ fbcode_target(_kind = runtime.python_test,
":export_lib",
],
)

fbcode_target(_kind = runtime.python_library,
name = "metadata",
srcs = [
"metadata.py",
],
visibility = ["PUBLIC"],
deps = [
"//executorch/exir:lib",
],
)
117 changes: 117 additions & 0 deletions extension/llm/export/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Model metadata storage for PTE files.

Embeds model metadata (tokenizer config, chat templates, architecture info)
directly in PTE files via the NamedData mechanism. Replaces the current
constant_methods approach (which creates full ExecutionPlan entries for
simple constant values).

Keys use a dotted namespace.field convention:
tokenizer.bos_id, tokenizer.eos_ids, context.max_seq_len, etc.
"""

from __future__ import annotations

import struct
from typing import Dict, List, Sequence, TYPE_CHECKING, Union

if TYPE_CHECKING:
from executorch.exir import EdgeProgramManager

METADATA_PREFIX = "metadata."

MetadataValue = Union[str, int, float, bytes, Sequence[int]]


def _encode_value(key: str, value: MetadataValue) -> bytes:
if isinstance(value, bool):
raise TypeError(f"bool not supported for key '{key}', use int (0/1) instead")
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, (list, tuple)):
for i, elem in enumerate(value):
if not isinstance(elem, int) or isinstance(elem, bool):
raise TypeError(
f"list element {i} for key '{key}' must be int, got {type(elem)}"
)
return struct.pack(f"<I{len(value)}q", len(value), *value)
elif isinstance(value, int):
return struct.pack("<q", value)
elif isinstance(value, float):
return struct.pack("<d", value)
elif isinstance(value, bytes):
return value
raise TypeError(f"Unsupported metadata type {type(value)} for key '{key}'")
Comment on lines +31 to +49


def add_metadata(
edge_manager: EdgeProgramManager,
metadata: Dict[str, MetadataValue],
) -> None:
"""Add metadata KV pairs to a PTE file during export.

Call BEFORE edge_manager.to_executorch().

Args:
edge_manager: The EdgeProgramManager from to_edge() or
to_edge_transform_and_lower().
metadata: Dict mapping string keys to values (str, int, float, bytes,
or list[int]). Keys are automatically prefixed with "metadata." to
avoid collision with backend named data.
"""
for key, value in metadata.items():
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai May 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to standardize key names and dtypes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easiest way: follow what gguf did 🤣

# _named_data_store is a private attribute of EdgeProgramManager; if its
# internals change, this will break. Prefer a public API if one is added.
edge_manager._named_data_store.add_named_data(
key=f"{METADATA_PREFIX}{key}",
data=_encode_value(key, value),
)


def read_metadata(pte_path: str) -> Dict[str, bytes]:
"""Read all metadata entries from a PTE file.

Returns raw bytes for each key (without the "metadata." prefix).
Use get_string/get_int/get_float for typed access.

WARNING: Loads the entire PTE file into memory. Not suitable for
large model files in production; intended for testing and debugging.
"""
from executorch.exir._serialize._program import deserialize_pte_binary

with open(pte_path, "rb") as f:
pte_data = f.read()

pte_file = deserialize_pte_binary(pte_data)

result = {}
if pte_file.named_data is not None:
for key, entry in pte_file.named_data.pte_data.items():
if key.startswith(METADATA_PREFIX):
short_key = key[len(METADATA_PREFIX) :]
result[short_key] = pte_file.named_data.buffers[entry.buffer_index]

return result


def get_string(metadata: Dict[str, bytes], key: str) -> str:
return metadata[key].decode("utf-8")


def get_int(metadata: Dict[str, bytes], key: str) -> int:
return struct.unpack("<q", metadata[key])[0]


def get_float(metadata: Dict[str, bytes], key: str) -> float:
return struct.unpack("<d", metadata[key])[0]


def get_int_list(metadata: Dict[str, bytes], key: str) -> List[int]:
data = metadata[key]
(count,) = struct.unpack_from("<I", data, 0)
return list(struct.unpack_from(f"<{count}q", data, 4))
12 changes: 12 additions & 0 deletions extension/llm/export/test/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,15 @@ fbcode_target(_kind = runtime.python_test,
"//caffe2:torch",
],
)

fbcode_target(_kind = runtime.python_test,
name = "test_metadata_roundtrip",
srcs = [
"test_metadata_roundtrip.py",
],
deps = [
"//executorch/extension/llm/export:metadata",
"//caffe2:torch",
"//executorch/exir:lib",
],
)
163 changes: 163 additions & 0 deletions extension/llm/export/test/test_metadata_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Round-trip test: export model with metadata stored in NamedData."""

import tempfile
import unittest

import torch
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from executorch.extension.llm.export.metadata import (
add_metadata,
get_float,
get_int,
get_int_list,
get_string,
read_metadata,
)
from torch.export import export


class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 5)

def forward(self, x):
return self.linear(x)


class TestMetadataRoundTrip(unittest.TestCase):
def test_roundtrip(self):
model = SimpleModel()
example_input = (torch.randn(1, 10),)

exported = export(model, example_input)
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

chat_template = (
"{% for message in messages %}"
"{{ message.role }}: {{ message.content }}\n"
"{% endfor %}"
)

add_metadata(
edge,
{
"tokenizer.model": "BPE",
"tokenizer.vocab_size": 128256,
"chat_template": chat_template,
"model.arch": "llama",
"model.context_length": 8192,
"general.name": "Llama-3.2-1B",
"model.temperature": 0.7,
},
)

et_program = edge.to_executorch()

with tempfile.NamedTemporaryFile(suffix=".pte") as f:
f.write(et_program.buffer)
f.flush()

metadata = read_metadata(f.name)

self.assertEqual(len(metadata), 7)
self.assertEqual(get_string(metadata, "tokenizer.model"), "BPE")
self.assertEqual(get_int(metadata, "tokenizer.vocab_size"), 128256)
self.assertEqual(get_string(metadata, "chat_template"), chat_template)
self.assertEqual(get_string(metadata, "model.arch"), "llama")
self.assertEqual(get_int(metadata, "model.context_length"), 8192)
self.assertEqual(get_string(metadata, "general.name"), "Llama-3.2-1B")
self.assertAlmostEqual(get_float(metadata, "model.temperature"), 0.7)

def test_empty_metadata(self):
model = SimpleModel()
exported = export(model, (torch.randn(1, 10),))
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

add_metadata(edge, {})
et_program = edge.to_executorch()

with tempfile.NamedTemporaryFile(suffix=".pte") as f:
f.write(et_program.buffer)
f.flush()
metadata = read_metadata(f.name)
self.assertEqual(len(metadata), 0)

def test_raw_bytes_metadata(self):
model = SimpleModel()
exported = export(model, (torch.randn(1, 10),))
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

raw = b"\x00\x01\x02\xff"
add_metadata(edge, {"binary_blob": raw})
et_program = edge.to_executorch()

with tempfile.NamedTemporaryFile(suffix=".pte") as f:
f.write(et_program.buffer)
f.flush()
metadata = read_metadata(f.name)
self.assertEqual(metadata["binary_blob"], raw)

def test_llm_metadata_replaces_constant_methods(self):
"""POC: these metadata fields currently live as constant_methods
(full ExecutionPlan entries). This test shows they can be stored
as lightweight NamedData entries instead."""
model = SimpleModel()
exported = export(model, (torch.randn(1, 10),))
edge = to_edge_transform_and_lower(
exported,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

add_metadata(
edge,
{
"tokenizer.bos_id": 128000,
"tokenizer.eos_ids": [128009, 128001],
"context.max_seq_len": 8192,
"context.max_context_len": 8192,
"model.vocab_size": 128256,
"model.use_kv_cache": 1,
"model.use_sdpa_with_kv_cache": 1,
"model.n_layers": 16,
"tokenizer.chat_template": "{% for m in messages %}{{ m.content }}{% endfor %}",
},
)

et_program = edge.to_executorch()

with tempfile.NamedTemporaryFile(suffix=".pte") as f:
f.write(et_program.buffer)
f.flush()

metadata = read_metadata(f.name)

self.assertEqual(get_int(metadata, "tokenizer.bos_id"), 128000)
self.assertEqual(
get_int_list(metadata, "tokenizer.eos_ids"), [128009, 128001]
)
self.assertEqual(get_int(metadata, "context.max_seq_len"), 8192)
self.assertEqual(get_int(metadata, "context.max_context_len"), 8192)
self.assertEqual(get_int(metadata, "model.vocab_size"), 128256)
self.assertEqual(get_int(metadata, "model.use_kv_cache"), 1)
self.assertEqual(get_int(metadata, "model.use_sdpa_with_kv_cache"), 1)
self.assertEqual(get_int(metadata, "model.n_layers"), 16)
self.assertIn(
"{% for m in messages %}",
get_string(metadata, "tokenizer.chat_template"),
)
Loading
Loading