Skip to content

Commit 6f89131

Browse files
committed
[deserialization] PT2ArchiveDataMap
Pull Request resolved: #13973 This diff introduces the PT2ArchiveDataMap, which reads the weights in .pt2 archive files. 1. Open the archive file with miniz 2. There are two json config files (weights_config.json, constants_config.json) with information on weight name -> {weight path, weight metadata}. Open and extract weight information into unordered_maps; then free the json blobs. 3. For get_tensor_layout calls, return the json information. 4. For get_data calls, use miniz to calculate the offset + size and then use data loader. PT2 archive files are not compressed (to allow mmap-ing), so this is fine. PT2 archive format: https://docs.google.com/document/d/1xdx3I4zK6naPEWX3e49rCUccZeAC9zMLCFKXvUQFR7o/edit?tab=t.0 Serde: https://docs.google.com/document/d/11X-KsLPMJGdEr4sG4sCNLnGLhSKrc8utDGMQqFbZx9E/edit?tab=t.0#heading=h.tsw6d16xh497 --- TODO in subsequent diffs - convert stride to dim order - Additional testing; failure cases, model file with constants as well as weights, model with no weights. - CMake for OSS ghstack-source-id: 308366584 @exported-using-ghexport Differential Revision: [D81248896](https://our.internmc.facebook.com/intern/diff/D81248896/)
1 parent 151f3be commit 6f89131

File tree

6 files changed

+687
-0
lines changed

6 files changed

+687
-0
lines changed

