Skip to content

Support device-level configuration across all devices #1276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ export class AutoConfig {
/**
* Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
* @typedef {Object} TransformersJSConfig
* @property {Record<import('./utils/devices.js').DeviceType, DeviceConfig>} [device_config] Device-specific configurations.
* @property {import('./utils/tensor.js').DataType|Record<import('./utils/dtypes.js').DataType, import('./utils/tensor.js').DataType>} [kv_cache_dtype] The data type of the key-value cache.
* @property {Record<string, number>} [free_dimension_overrides] Override the free dimensions of the model.
* See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
Expand All @@ -412,3 +413,8 @@ export class AutoConfig {
* @property {import('./utils/dtypes.js').DataType|Record<string, import('./utils/dtypes.js').DataType>} [dtype] The default data type to use for the model.
* @property {import('./utils/hub.js').ExternalData|Record<string, import('./utils/hub.js').ExternalData>} [use_external_data_format=false] Whether to load the model using the external data format (used for models >= 2GB in size).
*/

/**
* Device-specific configuration options.
* @typedef {Omit<TransformersJSConfig, "device" | "device_config">} DeviceConfig
*/
32 changes: 22 additions & 10 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
* @private
*/
async function getSession(pretrained_model_name_or_path, fileName, options) {
const custom_config = options.config?.['transformers.js_config'] ?? {};
let custom_config = options.config?.['transformers.js_config'] ?? {};

let device = options.device ?? custom_config.device;
if (device && typeof device !== 'string') {
if (device.hasOwnProperty(fileName)) {
Expand All @@ -173,8 +174,18 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
const selectedDevice = /** @type {import("./utils/devices.js").DeviceType} */(
device ?? (apis.IS_NODE_ENV ? 'cpu' : 'wasm')
);

const executionProviders = deviceToExecutionProviders(selectedDevice);

// Update custom config with the selected device's config, if it exists
const device_config = custom_config.device_config ?? {};
if (device_config.hasOwnProperty(selectedDevice)) {
custom_config = {
...custom_config,
...device_config[selectedDevice],
};
}

// If options.dtype is specified, we use it to choose the suffix for the model file.
// Otherwise, we use the default dtype for the device.
let dtype = options.dtype ?? custom_config.dtype;
Expand All @@ -191,11 +202,11 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
// Try to choose the auto dtype based on the custom config
let config_dtype = custom_config.dtype;
if (typeof config_dtype !== 'string') {
config_dtype = config_dtype[fileName];
config_dtype = config_dtype?.[fileName];
}

if (config_dtype && config_dtype !== DATA_TYPES.auto && DATA_TYPES.hasOwnProperty(config_dtype)) {
// Defined by the custom config, and is not "auto"
// Defined by the config, and is not "auto"
dtype = config_dtype;
} else {
// Choose default dtype based on device, falling back to fp32
Expand All @@ -212,10 +223,11 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
}

// Only valid for models with a decoder
const kv_cache_dtype = custom_config.kv_cache_dtype
? (typeof custom_config.kv_cache_dtype === 'string'
? custom_config.kv_cache_dtype
: custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
const kv_cache_dtype_config = custom_config.kv_cache_dtype;
const kv_cache_dtype = kv_cache_dtype_config
? (typeof kv_cache_dtype_config === 'string'
? kv_cache_dtype_config
: kv_cache_dtype_config[selectedDtype] ?? 'float32')
: undefined;

if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
Expand Down Expand Up @@ -243,14 +255,14 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
session_options.freeDimensionOverrides ??= free_dimension_overrides;
} else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
console.warn(
'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
`WebNN does not currently support dynamic shapes and requires 'free_dimension_overrides' to be set in config.json, preferably as a field within config["transformers.js_config"]["device_config"]["${selectedDevice}"]. ` +
`When 'free_dimension_overrides' is not set, you may experience significant performance degradation.`
);
}

const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, apis.IS_NODE_ENV);

// handle onnx external data files
// Handle onnx external data files
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
/** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
let externalDataPromises = [];
Expand Down