Skip to content

Commit

Permalink
refactor: Accept multiple inference providers
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Feb 13, 2025
1 parent a19ae13 commit fd8a486
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
24 changes: 22 additions & 2 deletions packages/hub/src/lib/list-models.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,32 @@ describe("listModels", () => {
it("should search model by inference provider", async () => {
let count = 0;
for await (const entry of listModels({
search: { inferenceProvider: "together" },
search: { inferenceProvider: ["together"] },
additionalFields: ["inferenceProviderMapping"],
limit: 10,
})) {
count++;
expect(Object.keys(entry.inferenceProviderMapping)).includes("together");
if (Array.isArray(entry.inferenceProviderMapping)) {
expect(entry.inferenceProviderMapping.map(({ provider }) => provider)).to.include("together");
}
}

expect(count).to.equal(10);
});

it("should search model by several inference providers", async () => {
let count = 0;
for await (const entry of listModels({
search: { inferenceProvider: ["together", "replicate"] },
additionalFields: ["inferenceProviderMapping"],
limit: 10,
})) {
count++;
if (Array.isArray(entry.inferenceProviderMapping)) {
const providerNames = entry.inferenceProviderMapping.map(({ provider }) => provider);
expect(providerNames).to.include("together");
expect(providerNames).to.include("replicate");
}
}

expect(count).to.equal(10);
Expand Down
4 changes: 2 additions & 2 deletions packages/hub/src/lib/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export async function* listModels<
owner?: string;
task?: PipelineType;
tags?: string[];
inferenceProvider?: string;
inferenceProvider?: string[];
};
hubUrl?: string;
additionalFields?: T[];
Expand All @@ -85,11 +85,11 @@ export async function* listModels<
...(params?.search?.owner ? { author: params.search.owner } : undefined),
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
...(params?.search?.query ? { search: params.search.query } : undefined),
...(params?.search?.inferenceProvider ? { inference_provider: params.search.inferenceProvider } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
...(params?.search?.inferenceProvider?.map((val) => ["inference_provider", val]) ?? []),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;

Expand Down

0 comments on commit fd8a486

Please sign in to comment.