|
13 | 13 | import os |
14 | 14 | import tempfile |
15 | 15 | from dataclasses import dataclass |
16 | | -from typing import ClassVar, Dict, List, Literal, Optional |
| 16 | +from typing import ClassVar, Dict, List, Literal, Optional, Union |
17 | 17 |
|
18 | 18 | import executorch.extension.flat_tensor.serialize as serialize_package |
19 | 19 |
|
20 | 20 | from executorch.exir._serialize._cord import Cord |
21 | 21 | from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass |
22 | 22 | from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile |
| 23 | +from executorch.exir._serialize._named_data_store import ( |
| 24 | + NamedDataStore, |
| 25 | + NamedDataStoreOutput, |
| 26 | +) |
23 | 27 | from executorch.exir._serialize._program import _insert_flatbuffer_header |
24 | 28 | from executorch.exir._serialize.data_serializer import ( |
25 | 29 | DataEntry, |
@@ -389,6 +393,8 @@ def serialize( |
389 | 393 | def deserialize(self, blob: Cord) -> DataPayload: |
390 | 394 | """ |
391 | 395 | Deserializes a flat_tensor blob into a list of tensor metadata and tensors. |
| 396 | +
|
| 397 | + Note: deserialization does not preserve alignment information. |
392 | 398 | """ |
393 | 399 |
|
394 | 400 | data = bytes(blob) |
@@ -436,3 +442,12 @@ def deserialize(self, blob: Cord) -> DataPayload: |
436 | 442 | payload.named_data[named_data.key] = entry |
437 | 443 |
|
438 | 444 | return payload |
| 445 | + |
| 446 | + def deserialize_to_named_data_store_output(self, blob: bytes, name: str) -> NamedDataStoreOutput: |
| 447 | + bytes = Cord(blob) |
| 448 | + data_payload = self.deserialize(bytes) |
| 449 | + return NamedDataStoreOutput( |
| 450 | + buffers = data_payload.buffers, |
| 451 | + pte_data = {}, |
| 452 | + external_data = {name:data_payload.named_data} |
| 453 | + ) |
0 commit comments