|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <executorch/extension/pt2_archive/pt2_archive_data_map.h> |
| 10 | + |
| 11 | +#include <executorch/extension/data_loader/mmap_data_loader.h> |
| 12 | +#include <executorch/runtime/core/data_loader.h> |
| 13 | +#include <executorch/runtime/core/error.h> |
| 14 | +#include <executorch/runtime/core/exec_aten/util/tensor_util.h> |
| 15 | +#include <executorch/runtime/core/freeable_buffer.h> |
| 16 | +#include <executorch/runtime/core/result.h> |
| 17 | +#include <executorch/runtime/core/span.h> |
| 18 | +#include <executorch/runtime/platform/compiler.h> |
| 19 | + |
| 20 | +#include "miniz.h" |
| 21 | + |
| 22 | +#include <nlohmann/json.hpp> |
| 23 | +#include <string.h> |
| 24 | +#include <unordered_map> |
| 25 | + |
| 26 | +using json = nlohmann::json; |
| 27 | + |
| 28 | +using executorch::runtime::Error; |
| 29 | +using executorch::runtime::FreeableBuffer; |
| 30 | +using executorch::runtime::Result; |
| 31 | +using executorch::runtime::Span; |
| 32 | + |
| 33 | +using executorch::aten::ScalarType; |
| 34 | +using executorch::aten::string_view; |
| 35 | +using executorch::ET_RUNTIME_NAMESPACE::TensorLayout; |
| 36 | +using executorch::runtime::DataLoader; |
| 37 | + |
| 38 | +using executorch::extension::MmapDataLoader; |
| 39 | + |
| 40 | +// MZ_ZIP constants. |
| 41 | +constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30; |
| 42 | +constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26; |
| 43 | +constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28; |
| 44 | +constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIG = 0x04034b50; |
| 45 | + |
| 46 | +// PT2Archive constants. |
| 47 | +constexpr const char* WEIGHTS_DIR = "/data/weights/"; |
| 48 | +constexpr const char* WEIGHTS_CONFIG_FILE = "model_weights_config.json"; |
| 49 | +constexpr const char* CONSTANTS_DIR = "/data/constants/"; |
| 50 | +constexpr const char* CONSTANTS_CONFIG_FILE = "model_constants_config.json"; |
| 51 | + |
| 52 | +namespace { |
| 53 | +ScalarType convert_pt2_to_et_scalartype(uint32_t dtype) { |
| 54 | + // PT2 serialization dtypes and ET dtypes are off by 1. |
| 55 | + // PT2: https://fburl.com/code/qjlmiifs (contains UNKNOWN at enum 0) |
| 56 | + // ET: https://fburl.com/code/gq30tizb (starts with BYTE at enum 0) |
| 57 | + return static_cast<ScalarType>(dtype - 1); |
| 58 | +} |
| 59 | + |
| 60 | +// Use to read miniz header info. |
| 61 | +static int64_t read_le_16(uint8_t* buf) { |
| 62 | + return buf[0] + (buf[1] << 8); |
| 63 | +} |
| 64 | +} // namespace |
| 65 | + |
| 66 | +namespace executorch { |
| 67 | +namespace extension { |
| 68 | + |
| 69 | +PT2ArchiveDataMap::~PT2ArchiveDataMap() { |
| 70 | + // Close zip archive resources. |
| 71 | + if (zip_archive_) { |
| 72 | + mz_zip_reader_end(zip_archive_.get()); |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +/*static*/ Error PT2ArchiveDataMap::parse_json( |
| 77 | + std::unique_ptr<mz_zip_archive>& zip_archive, |
| 78 | + const std::string& filename, |
| 79 | + std::unordered_map<std::string, std::string>& tensor_name_to_path, |
| 80 | + std::unordered_map<std::string, ConcreteTensorLayout>& |
| 81 | + tensor_name_to_layout) { |
| 82 | + /** JSON format (for information we care about) looks like this: |
| 83 | + "config": { |
| 84 | + "weight_name": { |
| 85 | + "path_name": "weight_0", |
| 86 | + "tensor_meta": { |
| 87 | + "dtype": <DTYPE>, |
| 88 | + "sizes": [{"as_int": <SIZE>}, {"as_int": <SIZE>}, ...], |
| 89 | + "strides": [{"as_int": <SIZE>}, {"as_int": <SIZE>}, ...], |
| 90 | + } |
| 91 | + } |
| 92 | + } */ |
| 93 | + size_t uncomp_size = 0; |
| 94 | + void* buffer = mz_zip_reader_extract_file_to_heap( |
| 95 | + zip_archive.get(), filename.c_str(), &uncomp_size, 0); |
| 96 | + if (!buffer) { |
| 97 | + ET_LOG(Error, "Failed to extract file %s to heap", filename.c_str()); |
| 98 | + mz_zip_reader_end(zip_archive.get()); |
| 99 | + return Error::InvalidExternalData; |
| 100 | + } |
| 101 | + json json_config; |
| 102 | + try { |
| 103 | + std::string json_str(static_cast<const char*>(buffer), uncomp_size); |
| 104 | + // Parse JSON string. |
| 105 | + json_config = json::parse(json_str); |
| 106 | + ET_CHECK_OR_RETURN_ERROR( |
| 107 | + json_config.contains("config"), |
| 108 | + InvalidExternalData, |
| 109 | + "JSON config does not contain 'config' key; malformed archive file."); |
| 110 | + auto config = json_config["config"]; |
| 111 | + for (auto& item : config.items()) { |
| 112 | + ET_CHECK_OR_RETURN_ERROR( |
| 113 | + item.value().contains("path_name") && |
| 114 | + item.value().contains("tensor_meta"), |
| 115 | + InvalidExternalData, |
| 116 | + "JSON config does not contain 'path_name' and 'tensor_meta' keys for key %s", |
| 117 | + item.key().c_str()); |
| 118 | + |
| 119 | + // Add tensor_name -> path_name mapping. |
| 120 | + tensor_name_to_path[item.key().c_str()] = item.value()["path_name"]; |
| 121 | + |
| 122 | + // Add tensor_name -> tensor_meta mapping. |
| 123 | + auto tensor_meta = item.value()["tensor_meta"]; |
| 124 | + ET_CHECK_OR_RETURN_ERROR( |
| 125 | + tensor_meta.contains("dtype") && |
| 126 | + tensor_meta["dtype"].is_number_integer(), |
| 127 | + InvalidExternalData, |
| 128 | + "JSON config does not contain 'dtype' key for key %s", |
| 129 | + item.key().c_str()); |
| 130 | + ET_CHECK_OR_RETURN_ERROR( |
| 131 | + tensor_meta.contains("sizes") && tensor_meta["sizes"].is_array(), |
| 132 | + InvalidExternalData, |
| 133 | + "JSON config does not contain 'sizes' key for key %s", |
| 134 | + item.key().c_str()); |
| 135 | + ET_CHECK_OR_RETURN_ERROR( |
| 136 | + tensor_meta.contains("strides") && tensor_meta["strides"].is_array(), |
| 137 | + InvalidExternalData, |
| 138 | + "JSON config does not contain 'strides' key for key %s", |
| 139 | + item.key().c_str()); |
| 140 | + ConcreteTensorLayout concrete_layout; |
| 141 | + concrete_layout.scalar_type = |
| 142 | + convert_pt2_to_et_scalartype(tensor_meta["dtype"].get<int>()); |
| 143 | + int i = 0; |
| 144 | + for (const auto& size : tensor_meta["sizes"]) { |
| 145 | + concrete_layout.sizes.push_back(size["as_int"].get<int32_t>()); |
| 146 | + // TODO: Calculate dim order from strides. Assume contiguous for now. |
| 147 | + concrete_layout.dim_order.push_back(i); |
| 148 | + ++i; |
| 149 | + } |
| 150 | + tensor_name_to_layout[item.key().c_str()] = std::move(concrete_layout); |
| 151 | + } |
| 152 | + free(buffer); |
| 153 | + } catch (const json::exception& e) { |
| 154 | + ET_LOG(Error, "Failed to parse JSON: %s", e.what()); |
| 155 | + free(buffer); |
| 156 | + mz_zip_reader_end(zip_archive.get()); |
| 157 | + return Error::InvalidExternalData; |
| 158 | + } |
| 159 | + return Error::Ok; |
| 160 | +} |
| 161 | + |
| 162 | +/*static*/ Result<PT2ArchiveDataMap> PT2ArchiveDataMap::load( |
| 163 | + const std::string& pt2_archive_file_path) { |
| 164 | + ET_LOG( |
| 165 | + Info, "Loading PT2ArchiveDataMap from %s", pt2_archive_file_path.c_str()); |
| 166 | + auto zip_archive = std::make_unique<mz_zip_archive>(); |
| 167 | + // Open zip archive to get json config data. |
| 168 | + memset(zip_archive.get(), 0, sizeof(mz_zip_archive)); |
| 169 | + mz_bool status = mz_zip_reader_init_file( |
| 170 | + zip_archive.get(), pt2_archive_file_path.c_str(), 0); |
| 171 | + |
| 172 | + ET_CHECK_OR_RETURN_ERROR( |
| 173 | + status == 1, |
| 174 | + InvalidArgument, |
| 175 | + "Failed to open zip archive %s, status: %d", |
| 176 | + pt2_archive_file_path.c_str(), |
| 177 | + status); |
| 178 | + |
| 179 | + // Extract archive name. |
| 180 | + mz_uint n = mz_zip_reader_get_num_files(zip_archive.get()); |
| 181 | + ET_CHECK_OR_RETURN_ERROR( |
| 182 | + n > 0, InvalidExternalData, "Archive does not contain any files"); |
| 183 | + mz_uint name_size = |
| 184 | + mz_zip_reader_get_filename(zip_archive.get(), 0, nullptr, 0); |
| 185 | + std::string buf(name_size, '\0'); |
| 186 | + mz_zip_reader_get_filename(zip_archive.get(), 0, &buf[0], name_size); |
| 187 | + auto pos = buf.find_first_of('/'); |
| 188 | + ET_CHECK_OR_RETURN_ERROR( |
| 189 | + pos != std::string::npos, |
| 190 | + InvalidExternalData, |
| 191 | + "File in archive is not in a subdirectory"); |
| 192 | + |
| 193 | + std::string archive_name = buf.substr(0, pos); |
| 194 | + |
| 195 | + // Set up data structures for tensor name -> {path, metadata}. |
| 196 | + std::unordered_map<std::string, std::string> tensor_name_to_path; |
| 197 | + std::unordered_map<std::string, ConcreteTensorLayout> tensor_name_to_layout; |
| 198 | + |
| 199 | + // Read model_weights.json file. |
| 200 | + std::string model_weights = archive_name + WEIGHTS_DIR + WEIGHTS_CONFIG_FILE; |
| 201 | + Error err = parse_json( |
| 202 | + zip_archive, model_weights, tensor_name_to_path, tensor_name_to_layout); |
| 203 | + ET_CHECK_OR_RETURN_ERROR( |
| 204 | + err == Error::Ok, |
| 205 | + InvalidExternalData, |
| 206 | + "Failed to parse model weights json config"); |
| 207 | + |
| 208 | + // Read model_constants.json file. |
| 209 | + std::string model_constants = |
| 210 | + archive_name + CONSTANTS_DIR + CONSTANTS_CONFIG_FILE; |
| 211 | + err = parse_json( |
| 212 | + zip_archive, model_constants, tensor_name_to_path, tensor_name_to_layout); |
| 213 | + ET_CHECK_OR_RETURN_ERROR( |
| 214 | + err == Error::Ok, |
| 215 | + InvalidExternalData, |
| 216 | + "Failed to parse model constants json config"); |
| 217 | + |
| 218 | + // Create data loader to wrap around zip archive. |
| 219 | + Result<MmapDataLoader> loader = |
| 220 | + MmapDataLoader::from(pt2_archive_file_path.c_str()); |
| 221 | + ET_CHECK_OR_RETURN_ERROR( |
| 222 | + loader.ok(), |
| 223 | + InvalidArgument, |
| 224 | + "Loader failed to load with error: %zu", |
| 225 | + loader.error()); |
| 226 | + |
| 227 | + std::unique_ptr<DataLoader> loader_ptr = |
| 228 | + std::make_unique<MmapDataLoader>(std::move(loader.get())); |
| 229 | + return PT2ArchiveDataMap( |
| 230 | + std::move(zip_archive), |
| 231 | + std::move(loader_ptr), |
| 232 | + std::move(archive_name), |
| 233 | + std::move(tensor_name_to_layout), |
| 234 | + std::move(tensor_name_to_path)); |
| 235 | +} |
| 236 | + |
| 237 | +Result<const TensorLayout> PT2ArchiveDataMap::get_tensor_layout( |
| 238 | + string_view key) const { |
| 239 | + if (tensor_name_to_layout_.find(key.data()) == tensor_name_to_layout_.end()) { |
| 240 | + ET_LOG(Error, "Tensor layout not found for key %s", key.data()); |
| 241 | + return Error::NotFound; |
| 242 | + } |
| 243 | + return tensor_name_to_layout_.at(key.data()).create_tensor_layout(); |
| 244 | +} |
| 245 | + |
| 246 | +Result<FreeableBuffer> PT2ArchiveDataMap::get_data(string_view key) const { |
| 247 | + if (tensor_name_to_path_.find(key.data()) == tensor_name_to_path_.end()) { |
| 248 | + ET_LOG(Error, "Tensor data not found for key %s", key.data()); |
| 249 | + return Error::NotFound; |
| 250 | + } |
| 251 | + |
| 252 | + // Load data from zip archive - see PyTorch equivalent: |
| 253 | + // https://www.internalfb.com/code/fbsource/[f25405534204]/fbcode/caffe2/caffe2/serialize/inline_container.cc?lines=614 |
| 254 | + std::string file_path = |
| 255 | + archive_name_ + WEIGHTS_DIR + tensor_name_to_path_.at(key.data()); |
| 256 | + int file_index = mz_zip_reader_locate_file( |
| 257 | + zip_archive_.get(), file_path.c_str(), nullptr, 0); |
| 258 | + |
| 259 | + mz_zip_archive_file_stat file_stat; |
| 260 | + if (!mz_zip_reader_file_stat(zip_archive_.get(), file_index, &file_stat)) { |
| 261 | + ET_LOG(Error, "Failed to get file stat for file '%s'\n", file_path.c_str()); |
| 262 | + return Error::InvalidExternalData; |
| 263 | + } |
| 264 | + mz_uint64 file_size = file_stat.m_uncomp_size; |
| 265 | + // NOLINTNEXTLINE(facebook-hte-CArray) |
| 266 | + mz_uint8 local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE]; |
| 267 | + if (mz_zip_read_archive_data( |
| 268 | + zip_archive_.get(), |
| 269 | + file_stat.m_local_header_ofs, |
| 270 | + local_header, |
| 271 | + MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) { |
| 272 | + ET_LOG(Info, "Failed to read local header for '%s'\n", file_path.c_str()); |
| 273 | + return Error::InvalidExternalData; |
| 274 | + } |
| 275 | + mz_uint32 sig = MZ_READ_LE32(local_header); |
| 276 | + if (sig != MZ_ZIP_LOCAL_DIR_HEADER_SIG) { |
| 277 | + ET_LOG( |
| 278 | + Info, |
| 279 | + "Invalid local header signature for '%s': 0x%08X\n", |
| 280 | + file_path.c_str(), |
| 281 | + sig); |
| 282 | + return Error::InvalidExternalData; |
| 283 | + } |
| 284 | + |
| 285 | + // Calculate offset. |
| 286 | + mz_uint16 filename_len = |
| 287 | + read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS); |
| 288 | + mz_uint16 extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); |
| 289 | + mz_uint64 offset = file_stat.m_local_header_ofs + |
| 290 | + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len; |
| 291 | + |
| 292 | + return loader_->load( |
| 293 | + offset, |
| 294 | + file_size, |
| 295 | + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); |
| 296 | +} |
| 297 | + |
| 298 | +Error PT2ArchiveDataMap::load_data_into( |
| 299 | + ET_UNUSED string_view key, |
| 300 | + ET_UNUSED void* buffer, |
| 301 | + ET_UNUSED size_t size) const { |
| 302 | + return Error::NotImplemented; |
| 303 | +} |
| 304 | + |
| 305 | +Result<uint32_t> PT2ArchiveDataMap::get_num_keys() const { |
| 306 | + return static_cast<uint32_t>(tensor_name_to_path_.size()); |
| 307 | +} |
| 308 | + |
| 309 | +Result<const char*> PT2ArchiveDataMap::get_key(uint32_t index) const { |
| 310 | + auto num_keys = get_num_keys().get(); |
| 311 | + ET_CHECK_OR_RETURN_ERROR( |
| 312 | + index < num_keys, |
| 313 | + InvalidArgument, |
| 314 | + "Index %u out of range of size %u", |
| 315 | + index, |
| 316 | + num_keys); |
| 317 | + int i = 0; |
| 318 | + for (const auto& item : tensor_name_to_path_) { |
| 319 | + if (i == index) { |
| 320 | + return item.first.c_str(); |
| 321 | + } |
| 322 | + ++i; |
| 323 | + } |
| 324 | + // Should not reach here. |
| 325 | + return Error::Internal; |
| 326 | +} |
| 327 | + |
| 328 | +} // namespace extension |
| 329 | +} // namespace executorch |
0 commit comments