Skip to content

Commit

Permalink
implement a cache
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Feb 5, 2025
1 parent 12e4bed commit 2737081
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions packages/inference/src/lib/getProviderModelId.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import type { WidgetType } from "@huggingface/tasks";
import { HF_HUB_URL } from "../config";
import { HARDCODED_MODEL_ID_MAPPING } from "../providers/consts";
import type { InferenceProvider, InferenceTask, Options, RequestArgs } from "../types";
import type { ModelId } from "../types";
import { type InferenceProvider, type InferenceTask, type Options, type RequestArgs } from "../types";

type InferenceProviderMapping = Partial<
Record<InferenceProvider, { providerId: string; status: "live" | "staging"; task: WidgetType }>
>;
const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();

export async function getProviderModelId(
params: {
Expand Down Expand Up @@ -29,16 +35,25 @@ export async function getProviderModelId(
return HARDCODED_MODEL_ID_MAPPING[params.model];
}

// TODO: cache this call
const inferenceProviderMapping = await (options?.fetch ?? fetch)(
`${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
{
headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {},
}
)
.then((resp) => resp.json())
.then((json) => json.inferenceProviderMapping)
.catch(() => null);
let inferenceProviderMapping: InferenceProviderMapping | null;
if (inferenceProviderMappingCache.has(params.model)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = inferenceProviderMappingCache.get(params.model)!;
} else {
inferenceProviderMapping = await (options?.fetch ?? fetch)(
`${HF_HUB_URL}/api/models/${params.model}?expand[]=inferenceProviderMapping`,
{
headers: args.accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${args.accessToken}` } : {},
}
)
.then((resp) => resp.json())
.then((json) => json.inferenceProviderMapping)
.catch(() => null);
}

if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${params.model}.`);
}

const providerMapping = inferenceProviderMapping[params.provider];
// If provider listed => takes precedence over hard-coded mapping
Expand Down

0 comments on commit 2737081

Please sign in to comment.