Skip to content

Commit

Permalink
improve python snippets to add huggingface_hub
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Feb 5, 2025
1 parent d515d60 commit 20803ba
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
6 changes: 3 additions & 3 deletions packages/tasks/src/snippets/js.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { openAIbaseUrl, type SnippetInferenceProvider } from "../inference-providers.js";
import type { PipelineType } from "../pipelines.js";
import type { PipelineType, WidgetType } from "../pipelines.js";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

const HFJS_METHODS: Record<string, string> = {
const HFJS_METHODS: Partial<Record<WidgetType, string>> = {
"text-classification": "textClassification",
"token-classification": "tokenClassification",
"table-question-answering": "tableQuestionAnswering",
Expand Down Expand Up @@ -40,7 +40,7 @@ const output = await client.${HFJS_METHODS[model.pipeline_tag]}({
provider: "${provider}",
});
console.log(output)
console.log(output);
`,
},
]
Expand Down
61 changes: 55 additions & 6 deletions packages/tasks/src/snippets/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,42 @@ import {
openAIbaseUrl,
type SnippetInferenceProvider,
} from "../inference-providers.js";
import type { PipelineType } from "../pipelines.js";
import type { PipelineType, WidgetType } from "../pipelines.js";
import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks/index.js";
import { stringifyGenerationConfig, stringifyMessages } from "./common.js";
import { getModelInputSnippet } from "./inputs.js";
import type { InferenceSnippet, ModelDataMinimal } from "./types.js";

const HFH_INFERENCE_CLIENT_METHODS: Partial<Record<WidgetType, string>> = {
"audio-classification": "audio_classification",
"audio-to-audio": "audio_to_audio",
"automatic-speech-recognition": "automatic_speech_recognition",
"text-to-speech": "text_to_speech",
"image-classification": "image_classification",
"image-segmentation": "image_segmentation",
"image-to-image": "image_to_image",
"image-to-text": "image_to_text",
"object-detection": "object_detection",
"text-to-image": "text_to_image",
"text-to-video": "text_to_video",
"zero-shot-image-classification": "zero_shot_image_classification",
"document-question-answering": "document_question_answering",
"visual-question-answering": "visual_question_answering",
"feature-extraction": "feature_extraction",
"fill-mask": "fill_mask",
"question-answering": "question_answering",
"sentence-similarity": "sentence_similarity",
summarization: "summarization",
"table-question-answering": "table_question_answering",
"text-classification": "text_classification",
"text-generation": "text_generation",
"token-classification": "token_classification",
translation: "translation",
"zero-shot-classification": "zero_shot_classification",
"tabular-classification": "tabular_classification",
"tabular-regression": "tabular_regression",
};

const snippetImportInferenceClient = (accessToken: string, provider: SnippetInferenceProvider): string =>
`\
from huggingface_hub import InferenceClient
Expand Down Expand Up @@ -168,8 +198,30 @@ output = query({
];
};

export const snippetBasic = (model: ModelDataMinimal): InferenceSnippet[] => {
export const snippetBasic = (
model: ModelDataMinimal,
accessToken: string,
provider: SnippetInferenceProvider
): InferenceSnippet[] => {
return [
...(model.pipeline_tag && model.pipeline_tag in HFH_INFERENCE_CLIENT_METHODS
? [
{
client: "huggingface_hub",
content: `\
${snippetImportInferenceClient(accessToken, provider)}
result = client.${HFH_INFERENCE_CLIENT_METHODS[model.pipeline_tag]}(
model="${model.id}",
inputs=${getModelInputSnippet(model)},
provider="${provider}",
)
print(result)
`,
},
]
: []),
{
client: "requests",
content: `\
Expand Down Expand Up @@ -391,10 +443,7 @@ export function getPythonInferenceSnippet(
? pythonSnippets[model.pipeline_tag]?.(model, accessToken, provider) ?? []
: [];

const baseUrl =
provider === "hf-inference"
? `https://api-inference.huggingface.co/models/${model.id}`
: HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", provider);

return snippets.map((snippet) => {
return {
Expand Down

0 comments on commit 20803ba

Please sign in to comment.