@@ -3,12 +3,42 @@ import {
3
3
openAIbaseUrl ,
4
4
type SnippetInferenceProvider ,
5
5
} from "../inference-providers.js" ;
6
- import type { PipelineType } from "../pipelines.js" ;
6
+ import type { PipelineType , WidgetType } from "../pipelines.js" ;
7
7
import type { ChatCompletionInputMessage , GenerationParameters } from "../tasks/index.js" ;
8
8
import { stringifyGenerationConfig , stringifyMessages } from "./common.js" ;
9
9
import { getModelInputSnippet } from "./inputs.js" ;
10
10
import type { InferenceSnippet , ModelDataMinimal } from "./types.js" ;
11
11
12
+ const HFH_INFERENCE_CLIENT_METHODS : Partial < Record < WidgetType , string > > = {
13
+ "audio-classification" : "audio_classification" ,
14
+ "audio-to-audio" : "audio_to_audio" ,
15
+ "automatic-speech-recognition" : "automatic_speech_recognition" ,
16
+ "text-to-speech" : "text_to_speech" ,
17
+ "image-classification" : "image_classification" ,
18
+ "image-segmentation" : "image_segmentation" ,
19
+ "image-to-image" : "image_to_image" ,
20
+ "image-to-text" : "image_to_text" ,
21
+ "object-detection" : "object_detection" ,
22
+ "text-to-image" : "text_to_image" ,
23
+ "text-to-video" : "text_to_video" ,
24
+ "zero-shot-image-classification" : "zero_shot_image_classification" ,
25
+ "document-question-answering" : "document_question_answering" ,
26
+ "visual-question-answering" : "visual_question_answering" ,
27
+ "feature-extraction" : "feature_extraction" ,
28
+ "fill-mask" : "fill_mask" ,
29
+ "question-answering" : "question_answering" ,
30
+ "sentence-similarity" : "sentence_similarity" ,
31
+ summarization : "summarization" ,
32
+ "table-question-answering" : "table_question_answering" ,
33
+ "text-classification" : "text_classification" ,
34
+ "text-generation" : "text_generation" ,
35
+ "token-classification" : "token_classification" ,
36
+ translation : "translation" ,
37
+ "zero-shot-classification" : "zero_shot_classification" ,
38
+ "tabular-classification" : "tabular_classification" ,
39
+ "tabular-regression" : "tabular_regression" ,
40
+ } ;
41
+
12
42
const snippetImportInferenceClient = ( accessToken : string , provider : SnippetInferenceProvider ) : string =>
13
43
`\
14
44
from huggingface_hub import InferenceClient
@@ -168,8 +198,30 @@ output = query({
168
198
] ;
169
199
} ;
170
200
171
- export const snippetBasic = ( model : ModelDataMinimal ) : InferenceSnippet [ ] => {
201
+ export const snippetBasic = (
202
+ model : ModelDataMinimal ,
203
+ accessToken : string ,
204
+ provider : SnippetInferenceProvider
205
+ ) : InferenceSnippet [ ] => {
172
206
return [
207
+ ...( model . pipeline_tag && model . pipeline_tag in HFH_INFERENCE_CLIENT_METHODS
208
+ ? [
209
+ {
210
+ client : "huggingface_hub" ,
211
+ content : `\
212
+ ${ snippetImportInferenceClient ( accessToken , provider ) }
213
+
214
+ result = client.${ HFH_INFERENCE_CLIENT_METHODS [ model . pipeline_tag ] } (
215
+ model="${ model . id } ",
216
+ inputs=${ getModelInputSnippet ( model ) } ,
217
+ provider="${ provider } ",
218
+ )
219
+
220
+ print(result)
221
+ ` ,
222
+ } ,
223
+ ]
224
+ : [ ] ) ,
173
225
{
174
226
client : "requests" ,
175
227
content : `\
@@ -391,10 +443,7 @@ export function getPythonInferenceSnippet(
391
443
? pythonSnippets [ model . pipeline_tag ] ?.( model , accessToken , provider ) ?? [ ]
392
444
: [ ] ;
393
445
394
- const baseUrl =
395
- provider === "hf-inference"
396
- ? `https://api-inference.huggingface.co/models/${ model . id } `
397
- : HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , provider ) ;
446
+ const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE . replace ( "{{PROVIDER}}" , provider ) ;
398
447
399
448
return snippets . map ( ( snippet ) => {
400
449
return {
0 commit comments