Skip to content

Commit 20803ba

Browse files
committed
improve python snippets to add huggingface_hub
1 parent d515d60 commit 20803ba

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

packages/tasks/src/snippets/js.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import { openAIbaseUrl, type SnippetInferenceProvider } from "../inference-providers.js";
2-
import type { PipelineType } from "../pipelines.js";
2+
import type { PipelineType, WidgetType } from "../pipelines.js";
33
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
44
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
55
import { getModelInputSnippet } from "./inputs.js";
66
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
77

8-
const HFJS_METHODS: Record<string, string> = {
8+
const HFJS_METHODS: Partial<Record<WidgetType, string>> = {
99
"text-classification": "textClassification",
1010
"token-classification": "tokenClassification",
1111
"table-question-answering": "tableQuestionAnswering",
@@ -40,7 +40,7 @@ const output = await client.${HFJS_METHODS[model.pipeline_tag]}({
4040
provider: "${provider}",
4141
});
4242
43-
console.log(output)
43+
console.log(output);
4444
`,
4545
},
4646
]

packages/tasks/src/snippets/python.ts

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,42 @@ import {
33
openAIbaseUrl,
44
type SnippetInferenceProvider,
55
} from "../inference-providers.js";
6-
import type { PipelineType } from "../pipelines.js";
6+
import type { PipelineType, WidgetType } from "../pipelines.js";
77
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
88
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
99
import { getModelInputSnippet } from "./inputs.js";
1010
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";
1111

12+
const HFH_INFERENCE_CLIENT_METHODS: Partial<Record<WidgetType, string>> = {
13+
"audio-classification": "audio_classification",
14+
"audio-to-audio": "audio_to_audio",
15+
"automatic-speech-recognition": "automatic_speech_recognition",
16+
"text-to-speech": "text_to_speech",
17+
"image-classification": "image_classification",
18+
"image-segmentation": "image_segmentation",
19+
"image-to-image": "image_to_image",
20+
"image-to-text": "image_to_text",
21+
"object-detection": "object_detection",
22+
"text-to-image": "text_to_image",
23+
"text-to-video": "text_to_video",
24+
"zero-shot-image-classification": "zero_shot_image_classification",
25+
"document-question-answering": "document_question_answering",
26+
"visual-question-answering": "visual_question_answering",
27+
"feature-extraction": "feature_extraction",
28+
"fill-mask": "fill_mask",
29+
"question-answering": "question_answering",
30+
"sentence-similarity": "sentence_similarity",
31+
summarization: "summarization",
32+
"table-question-answering": "table_question_answering",
33+
"text-classification": "text_classification",
34+
"text-generation": "text_generation",
35+
"token-classification": "token_classification",
36+
translation: "translation",
37+
"zero-shot-classification": "zero_shot_classification",
38+
"tabular-classification": "tabular_classification",
39+
"tabular-regression": "tabular_regression",
40+
};
41+
1242
const snippetImportInferenceClient = (accessToken: string, provider: SnippetInferenceProvider): string =>
1343
`\
1444
from huggingface_hub import InferenceClient
@@ -168,8 +198,30 @@ output = query({
168198
];
169199
};
170200

171-
export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet[] => {
201+
export const snippetBasic = (
202+
model: ModelDataMinimal,
203+
accessToken: string,
204+
provider: SnippetInferenceProvider
205+
): InferenceSnippet[] => {
172206
return [
207+
...(model.pipeline_tag && model.pipeline_tag in HFH_INFERENCE_CLIENT_METHODS
208+
? [
209+
{
210+
client: "huggingface_hub",
211+
content: `\
212+
${snippetImportInferenceClient(accessToken, provider)}
213+
214+
result = client.${HFH_INFERENCE_CLIENT_METHODS[model.pipeline_tag]}(
215+
model="${model.id}",
216+
inputs=${getModelInputSnippet(model)},
217+
provider="${provider}",
218+
)
219+
220+
print(result)
221+
`,
222+
},
223+
]
224+
: []),
173225
{
174226
client: "requests",
175227
content: `\
@@ -391,10 +443,7 @@ export function getPythonInferenceSnippet(
391443
? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? []
392444
: [];
393445

394-
const baseUrl =
395-
provider === "hf-inference"
396-
? `https://api-inference.huggingface.co/models/${model.id}`
397-
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
446+
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
398447

399448
return snippets.map((snippet) => {
400449
return {

0 commit comments

Comments
 (0)