-
-
Notifications
You must be signed in to change notification settings - Fork 23.1k
fix: Upgrade Hugging Face Inference API to support Inference Providers #5454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,9 +56,9 @@ | |
| this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY') | ||
| this.endpointUrl = fields?.endpointUrl | ||
| this.includeCredentials = fields?.includeCredentials | ||
| if (!this.apiKey) { | ||
| if (!this.apiKey || this.apiKey.trim() === '') { | ||
| throw new Error( | ||
| 'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.' | ||
| 'Please set an API key for HuggingFace Hub. Either configure it in the credential settings in the UI, or set the environment variable HUGGINGFACEHUB_API_KEY.' | ||
| ) | ||
| } | ||
| } | ||
|
|
@@ -68,71 +68,120 @@ | |
| } | ||
|
|
||
| invocationParams(options?: this['ParsedCallOptions']) { | ||
| return { | ||
| model: this.model, | ||
| parameters: { | ||
| // make it behave similar to openai, returning only the generated text | ||
| return_full_text: false, | ||
| temperature: this.temperature, | ||
| max_new_tokens: this.maxTokens, | ||
| stop: options?.stop ?? this.stopSequences, | ||
| top_p: this.topP, | ||
| top_k: this.topK, | ||
| repetition_penalty: this.frequencyPenalty | ||
| } | ||
| // Return parameters compatible with chatCompletion API (OpenAI-compatible format) | ||
| const params: any = { | ||
| temperature: this.temperature, | ||
| max_tokens: this.maxTokens, | ||
| stop: options?.stop ?? this.stopSequences, | ||
| top_p: this.topP | ||
| } | ||
| // Include optional parameters if they are defined | ||
| if (this.topK !== undefined) { | ||
| params.top_k = this.topK | ||
| } | ||
| if (this.frequencyPenalty !== undefined) { | ||
| params.frequency_penalty = this.frequencyPenalty | ||
| } | ||
| return params | ||
|
Comment on lines
+72
to
+85
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The handling of optional parameters is inconsistent. const params: any = {};
if (this.temperature !== undefined) params.temperature = this.temperature;
if (this.maxTokens !== undefined) params.max_tokens = this.maxTokens;
if (this.topP !== undefined) params.top_p = this.topP;
if (this.topK !== undefined) params.top_k = this.topK;
if (this.frequencyPenalty !== undefined) params.frequency_penalty = this.frequencyPenalty;
const stop = options?.stop ?? this.stopSequences;
if (stop) params.stop = stop;
return params; |
||
| } | ||
|
|
||
| async *_streamResponseChunks( | ||
| prompt: string, | ||
| options: this['ParsedCallOptions'], | ||
| runManager?: CallbackManagerForLLMRun | ||
| ): AsyncGenerator<GenerationChunk> { | ||
| const hfi = await this._prepareHFInference() | ||
| const stream = await this.caller.call(async () => | ||
| hfi.textGenerationStream({ | ||
| ...this.invocationParams(options), | ||
| inputs: prompt | ||
| }) | ||
| ) | ||
| for await (const chunk of stream) { | ||
| const token = chunk.token.text | ||
| yield new GenerationChunk({ text: token, generationInfo: chunk }) | ||
| await runManager?.handleLLMNewToken(token ?? '') | ||
|
|
||
| // stream is done | ||
| if (chunk.generated_text) | ||
| yield new GenerationChunk({ | ||
| text: '', | ||
| generationInfo: { finished: true } | ||
| try { | ||
| const client = await this._prepareHFInference() | ||
| // Use chatCompletionStream for chat models (v4 supports streaming via Inference Providers) | ||
| const stream = await this.caller.call(async () => | ||
| client.chatCompletionStream({ | ||
| model: this.model, | ||
| messages: [{ role: 'user', content: prompt }], | ||
| ...this.invocationParams(options) | ||
| }) | ||
| ) | ||
| for await (const chunk of stream) { | ||
| const token = chunk.choices[0]?.delta?.content || '' | ||
| if (token) { | ||
| yield new GenerationChunk({ text: token, generationInfo: chunk }) | ||
| await runManager?.handleLLMNewToken(token) | ||
| } | ||
| // stream is done when finish_reason is set | ||
| if (chunk.choices[0]?.finish_reason) { | ||
| yield new GenerationChunk({ | ||
| text: '', | ||
| generationInfo: { finished: true } | ||
| }) | ||
| break | ||
| } | ||
| } | ||
| } catch (error: any) { | ||
| // Provide more helpful error messages | ||
| if (error?.message?.includes('endpointUrl') || error?.message?.includes('third-party provider')) { | ||
| throw new Error(`Cannot use custom endpoint with model "${this.model}" that includes a provider. Please leave the Endpoint field blank in the UI. Original error: ${error.message}`) | ||
|
Check failure on line 121 in packages/components/nodes/chatmodels/ChatHuggingFace/core.ts
|
||
| } | ||
| throw error | ||
| } | ||
| } | ||
|
|
||
| /** @ignore */ | ||
| async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> { | ||
| const hfi = await this._prepareHFInference() | ||
| const args = { ...this.invocationParams(options), inputs: prompt } | ||
| const res = await this.caller.callWithOptions({ signal: options.signal }, hfi.textGeneration.bind(hfi), args) | ||
| return res.generated_text | ||
| try { | ||
| const client = await this._prepareHFInference() | ||
| // Use chatCompletion for chat models (v4 supports conversational models via Inference Providers) | ||
| const args = { | ||
| model: this.model, | ||
| messages: [{ role: 'user', content: prompt }], | ||
| ...this.invocationParams(options) | ||
| } | ||
| const res = await this.caller.callWithOptions({ signal: options.signal }, client.chatCompletion.bind(client), args) | ||
| const content = res.choices[0]?.message?.content || '' | ||
| if (!content) { | ||
| console.error('[ChatHuggingFace] No content in response:', JSON.stringify(res)) | ||
| throw new Error(`No content received from HuggingFace API. Response: ${JSON.stringify(res)}`) | ||
| } | ||
| return content | ||
| } catch (error: any) { | ||
| console.error('[ChatHuggingFace] Error in _call:', error.message) | ||
| // Provide more helpful error messages | ||
| if (error?.message?.includes('endpointUrl') || error?.message?.includes('third-party provider')) { | ||
| throw new Error(`Cannot use custom endpoint with model "${this.model}" that includes a provider. Please leave the Endpoint field blank in the UI. Original error: ${error.message}`) | ||
|
Check failure on line 148 in packages/components/nodes/chatmodels/ChatHuggingFace/core.ts
|
||
| } | ||
| if (error?.message?.includes('Invalid username or password') || error?.message?.includes('authentication')) { | ||
| throw new Error(`HuggingFace API authentication failed. Please verify your API key is correct and starts with "hf_". Original error: ${error.message}`) | ||
|
Check failure on line 151 in packages/components/nodes/chatmodels/ChatHuggingFace/core.ts
|
||
| } | ||
| throw error | ||
| } | ||
| } | ||
|
|
||
| /** @ignore */ | ||
| private async _prepareHFInference() { | ||
| const { HfInference } = await HuggingFaceInference.imports() | ||
| const hfi = new HfInference(this.apiKey, { | ||
| includeCredentials: this.includeCredentials | ||
| }) | ||
| return this.endpointUrl ? hfi.endpoint(this.endpointUrl) : hfi | ||
| if (!this.apiKey || this.apiKey.trim() === '') { | ||
| console.error('[ChatHuggingFace] API key validation failed: Empty or undefined') | ||
| throw new Error('HuggingFace API key is required. Please configure it in the credential settings.') | ||
| } | ||
|
|
||
| const { InferenceClient } = await HuggingFaceInference.imports() | ||
| // Use InferenceClient for chat models (works better with Inference Providers) | ||
| const client = new InferenceClient(this.apiKey) | ||
|
|
||
| // Don't override endpoint if model uses a provider (contains ':') or if endpoint is router-based | ||
| // When using Inference Providers, endpoint should be left blank - InferenceClient handles routing automatically | ||
| if (this.endpointUrl && !this.model.includes(':') && !this.endpointUrl.includes('/v1/chat/completions') && !this.endpointUrl.includes('router.huggingface.co')) { | ||
|
Check failure on line 170 in packages/components/nodes/chatmodels/ChatHuggingFace/core.ts
|
||
| return client.endpoint(this.endpointUrl) | ||
| } | ||
|
|
||
| // Return client without endpoint override - InferenceClient will use Inference Providers automatically | ||
| return client | ||
| } | ||
|
|
||
| /** @ignore */ | ||
| static async imports(): Promise<{ | ||
| HfInference: typeof import('@huggingface/inference').HfInference | ||
| InferenceClient: typeof import('@huggingface/inference').InferenceClient | ||
| }> { | ||
| try { | ||
| const { HfInference } = await import('@huggingface/inference') | ||
| return { HfInference } | ||
| const { InferenceClient } = await import('@huggingface/inference') | ||
| return { InferenceClient } | ||
| } catch (e) { | ||
| throw new Error('Please install huggingface as a dependency with, e.g. `pnpm install @huggingface/inference`') | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.