diff --git a/packages/hub/src/types/api/api-model.ts b/packages/hub/src/types/api/api-model.ts index 1e051f0d9..e9510f279 100644 --- a/packages/hub/src/types/api/api-model.ts +++ b/packages/hub/src/types/api/api-model.ts @@ -1,4 +1,4 @@ -import type { ModelLibraryKey, TransformersInfo } from "@huggingface/tasks"; +import type { ModelLibraryKey, TransformersInfo, WidgetType } from "@huggingface/tasks"; import type { License, PipelineType } from "../public"; export interface ApiModelInfo { @@ -18,9 +18,8 @@ export interface ApiModelInfo { downloadsAllTime: number; files: string[]; gitalyUid: string; - inferenceProviderMapping: Record< - string, - { providerId: string; status: "prod" | "staging"; task: PipelineType | "conversational" } + inferenceProviderMapping: Partial< + Record >; lastAuthor: { email: string; user?: string }; lastModified: string; // convert to date diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index d6312c213..6d0869d43 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -172,9 +172,9 @@ async function mapModel(params: { // TODO: cache this call const info = await modelInfo({ name: params.model, additionalFields: ["inferenceProviderMapping"] }); + const inferenceProviderMapping = info.inferenceProviderMapping[params.provider]; // If provider listed => takes precedence over hard-coded mapping - if (params.provider in info.inferenceProviderMapping) { - const inferenceProviderMapping = info.inferenceProviderMapping[params.provider]; + if (inferenceProviderMapping) { if (inferenceProviderMapping.task !== task) { throw new Error( `Model ${params.model} is not supported for task ${task} and provider ${params.provider}. Supported task: ${inferenceProviderMapping.task}.`