Skip to content

Commit 3d6bcf4

Browse files
xenovaaxratitangkunyin
authored
Add support for multi-chunk external data files (huggingface#1212)
* Cache path vs Local path * Update src/utils/hub.js * complete the full parameters of option for pipeline api, see: huggingface#1200 Signed-off-by: Thomas Tang <[email protected]> * Add support for multi-chunk external data files * Add MAX_EXTERNAL_DATA_CHUNKS check * Support loading inference sessions from paths (node.js external data) * Fix types * Add support for node.js external data & optimize downloading * Remove debug logs * Optimize FileResponse body stream for large models * Expose `use_external_data_format` in pipeline options. TODO: Improve support and allow per-file specification * Fix post-merge * Dedent comments * Default use_external_data_format to null (unset) * let -> const where possible * Support setting external data format in config.json * Add external data model architecture tests * Simplify tests * Formatting --------- Signed-off-by: Thomas Tang <[email protected]> Co-authored-by: axrati <[email protected]> Co-authored-by: Axm <[email protected]> Co-authored-by: Thomas Tang <[email protected]>
1 parent d799fb6 commit 3d6bcf4

File tree

6 files changed

+211
-124
lines changed

6 files changed

+211
-124
lines changed

src/backends/onnx.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,19 @@ let wasmInitPromise = null;
141141

142142
/**
143143
* Create an ONNX inference session.
144-
* @param {Uint8Array} buffer The ONNX model buffer.
144+
* @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path.
145145
* @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
146146
* @param {Object} session_config ONNX inference session configuration.
147147
* @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
148148
*/
149-
export async function createInferenceSession(buffer, session_options, session_config) {
149+
export async function createInferenceSession(buffer_or_path, session_options, session_config) {
150150
if (wasmInitPromise) {
151151
// A previous session has already initialized the WASM runtime
152152
// so we wait for it to resolve before creating this new session.
153153
await wasmInitPromise;
154154
}
155155

156-
const sessionPromise = InferenceSession.create(buffer, session_options);
156+
const sessionPromise = InferenceSession.create(buffer_or_path, session_options);
157157
wasmInitPromise ??= sessionPromise;
158158
const session = await sessionPromise;
159159
session.config = session_config;

src/configs.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -408,5 +408,5 @@ export class AutoConfig {
408408
* for more information.
409409
* @property {import('./utils/devices.js').DeviceType} [device] The default device to use for the model.
410410
* @property {import('./utils/dtypes.js').DataType|Record<string, import('./utils/dtypes.js').DataType>} [dtype] The default data type to use for the model.
411-
* @property {boolean|Record<string, boolean>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
411+
* @property {import('./utils/hub.js').ExternalData|Record<string, import('./utils/hub.js').ExternalData>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
412412
*/

src/models.js

+39-25
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ import {
6868
import {
6969
getModelFile,
7070
getModelJSON,
71+
MAX_EXTERNAL_DATA_CHUNKS,
7172
} from './utils/hub.js';
7273

7374
import {
@@ -153,7 +154,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
153154
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
154155
* @param {string} fileName The name of the model file.
155156
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
156-
* @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
157+
* @returns {Promise<{buffer_or_path: Uint8Array|string, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
157158
* @private
158159
*/
159160
async function getSession(pretrained_model_name_or_path, fileName, options) {
@@ -228,7 +229,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
228229

229230
// Construct the model file name
230231
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
231-
const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
232+
const baseName = `${fileName}${suffix}.onnx`;
233+
const modelFileName = `${options.subfolder ?? ''}/${baseName}`;
232234

233235
const session_options = { ...options.session_options };
234236

@@ -246,29 +248,38 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
246248
);
247249
}
248250

249-
const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
251+
const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, apis.IS_NODE_ENV);
250252

251253
// handle onnx external data files
252254
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
253-
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
255+
/** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
254256
let externalDataPromises = [];
255-
if (use_external_data_format && (
256-
use_external_data_format === true ||
257-
(
258-
typeof use_external_data_format === 'object' &&
259-
use_external_data_format.hasOwnProperty(fileName) &&
260-
use_external_data_format[fileName] === true
261-
)
262-
)) {
263-
if (apis.IS_NODE_ENV) {
264-
throw new Error('External data format is not yet supported in Node.js');
257+
if (use_external_data_format) {
258+
let external_data_format;
259+
if (typeof use_external_data_format === 'object') {
260+
if (use_external_data_format.hasOwnProperty(baseName)) {
261+
external_data_format = use_external_data_format[baseName];
262+
} else if (use_external_data_format.hasOwnProperty(fileName)) {
263+
external_data_format = use_external_data_format[fileName];
264+
} else {
265+
external_data_format = false;
266+
}
267+
} else {
268+
external_data_format = use_external_data_format;
269+
}
270+
271+
const num_chunks = +external_data_format; // (false=0, true=1, number remains the same)
272+
if (num_chunks > MAX_EXTERNAL_DATA_CHUNKS) {
273+
throw new Error(`The number of external data chunks (${num_chunks}) exceeds the maximum allowed value (${MAX_EXTERNAL_DATA_CHUNKS}).`);
274+
}
275+
for (let i = 0; i < num_chunks; ++i) {
276+
const path = `${baseName}_data${i === 0 ? '' : '_' + i}`;
277+
const fullPath = `${options.subfolder ?? ''}/${path}`;
278+
externalDataPromises.push(new Promise(async (resolve, reject) => {
279+
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV);
280+
resolve(data instanceof Uint8Array ? { path, data } : path);
281+
}));
265282
}
266-
const path = `${fileName}${suffix}.onnx_data`;
267-
const fullPath = `${options.subfolder ?? ''}/${path}`;
268-
externalDataPromises.push(new Promise(async (resolve, reject) => {
269-
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options);
270-
resolve({ path, data })
271-
}));
272283

273284
} else if (session_options.externalData !== undefined) {
274285
externalDataPromises = session_options.externalData.map(async (ext) => {
@@ -285,7 +296,10 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
285296
}
286297

287298
if (externalDataPromises.length > 0) {
288-
session_options.externalData = await Promise.all(externalDataPromises);
299+
const externalData = await Promise.all(externalDataPromises);
300+
if (!apis.IS_NODE_ENV) {
301+
session_options.externalData = externalData;
302+
}
289303
}
290304

291305
if (selectedDevice === 'webgpu') {
@@ -303,9 +317,9 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
303317
}
304318
}
305319

306-
const buffer = await bufferPromise;
320+
const buffer_or_path = await bufferOrPathPromise;
307321

308-
return { buffer, session_options, session_config };
322+
return { buffer_or_path, session_options, session_config };
309323
}
310324

311325
/**
@@ -320,8 +334,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
320334
async function constructSessions(pretrained_model_name_or_path, names, options) {
321335
return Object.fromEntries(await Promise.all(
322336
Object.keys(names).map(async (name) => {
323-
const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
324-
const session = await createInferenceSession(buffer, session_options, session_config);
337+
const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
338+
const session = await createInferenceSession(buffer_or_path, session_options, session_config);
325339
return [name, session];
326340
})
327341
));

src/pipelines.js

+4
Original file line numberDiff line numberDiff line change
@@ -3399,6 +3399,8 @@ export async function pipeline(
33993399
revision = 'main',
34003400
device = null,
34013401
dtype = null,
3402+
subfolder = 'onnx',
3403+
use_external_data_format = null,
34023404
model_file_name = null,
34033405
session_options = {},
34043406
} = {}
@@ -3429,6 +3431,8 @@ export async function pipeline(
34293431
revision,
34303432
device,
34313433
dtype,
3434+
subfolder,
3435+
use_external_data_format,
34323436
model_file_name,
34333437
session_options,
34343438
}

0 commit comments

Comments
 (0)