Skip to content

Commit 9fc408a

Browse files
committed
PT2ArchiveDataMap
This diff introduces the PT2ArchiveDataMap, which reads the weights in .pt2 archive files. 1. Open the archive file with miniz 2. There are two json config files (weights_config.json, constants_config.json) with information on weight name -> {weight path, weight metadata}. Open and extract weight information into unordered_maps; then free the json blobs. 3. For get_tensor_layout calls, return the json information. 4. For get_data calls, use miniz to calculate the offset + size and then use data loader. PT2 archive files are not compressed (to allow mmap-ing), so this is fine. PT2 archive format: https://docs.google.com/document/d/1xdx3I4zK6naPEWX3e49rCUccZeAC9zMLCFKXvUQFR7o/edit?tab=t.0 Serde: https://docs.google.com/document/d/11X-KsLPMJGdEr4sG4sCNLnGLhSKrc8utDGMQqFbZx9E/edit?tab=t.0#heading=h.tsw6d16xh497 --- TODO in subsequent diffs - convert stride to dim order - Additional testing; failure cases, model file with constants as well as weights, model with no weights. - CMake for OSS Differential Revision: [D81248896](https://our.internmc.facebook.com/intern/diff/D81248896/) ghstack-source-id: 307655657 Pull Request resolved: #13973
1 parent 8c79a53 commit 9fc408a

File tree

6 files changed

+681
-0
lines changed

6 files changed

+681
-0
lines changed

