Skip to content

[WIP] Translate JS to TS #1197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
25 changes: 14 additions & 11 deletions src/backends/onnx.js → src/backends/onnx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@
* @module backends/onnx
*/

import { env, apis } from '../env.js';
import { env, apis } from '../env';

// NOTE: Import order matters here. We need to import `onnxruntime-node` before `onnxruntime-web`.
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web';
import { DeviceType } from '../utils/devices';
import { InferenceSession as ONNXInferenceSession } from 'onnxruntime-common';

export { Tensor } from 'onnxruntime-common';

/**
* @typedef {import('onnxruntime-common').InferenceSession.ExecutionProviderConfig} ONNXExecutionProviders
*/
type ONNXExecutionProviders = ONNXInferenceSession.ExecutionProviderConfig;

/** @type {Record<import("../utils/devices.js").DeviceType, ONNXExecutionProviders>} */
const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
const DEVICE_TO_EXECUTION_PROVIDER_MAPPING: Record<DeviceType, ONNXExecutionProviders> = Object.freeze({
auto: null, // Auto-detect based on device and environment
gpu: null, // Auto-detect GPU
cpu: 'cpu', // CPU
Expand All @@ -49,10 +52,10 @@ const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
* The list of supported devices, sorted by priority/performance.
* @type {import("../utils/devices.js").DeviceType[]}
*/
const supportedDevices = [];
const supportedDevices: DeviceType[] = [];

/** @type {ONNXExecutionProviders[]} */
let defaultDevices;
let defaultDevices: ONNXExecutionProviders[];
let ONNX;
const ORT_SYMBOL = Symbol.for('onnxruntime');

Expand All @@ -61,7 +64,7 @@ if (ORT_SYMBOL in globalThis) {
ONNX = globalThis[ORT_SYMBOL];

} else if (apis.IS_NODE_ENV) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;
ONNX = ONNX_NODE;

// Updated as of ONNX Runtime 1.20.1
// The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
Expand Down Expand Up @@ -109,7 +112,7 @@ const InferenceSession = ONNX.InferenceSession;
* @param {import("../utils/devices.js").DeviceType|"auto"|null} [device=null] (Optional) The device to run the inference on.
* @returns {ONNXExecutionProviders[]} The execution providers to use for the given device.
*/
export function deviceToExecutionProviders(device = null) {
export function deviceToExecutionProviders(device: DeviceType | "auto" | null = null): ONNXExecutionProviders[] {
// Use the default execution providers if the user hasn't specified anything
if (!device) return defaultDevices;

Expand Down Expand Up @@ -137,7 +140,7 @@ export function deviceToExecutionProviders(device = null) {
* will wait for this Promise to resolve before creating their own InferenceSession.
* @type {Promise<any>|null}
*/
let wasmInitPromise = null;
let wasmInitPromise: Promise<any> | null = null;

/**
* Create an ONNX inference session.
Expand All @@ -146,7 +149,7 @@ let wasmInitPromise = null;
* @param {Object} session_config ONNX inference session configuration.
* @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
*/
export async function createInferenceSession(buffer, session_options, session_config) {
export async function createInferenceSession(buffer: Uint8Array, session_options: ONNXInferenceSession.SessionOptions, session_config: Object): Promise<ONNXInferenceSession & { config: Object; }> {
if (wasmInitPromise) {
// A previous session has already initialized the WASM runtime
// so we wait for it to resolve before creating this new session.
Expand All @@ -165,13 +168,13 @@ export async function createInferenceSession(buffer, session_options, session_co
* @param {any} x The object to check
* @returns {boolean} Whether the object is an ONNX tensor.
*/
export function isONNXTensor(x) {
export function isONNXTensor(x: any): boolean {
return x instanceof ONNX.Tensor;
}

/** @type {import('onnxruntime-common').Env} */
// @ts-ignore
const ONNX_ENV = ONNX?.env;
const ONNX_ENV: Env = ONNX?.env;
if (ONNX_ENV?.wasm) {
// Initialize wasm backend with suitable default settings.

Expand Down Expand Up @@ -202,7 +205,7 @@ if (ONNX_ENV?.webgpu) {
* Check if ONNX's WASM backend is being proxied.
* @returns {boolean} Whether ONNX's WASM backend is being proxied.
*/
export function isONNXProxy() {
export function isONNXProxy(): boolean {
// TODO: Update this when allowing non-WASM backends.
return ONNX_ENV?.wasm?.proxy;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { FEATURE_EXTRACTOR_NAME } from "../utils/constants.js";
import { Callable } from "../utils/generic.js";
import { getModelJSON } from "../utils/hub.js";
import { getModelJSON, PretrainedOptions } from "../utils/hub.js";

/**
* Base class for feature extractors.
*/
export class FeatureExtractor extends Callable {
config: Object;
/**
* Constructs a new FeatureExtractor instance.
*
* @param {Object} config The configuration for the feature extractor.
*/
constructor(config) {
constructor(config: Object) {
super();
this.config = config
}
Expand All @@ -27,11 +28,11 @@ export class FeatureExtractor extends Callable {
* Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
* user or organization name, like `dbmdz/bert-base-german-cased`.
* - A path to a *directory* containing feature_extractor files, e.g., `./my_model_directory/`.
* @param {import('../utils/hub.js').PretrainedOptions} options Additional options for loading the feature_extractor.
* @param {import('../utils/hub').PretrainedOptions} options Additional options for loading the feature_extractor.
*
* @returns {Promise<FeatureExtractor>} A new instance of the Feature Extractor class.
*/
static async from_pretrained(pretrained_model_name_or_path, options) {
static async from_pretrained(pretrained_model_name_or_path: string, options: PretrainedOptions): Promise<FeatureExtractor> {
const config = await getModelJSON(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, true, options);
return new this(config);
}
Expand All @@ -44,9 +45,10 @@ export class FeatureExtractor extends Callable {
* @param {string} feature_extractor The name of the feature extractor.
* @private
*/
export function validate_audio_inputs(audio, feature_extractor) {
export function validate_audio_inputs(audio: Float32Array | Float64Array, feature_extractor: string) {
if (!(audio instanceof Float32Array || audio instanceof Float64Array)) {
throw new Error(
// @ts-expect-error TS2339
`${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` +
`If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.`
)
Expand Down
Loading