Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add inferenceProvider filter when listing models #1198

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions packages/hub/src/lib/list-models.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,18 @@ describe("listModels", () => {

expect(count).to.equal(10);
});

it("should search model by inference provider", async () => {
let count = 0;
for await (const entry of listModels({
search: { inferenceProvider: "together" },
additionalFields: ["inferenceProviderMapping"],
limit: 10,
})) {
count++;
expect(Object.keys(entry.inferenceProviderMapping)).includes("together");
}

expect(count).to.equal(10);
});
});
2 changes: 2 additions & 0 deletions packages/hub/src/lib/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export async function* listModels<
owner?: string;
task?: PipelineType;
tags?: string[];
inferenceProvider?: string;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can search for several providers at once

Suggested change
inferenceProvider?: string;
inferenceProvider?: string[];

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing more than one provider raises an error:

Error: Internal Error - We're working hard to fix this as soon as possible!. URL: https://huggingface.co/api/models?limit=10&expand=pipeline_tag&expand=private&expand=gated&expand=downloads&expand=likes&expand=lastModified&expand=inferenceProviderMapping&inference_provider=together&inference_provider=replicate. Request ID: Root=1-67ae03c6-377f36e3001096932c7699a4

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this URL it works:
https://huggingface.co/api/models?limit=10&expand=pipeline_tag&expand=private&expand=gated&expand=downloads&expand=likes&expand=lastModified&expand=inferenceProviderMapping&inference_provider=together,replicate

inference_provider=together,replicate instead of inference_provider=together&inference_provider=replicate

That being said - this should not end up in a Server Error 😵 💫

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool. Thanks!. I thought the query param had the same syntax as the expand one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the filter terms are combined with an OR clause, right? I mean inference_provider=together,replicate returns models with one of those providers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it's an OR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I think it's ready.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @frascuchon !

};
hubUrl?: string;
additionalFields?: T[];
Expand All @@ -84,6 +85,7 @@ export async function* listModels<
...(params?.search?.owner ? { author: params.search.owner } : undefined),
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
...(params?.search?.query ? { search: params.search.query } : undefined),
...(params?.search?.inferenceProvider ? { inference_provider: params.search.inferenceProvider } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
Expand Down
Loading