extension/named_data_map/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/named_data_map/pt2_archive_data_map.h>
10+
11+
#include <executorch/extension/data_loader/mmap_data_loader.h>
12+
#include <executorch/runtime/core/data_loader.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
15+
#include <executorch/runtime/core/freeable_buffer.h>
16+
#include <executorch/runtime/core/result.h>
17+
#include <executorch/runtime/core/span.h>
18+
#include <executorch/runtime/platform/compiler.h>
19+
20+
#include "miniz.h"
21+
22+
#include <nlohmann/json.hpp>
23+
#include <string.h>
24+
#include <unordered_map>
25+
26+
using json = nlohmann::json;
27+
28+
using executorch::runtime::Error;
29+
using executorch::runtime::FreeableBuffer;
30+
using executorch::runtime::Result;
31+
using executorch::runtime::Span;
32+
33+
using executorch::aten::ScalarType;
34+
using executorch::aten::string_view;
35+
using executorch::ET_RUNTIME_NAMESPACE::TensorLayout;
36+
using executorch::runtime::DataLoader;
37+
38+
using executorch::extension::MmapDataLoader;
39+
40+
// MZ_ZIP constants.
41+
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
42+
constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
43+
constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
44+
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIG = 0x04034b50;
45+
46+
// PT2Archive constants.
47+
constexpr const char* WEIGHTS_DIR = "/data/weights/";
48+
constexpr const char* WEIGHTS_CONFIG_FILE = "model_weights_config.json";
49+
constexpr const char* CONSTANTS_DIR = "/data/constants/";
50+
constexpr const char* CONSTANTS_CONFIG_FILE = "model_constants_config.json";
51+
52+
namespace {
53+
ScalarType convert_pt2_to_et_scalartype(uint32_t dtype) {
54+
// PT2 serialization dtypes and ET dtypes are off by 1.
55+
// PT2: https://fburl.com/code/qjlmiifs (contains UNKNOWN at enum 0)
56+
// ET: https://fburl.com/code/gq30tizb (starts with BYTE at enum 0)
57+
return static_cast<ScalarType>(dtype - 1);
58+
}
59+
60+
// Use to read miniz header info.
61+
static int64_t read_le_16(uint8_t* buf) {
62+
return buf[0] + (buf[1] << 8);
63+
}
64+
} // namespace
65+
66+
namespace executorch {
67+
namespace extension {
68+
69+
PT2ArchiveDataMap::~PT2ArchiveDataMap() {
70+
// Close zip archive resources.
71+
if (zip_archive_) {
72+
mz_zip_reader_end(zip_archive_.get());
73+
}
74+
}
75+
76+
/*static*/ Error PT2ArchiveDataMap::parse_json(
77+
std::unique_ptr<mz_zip_archive>& zip_archive,
78+
const std::string& filename,
79+
std::unordered_map<std::string, std::string>& tensor_name_to_path,
80+
std::unordered_map<std::string, ConcreteTensorLayout>&
81+
tensor_name_to_layout) {
82+
/** JSON format (for information we care about) looks like this:
83+
"config": {
84+
"weight_name": {
85+
"path_name": "weight_0",
86+
"tensor_meta": {
87+
"dtype": <DTYPE>,
88+
"sizes": [{"as_int": <SIZE>}, {"as_int": <SIZE>}, ...],
89+
"strides": [{"as_int": <SIZE>}, {"as_int": <SIZE>}, ...],
90+
}
91+
}
92+
} */
93+
size_t uncomp_size = 0;
94+
void* buffer = mz_zip_reader_extract_file_to_heap(
95+
zip_archive.get(), filename.c_str(), &uncomp_size, 0);
96+
if (!buffer) {
97+
ET_LOG(Error, "Failed to extract file %s to heap", filename.c_str());
98+
mz_zip_reader_end(zip_archive.get());
99+
return Error::InvalidExternalData;
100+
}
101+
json json_config;
102+
try {
103+
std::string json_str(static_cast<const char*>(buffer), uncomp_size);
104+
// Parse JSON string.
105+
json_config = json::parse(json_str);
106+
ET_CHECK_OR_RETURN_ERROR(
107+
json_config.contains("config"),
108+
InvalidExternalData,
109+
"JSON config does not contain 'config' key; malformed archive file.");
110+
auto config = json_config["config"];
111+
for (auto& item : config.items()) {
112+
ET_CHECK_OR_RETURN_ERROR(
113+
item.value().contains("path_name") &&
114+
item.value().contains("tensor_meta"),
115+
InvalidExternalData,
116+
"JSON config does not contain 'path_name' and 'tensor_meta' keys for key %s",
117+
item.key().c_str());
118+
119+
// Add tensor_name -> path_name mapping.
120+
tensor_name_to_path[item.key().c_str()] = item.value()["path_name"];
121+
122+
// Add tensor_name -> tensor_meta mapping.
123+
auto tensor_meta = item.value()["tensor_meta"];
124+
ET_CHECK_OR_RETURN_ERROR(
125+
tensor_meta.contains("dtype") &&
126+
tensor_meta["dtype"].is_number_integer(),
127+
InvalidExternalData,
128+
"JSON config does not contain 'dtype' key for key %s",
129+
item.key().c_str());
130+
ET_CHECK_OR_RETURN_ERROR(
131+
tensor_meta.contains("sizes") && tensor_meta["sizes"].is_array(),
132+
InvalidExternalData,
133+
"JSON config does not contain 'sizes' key for key %s",
134+
item.key().c_str());
135+
ET_CHECK_OR_RETURN_ERROR(
136+
tensor_meta.contains("strides") && tensor_meta["strides"].is_array(),
137+
InvalidExternalData,
138+
"JSON config does not contain 'strides' key for key %s",
139+
item.key().c_str());
140+
ConcreteTensorLayout concrete_layout;
141+
concrete_layout.scalar_type =
142+
convert_pt2_to_et_scalartype(tensor_meta["dtype"].get<int>());
143+
int i = 0;
144+
for (const auto& size : tensor_meta["sizes"]) {
145+
concrete_layout.sizes.push_back(size["as_int"].get<int32_t>());
146+
// TODO: Calculate dim order from strides. Assume contiguous for now.
147+
concrete_layout.dim_order.push_back(i);
148+
++i;
149+
}
150+
tensor_name_to_layout[item.key().c_str()] = std::move(concrete_layout);
151+
}
152+
free(buffer);
153+
} catch (const json::exception& e) {
154+
ET_LOG(Error, "Failed to parse JSON: %s", e.what());
155+
free(buffer);
156+
mz_zip_reader_end(zip_archive.get());
157+
return Error::InvalidExternalData;
158+
}
159+
return Error::Ok;
160+
}
161+
162+
/*static*/ Result<PT2ArchiveDataMap> PT2ArchiveDataMap::load(
163+
const std::string& pt2_archive_file_path) {
164+
ET_LOG(
165+
Info, "Loading PT2ArchiveDataMap from %s", pt2_archive_file_path.c_str());
166+
auto zip_archive = std::make_unique<mz_zip_archive>();
167+
// Open zip archive to get json config data.
168+
memset(zip_archive.get(), 0, sizeof(mz_zip_archive));
169+
mz_bool status = mz_zip_reader_init_file(
170+
zip_archive.get(), pt2_archive_file_path.c_str(), 0);
171+
172+
ET_CHECK_OR_RETURN_ERROR(
173+
status == 1,
174+
InvalidArgument,
175+
"Failed to open zip archive %s, status: %d",
176+
pt2_archive_file_path.c_str(),
177+
status);
178+
179+
// Extract archive name.
180+
mz_uint n = mz_zip_reader_get_num_files(zip_archive.get());
181+
ET_CHECK_OR_RETURN_ERROR(
182+
n > 0, InvalidExternalData, "Archive does not contain any files");
183+
size_t name_size =
184+
mz_zip_reader_get_filename(zip_archive.get(), 0, nullptr, 0);
185+
std::string buf(name_size, '\0');
186+
mz_zip_reader_get_filename(zip_archive.get(), 0, &buf[0], name_size);
187+
auto pos = buf.find_first_of('/');
188+
ET_CHECK_OR_RETURN_ERROR(
189+
pos != std::string::npos,
190+
InvalidExternalData,
191+
"File in archive is not in a subdirectory");
192+
193+
std::string archive_name = buf.substr(0, pos);
194+
195+
// Set up data structures for tensor name -> {path, metadata}.
196+
std::unordered_map<std::string, std::string> tensor_name_to_path;
197+
std::unordered_map<std::string, ConcreteTensorLayout> tensor_name_to_layout;
198+
199+
// Read model_weights.json file.
200+
std::string model_weights = archive_name + WEIGHTS_DIR + WEIGHTS_CONFIG_FILE;
201+
Error err = parse_json(
202+
zip_archive, model_weights, tensor_name_to_path, tensor_name_to_layout);
203+
ET_CHECK_OR_RETURN_ERROR(
204+
err == Error::Ok,
205+
InvalidExternalData,
206+
"Failed to parse model weights json config");
207+
208+
// Read model_constants.json file.
209+
std::string model_constants =
210+
archive_name + CONSTANTS_DIR + CONSTANTS_CONFIG_FILE;
211+
err = parse_json(
212+
zip_archive, model_constants, tensor_name_to_path, tensor_name_to_layout);
213+
ET_CHECK_OR_RETURN_ERROR(
214+
err == Error::Ok,
215+
InvalidExternalData,
216+
"Failed to parse model constants json config");
217+
218+
// Create data loader to wrap around zip archive.
219+
Result<MmapDataLoader> loader =
220+
MmapDataLoader::from(pt2_archive_file_path.c_str());
221+
ET_CHECK_OR_RETURN_ERROR(
222+
loader.ok(),
223+
InvalidArgument,
224+
"Loader failed to load with error: %zu",
225+
loader.error());
226+
227+
std::unique_ptr<DataLoader> loader_ptr =
228+
std::make_unique<MmapDataLoader>(std::move(loader.get()));
229+
return PT2ArchiveDataMap(
230+
std::move(zip_archive),
231+
std::move(loader_ptr),
232+
std::move(archive_name),
233+
std::move(tensor_name_to_layout),
234+
std::move(tensor_name_to_path));
235+
}
236+
237+
Result<const TensorLayout> PT2ArchiveDataMap::get_tensor_layout(
238+
string_view key) const {
239+
if (tensor_name_to_layout_.find(key.data()) == tensor_name_to_layout_.end()) {
240+
ET_LOG(Error, "Tensor layout not found for key %s", key.data());
241+
return Error::NotFound;
242+
}
243+
return tensor_name_to_layout_.at(key.data()).create_tensor_layout();
244+
}
245+
246+
Result<FreeableBuffer> PT2ArchiveDataMap::get_data(string_view key) const {
247+
if (tensor_name_to_path_.find(key.data()) == tensor_name_to_path_.end()) {
248+
ET_LOG(Error, "Tensor data not found for key %s", key.data());
249+
return Error::NotFound;
250+
}
251+
252+
// Load data from zip archive - see PyTorch equivalent:
253+
// https://www.internalfb.com/code/fbsource/[f25405534204]/fbcode/caffe2/caffe2/serialize/inline_container.cc?lines=614
254+
std::string file_path =
255+
archive_name_ + WEIGHTS_DIR + tensor_name_to_path_.at(key.data());
256+
int file_index = mz_zip_reader_locate_file(
257+
zip_archive_.get(), file_path.c_str(), nullptr, 0);
258+
259+
mz_zip_archive_file_stat file_stat;
260+
if (!mz_zip_reader_file_stat(zip_archive_.get(), file_index, &file_stat)) {
261+
ET_LOG(Error, "Failed to get file stat for file '%s'\n", file_path.c_str());
262+
return Error::InvalidExternalData;
263+
}
264+
mz_uint64 file_size = file_stat.m_uncomp_size;
265+
mz_uint8 local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
266+
if (mz_zip_read_archive_data(
267+
zip_archive_.get(),
268+
file_stat.m_local_header_ofs,
269+
local_header,
270+
MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) {
271+
ET_LOG(Info, "Failed to read local header for '%s'\n", file_path.c_str());
272+
return Error::InvalidExternalData;
273+
}
274+
mz_uint32 sig = MZ_READ_LE32(local_header);
275+
if (sig != MZ_ZIP_LOCAL_DIR_HEADER_SIG) {
276+
ET_LOG(
277+
Info,
278+
"Invalid local header signature for '%s': 0x%08X\n",
279+
file_path.c_str(),
280+
sig);
281+
return Error::InvalidExternalData;
282+
}
283+
284+
// Calculate offset.
285+
mz_uint16 filename_len =
286+
read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
287+
mz_uint16 extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
288+
mz_uint64 offset = file_stat.m_local_header_ofs +
289+
MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
290+
291+
return loader_->load(
292+
offset,
293+
file_size,
294+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
295+
}
296+
297+
Error PT2ArchiveDataMap::load_data_into(
298+
ET_UNUSED string_view key,
299+
ET_UNUSED void* buffer,
300+
ET_UNUSED size_t size) const {
301+
return Error::NotImplemented;
302+
}
303+
304+
Result<uint32_t> PT2ArchiveDataMap::get_num_keys() const {
305+
return tensor_name_to_path_.size();
306+
}
307+
308+
Result<const char*> PT2ArchiveDataMap::get_key(uint32_t index) const {
309+
auto num_keys = get_num_keys().get();
310+
ET_CHECK_OR_RETURN_ERROR(
311+
index < num_keys,
312+
InvalidArgument,
313+
"Index %u out of range of size %u",
314+
index,
315+
num_keys);
316+
int i = 0;
317+
for (const auto& item : tensor_name_to_path_) {
318+
if (i == index) {
319+
return item.first.c_str();
320+
}
321+
++i;
322+
}
323+
// Should not reach here.
324+
return Error::Internal;
325+
}
326+
327+
} // namespace extension
328+
} // namespace executorch

0 commit comments

Comments
 (0)