extension/pt2_archive/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/pt2_archive/pt2_archive_data_map.h>
10+
11+
#include <executorch/extension/data_loader/mmap_data_loader.h>
12+
#include <executorch/runtime/core/data_loader.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
15+
#include <executorch/runtime/core/freeable_buffer.h>
16+
#include <executorch/runtime/core/result.h>
17+
#include <executorch/runtime/core/span.h>
18+
#include <executorch/runtime/platform/compiler.h>
19+
20+
#include "miniz.h"
21+
22+
#include <nlohmann/json.hpp>
23+
#include <string.h>
24+
#include <unordered_map>
25+
26+
using json = nlohmann::json;
27+
28+
using executorch::runtime::Error;
29+
using executorch::runtime::FreeableBuffer;
30+
using executorch::runtime::Result;
31+
using executorch::runtime::Span;
32+
33+
using executorch::aten::ScalarType;
34+
using executorch::aten::string_view;
35+
using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
36+
using executorch::runtime::DataLoader;
37+
38+
using executorch::extension::MmapDataLoader;
39+
40+
// MZ_ZIP constants.
41+
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
42+
constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
43+
constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
44+
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIG = 0x04034b50;
45+
46+
// PT2Archive constants.
47+
constexpr const char* WEIGHTS_DIR = "/data/weights/";
48+
constexpr const char* WEIGHTS_CONFIG_FILE = "model_weights_config.json";
49+
constexpr const char* CONSTANTS_DIR = "/data/constants/";
50+
constexpr const char* CONSTANTS_CONFIG_FILE = "model_constants_config.json";
51+
52+
namespace {
53+
ScalarType convert_pt2_to_et_scalartype(uint32_t dtype) {
54+
// PT2 serialization dtypes and ET dtypes are off by 1.
55+
// PT2: https://fburl.com/code/qjlmiifs (contains UNKNOWN at enum 0)
56+
// ET: https://fburl.com/code/gq30tizb (starts with BYTE at enum 0)
57+
return static_cast<ScalarType>(dtype - 1);
58+
}
59+
60+
// Use to read miniz header info.
61+
static int64_t read_le_16(uint8_t* buf) {
62+
return buf[0] + (buf[1] << 8);
63+
}
64+
} // namespace
65+
66+
namespace executorch {
67+
namespace extension {
68+
69+
PT2ArchiveDataMap::~PT2ArchiveDataMap() {
70+
// Close zip archive resources.
71+
if (zip_archive_) {
72+
mz_zip_reader_end(zip_archive_.get());
73+
}
74+
}
75+
76+
/*static*/ Error PT2ArchiveDataMap::parse_json(
77+
std::unique_ptr<mz_zip_archive>& zip_archive,
78+
const std::string& filename,
79+
std::unordered_map<std::string, std::string>& tensor_name_to_path,
80+
std::unordered_map<std::string, ConcreteTensorLayout>&
81+
tensor_name_to_layout) {
82+
/** JSON format (for information we care about) looks like this:
83+
"config": {
84+
"weight_name": {
85+
"path_name": "weight_0",
86+
"tensor_meta": {
87+
"dtype": <DTYPE>,
88+
"sizes": [{"as_int": <SIZE>}, {"as_int": <SIZE>}, ...],
89+
"strides": [{"as_int": <SIZE>}, {"as_int": <SIZE>}, ...],
90+
}
91+
}
92+
} */
93+
size_t uncomp_size = 0;
94+
void* buffer = mz_zip_reader_extract_file_to_heap(
95+
zip_archive.get(), filename.c_str(), &uncomp_size, 0);
96+
if (!buffer) {
97+
ET_LOG(Error, "Failed to extract file %s to heap", filename.c_str());
98+
mz_zip_reader_end(zip_archive.get());
99+
return Error::InvalidExternalData;
100+
}
101+
json json_config;
102+
try {
103+
std::string json_str(static_cast<const char*>(buffer), uncomp_size);
104+
// Parse JSON string.
105+
json_config = json::parse(json_str);
106+
ET_CHECK_OR_RETURN_ERROR(
107+
json_config.contains("config"),
108+
InvalidExternalData,
109+
"JSON config does not contain 'config' key; malformed archive file.");
110+
auto config = json_config["config"];
111+
for (auto& item : config.items()) {
112+
ET_CHECK_OR_RETURN_ERROR(
113+
item.value().contains("path_name") &&
114+
item.value().contains("tensor_meta"),
115+
InvalidExternalData,
116+
"JSON config does not contain 'path_name' and 'tensor_meta' keys for key %s",
117+
item.key().c_str());
118+
119+
// Add tensor_name -> path_name mapping.
120+
tensor_name_to_path[item.key().c_str()] = item.value()["path_name"];
121+
122+
// Add tensor_name -> tensor_meta mapping.
123+
auto tensor_meta = item.value()["tensor_meta"];
124+
ET_CHECK_OR_RETURN_ERROR(
125+
tensor_meta.contains("dtype") &&
126+
tensor_meta["dtype"].is_number_integer(),
127+
InvalidExternalData,
128+
"JSON config does not contain 'dtype' key for key %s",
129+
item.key().c_str());
130+
ET_CHECK_OR_RETURN_ERROR(
131+
tensor_meta.contains("sizes") && tensor_meta["sizes"].is_array(),
132+
InvalidExternalData,
133+
"JSON config does not contain 'sizes' key for key %s",
134+
item.key().c_str());
135+
ET_CHECK_OR_RETURN_ERROR(
136+
tensor_meta.contains("strides") && tensor_meta["strides"].is_array(),
137+
InvalidExternalData,
138+
"JSON config does not contain 'strides' key for key %s",
139+
item.key().c_str());
140+
ConcreteTensorLayout concrete_layout;
141+
concrete_layout.scalar_type =
142+
convert_pt2_to_et_scalartype(tensor_meta["dtype"].get<int>());
143+
int i = 0;
144+
for (const auto& size : tensor_meta["sizes"]) {
145+
concrete_layout.sizes.push_back(size["as_int"].get<int32_t>());
146+
// TODO: Calculate dim order from strides. Assume contiguous for now.
147+
concrete_layout.dim_order.push_back(i);
148+
++i;
149+
}
150+
tensor_name_to_layout[item.key().c_str()] = std::move(concrete_layout);
151+
}
152+
free(buffer);
153+
} catch (const json::exception& e) {
154+
ET_LOG(Error, "Failed to parse JSON: %s", e.what());
155+
free(buffer);
156+
mz_zip_reader_end(zip_archive.get());
157+
return Error::InvalidExternalData;
158+
}
159+
return Error::Ok;
160+
}
161+
162+
/*static*/ Result<PT2ArchiveDataMap> PT2ArchiveDataMap::load(
163+
const std::string& pt2_archive_file_path) {
164+
ET_LOG(
165+
Info, "Loading PT2ArchiveDataMap from %s", pt2_archive_file_path.c_str());
166+
auto zip_archive = std::make_unique<mz_zip_archive>();
167+
// Open zip archive to get json config data.
168+
memset(zip_archive.get(), 0, sizeof(mz_zip_archive));
169+
mz_bool status = mz_zip_reader_init_file(
170+
zip_archive.get(), pt2_archive_file_path.c_str(), 0);
171+
172+
ET_CHECK_OR_RETURN_ERROR(
173+
status == 1,
174+
InvalidArgument,
175+
"Failed to open zip archive %s, status: %d",
176+
pt2_archive_file_path.c_str(),
177+
status);
178+
179+
// Extract archive name.
180+
mz_uint n = mz_zip_reader_get_num_files(zip_archive.get());
181+
ET_CHECK_OR_RETURN_ERROR(
182+
n > 0, InvalidExternalData, "Archive does not contain any files");
183+
mz_uint name_size =
184+
mz_zip_reader_get_filename(zip_archive.get(), 0, nullptr, 0);
185+
std::string buf(name_size, '\0');
186+
mz_zip_reader_get_filename(zip_archive.get(), 0, &buf[0], name_size);
187+
auto pos = buf.find_first_of('/');
188+
ET_CHECK_OR_RETURN_ERROR(
189+
pos != std::string::npos,
190+
InvalidExternalData,
191+
"File in archive is not in a subdirectory");
192+
193+
std::string archive_name = buf.substr(0, pos);
194+
195+
// Set up data structures for tensor name -> {path, metadata}.
196+
std::unordered_map<std::string, std::string> tensor_name_to_path;
197+
std::unordered_map<std::string, ConcreteTensorLayout> tensor_name_to_layout;
198+
199+
// Read model_weights.json file.
200+
std::string model_weights = archive_name + WEIGHTS_DIR + WEIGHTS_CONFIG_FILE;
201+
Error err = parse_json(
202+
zip_archive, model_weights, tensor_name_to_path, tensor_name_to_layout);
203+
ET_CHECK_OR_RETURN_ERROR(
204+
err == Error::Ok,
205+
InvalidExternalData,
206+
"Failed to parse model weights json config");
207+
208+
// Read model_constants.json file.
209+
std::string model_constants =
210+
archive_name + CONSTANTS_DIR + CONSTANTS_CONFIG_FILE;
211+
err = parse_json(
212+
zip_archive, model_constants, tensor_name_to_path, tensor_name_to_layout);
213+
ET_CHECK_OR_RETURN_ERROR(
214+
err == Error::Ok,
215+
InvalidExternalData,
216+
"Failed to parse model constants json config");
217+
218+
// Create data loader to wrap around zip archive.
219+
Result<MmapDataLoader> loader =
220+
MmapDataLoader::from(pt2_archive_file_path.c_str());
221+
ET_CHECK_OR_RETURN_ERROR(
222+
loader.ok(),
223+
InvalidArgument,
224+
"Loader failed to load with error: %zu",
225+
loader.error());
226+
227+
std::unique_ptr<DataLoader> loader_ptr =
228+
std::make_unique<MmapDataLoader>(std::move(loader.get()));
229+
return PT2ArchiveDataMap(
230+
std::move(zip_archive),
231+
std::move(loader_ptr),
232+
std::move(archive_name),
233+
std::move(tensor_name_to_layout),
234+
std::move(tensor_name_to_path));
235+
}
236+
237+
Result<const TensorLayout> PT2ArchiveDataMap::get_tensor_layout(
238+
string_view key) const {
239+
if (tensor_name_to_layout_.find(key.data()) == tensor_name_to_layout_.end()) {
240+
ET_LOG(Error, "Tensor layout not found for key %s", key.data());
241+
return Error::NotFound;
242+
}
243+
return tensor_name_to_layout_.at(key.data()).create_tensor_layout();
244+
}
245+
246+
Result<FreeableBuffer> PT2ArchiveDataMap::get_data(string_view key) const {
247+
if (tensor_name_to_path_.find(key.data()) == tensor_name_to_path_.end()) {
248+
ET_LOG(Error, "Tensor data not found for key %s", key.data());
249+
return Error::NotFound;
250+
}
251+
252+
// Load data from zip archive - see PyTorch equivalent:
253+
// https://www.internalfb.com/code/fbsource/[f25405534204]/fbcode/caffe2/caffe2/serialize/inline_container.cc?lines=614
254+
std::string file_path =
255+
archive_name_ + WEIGHTS_DIR + tensor_name_to_path_.at(key.data());
256+
int file_index = mz_zip_reader_locate_file(
257+
zip_archive_.get(), file_path.c_str(), nullptr, 0);
258+
259+
mz_zip_archive_file_stat file_stat;
260+
if (!mz_zip_reader_file_stat(zip_archive_.get(), file_index, &file_stat)) {
261+
ET_LOG(Error, "Failed to get file stat for file '%s'\n", file_path.c_str());
262+
return Error::InvalidExternalData;
263+
}
264+
mz_uint64 file_size = file_stat.m_uncomp_size;
265+
// NOLINTNEXTLINE(facebook-hte-CArray)
266+
mz_uint8 local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
267+
if (mz_zip_read_archive_data(
268+
zip_archive_.get(),
269+
file_stat.m_local_header_ofs,
270+
local_header,
271+
MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) {
272+
ET_LOG(Info, "Failed to read local header for '%s'\n", file_path.c_str());
273+
return Error::InvalidExternalData;
274+
}
275+
mz_uint32 sig = MZ_READ_LE32(local_header);
276+
if (sig != MZ_ZIP_LOCAL_DIR_HEADER_SIG) {
277+
ET_LOG(
278+
Info,
279+
"Invalid local header signature for '%s': 0x%08X\n",
280+
file_path.c_str(),
281+
sig);
282+
return Error::InvalidExternalData;
283+
}
284+
285+
// Calculate offset.
286+
mz_uint16 filename_len =
287+
read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
288+
mz_uint16 extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
289+
mz_uint64 offset = file_stat.m_local_header_ofs +
290+
MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
291+
292+
return loader_->load(
293+
offset,
294+
file_size,
295+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
296+
}
297+
298+
Error PT2ArchiveDataMap::load_data_into(
299+
ET_UNUSED string_view key,
300+
ET_UNUSED void* buffer,
301+
ET_UNUSED size_t size) const {
302+
return Error::NotImplemented;
303+
}
304+
305+
Result<uint32_t> PT2ArchiveDataMap::get_num_keys() const {
306+
return static_cast<uint32_t>(tensor_name_to_path_.size());
307+
}
308+
309+
Result<const char*> PT2ArchiveDataMap::get_key(uint32_t index) const {
310+
auto num_keys = get_num_keys().get();
311+
ET_CHECK_OR_RETURN_ERROR(
312+
index < num_keys,
313+
InvalidArgument,
314+
"Index %u out of range of size %u",
315+
index,
316+
num_keys);
317+
int i = 0;
318+
for (const auto& item : tensor_name_to_path_) {
319+
if (i == index) {
320+
return item.first.c_str();
321+
}
322+
++i;
323+
}
324+
// Should not reach here.
325+
return Error::Internal;
326+
}
327+
328+
} // namespace extension
329+
} // namespace executorch

0 commit comments

Comments
 (0)