Skip to content

Commit fe7b4e6

Browse files
committed
Deserialize to named data store output
Should this be a public API? May be relatively easy for user to impl with the existing deserialize functionality. Differential Revision: [D83510300](https://our.internmc.facebook.com/intern/diff/D83510300/) ghstack-source-id: 312879155 Pull Request resolved: #15469
1 parent c85ece4 commit fe7b4e6

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

extension/flat_tensor/serialize/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ runtime.python_library(
2929
],
3030
visibility = [
3131
"//executorch/...",
32+
"@EXECUTORCH_CLIENTS",
3233
],
3334
deps = [
3435
":schema",

extension/flat_tensor/serialize/serialize.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313
import os
1414
import tempfile
1515
from dataclasses import dataclass
16-
from typing import ClassVar, Dict, List, Literal, Optional
16+
from typing import ClassVar, Dict, List, Literal, Optional, Union
1717

1818
import executorch.extension.flat_tensor.serialize as serialize_package
1919

2020
from executorch.exir._serialize._cord import Cord
2121
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
2222
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
23+
from executorch.exir._serialize._named_data_store import (
24+
NamedDataStore,
25+
NamedDataStoreOutput,
26+
)
2327
from executorch.exir._serialize._program import _insert_flatbuffer_header
2428
from executorch.exir._serialize.data_serializer import (
2529
DataEntry,
@@ -389,6 +393,8 @@ def serialize(
389393
def deserialize(self, blob: Cord) -> DataPayload:
390394
"""
391395
Deserializes a flat_tensor blob into a list of tensor metadata and tensors.
396+
397+
Note: deserialization does not preserve alignment information.
392398
"""
393399

394400
data = bytes(blob)
@@ -436,3 +442,12 @@ def deserialize(self, blob: Cord) -> DataPayload:
436442
payload.named_data[named_data.key] = entry
437443

438444
return payload
445+
446+
def deserialize_to_named_data_store_output(self, blob: bytes, name: str) -> NamedDataStoreOutput:
447+
bytes = Cord(blob)
448+
data_payload = self.deserialize(bytes)
449+
return NamedDataStoreOutput(
450+
buffers = data_payload.buffers,
451+
pte_data = {},
452+
external_data = {name:data_payload.named_data}
453+
)

0 commit comments

Comments
 (0)