Skip to content

[Inference] Improve error handling #1504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,108 @@ await textGeneration({

This will enable tree-shaking by your bundler.

### Error handling

The inference package provides specific error types to help you handle different error scenarios effectively.

#### Error Types

The package defines several error types that extend the base `Error` class:

- `InferenceClientError`: Base error class for all Hugging Face Inference errors
- `InferenceClientInputError`: Thrown when there are issues with input parameters
- `InferenceClientProviderApiError`: Thrown when there are API-level errors from providers
- `InferenceClientHubApiError`: Thrown when there are API-levels errors from the Hugging Face Hub
- `InferenceClientProviderOutputError`: Thrown when there are issues with providers' API responses format

### Example Usage

```typescript
import { InferenceClient } from "@huggingface/inference";
import {
InferenceClientError,
InferenceClientProviderApiError,
InferenceClientProviderOutputError,
InferenceClientHubApiError,
} from "@huggingface/inference";

const client = new InferenceClient();

try {
const result = await client.textGeneration({
model: "gpt2",
inputs: "Hello, I'm a language model",
});
} catch (error) {
if (error instanceof InferenceClientProviderApiError) {
// Handle API errors (e.g., rate limits, authentication issues)
console.error("Provider API Error:", error.message);
console.error("HTTP Request details:", error.request);
console.error("HTTP Response details:", error.response);
if (error instanceof InferenceClientHubApiError) {
// Handle API errors (e.g., rate limits, authentication issues)
console.error("Hub API Error:", error.message);
console.error("HTTP Request details:", error.request);
console.error("HTTP Response details:", error.response);
} else if (error instanceof InferenceClientProviderOutputError) {
// Handle malformed responses from providers
console.error("Provider Output Error:", error.message);
} else if (error instanceof InferenceClientInputError) {
// Handle invalid input parameters
console.error("Input Error:", error.message);
} else {
// Handle unexpected errors
console.error("Unexpected error:", error);
}
}

/// Catch all errors from @huggingface/inference
try {
const result = await client.textGeneration({
model: "gpt2",
inputs: "Hello, I'm a language model",
});
} catch (error) {
if (error instanceof InferenceClientError) {
// Handle errors from @huggingface/inference
console.error("Error from InferenceClient:", error);
} else {
// Handle unexpected errors
console.error("Unexpected error:", error);
}
}
```

### Error Details

#### InferenceClientProviderApiError

This error occurs when there are issues with the API request when performing inference at the selected provider.

It has several properties:
- `message`: A descriptive error message
- `request`: Details about the failed request (URL, method, headers)
- `response`: Response details including status code and body

#### InferenceClientHubApiError

This error occurs when there are issues with the API request when requesting the Hugging Face Hub API.

It has several properties:
- `message`: A descriptive error message
- `request`: Details about the failed request (URL, method, headers)
- `response`: Response details including status code and body


#### InferenceClientProviderOutputError

This error occurs when a provider returns a response in an unexpected format.

#### InferenceClientInputError

This error occurs when input parameters are invalid or missing. The error message describes what's wrong with the input.


### Natural Language Processing

#### Text Generation
Expand Down
82 changes: 82 additions & 0 deletions packages/inference/src/errors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import type { JsonObject } from "./vendor/type-fest/basic.js";

/**
* Base class for all inference-related errors.
*/
export abstract class InferenceClientError extends Error {
constructor(message: string) {
super(message);
this.name = "InferenceClientError";
}
}

export class InferenceClientInputError extends InferenceClientError {
constructor(message: string) {
super(message);
this.name = "InputError";
}
}

interface HttpRequest {
url: string;
method: string;
headers?: Record<string, string>;
body?: JsonObject;
}

interface HttpResponse {
requestId: string;
status: number;
body: JsonObject | string;
}

abstract class InferenceClientHttpRequestError extends InferenceClientError {
httpRequest: HttpRequest;
httpResponse: HttpResponse;
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
super(message);
this.httpRequest = {
...httpRequest,
...(httpRequest.headers
? {
headers: {
...httpRequest.headers,
...("Authorization" in httpRequest.headers ? { Authorization: `Bearer [redacted]` } : undefined),
/// redact authentication in the request headers
},
}
: undefined),
};
this.httpResponse = httpResponse;
}
}

/**
* Thrown when the HTTP request to the provider fails, e.g. due to API issues or server errors.
*/
export class InferenceClientProviderApiError extends InferenceClientHttpRequestError {
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
super(message, httpRequest, httpResponse);
this.name = "ProviderApiError";
}
}

/**
* Thrown when the HTTP request to the hub fails, e.g. due to API issues or server errors.
*/
export class InferenceClientHubApiError extends InferenceClientHttpRequestError {
constructor(message: string, httpRequest: HttpRequest, httpResponse: HttpResponse) {
super(message, httpRequest, httpResponse);
this.name = "HubApiError";
}
}

/**
* Thrown when the inference output returned by the provider is invalid / does not match the expectations
*/
export class InferenceClientProviderOutputError extends InferenceClientError {
constructor(message: string) {
super(message);
this.name = "ProviderOutputError";
}
}
2 changes: 1 addition & 1 deletion packages/inference/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export { InferenceClient, InferenceClientEndpoint, HfInference } from "./InferenceClient.js";
export { InferenceOutputError } from "./lib/InferenceOutputError.js";
export * from "./errors.js";
export * from "./types.js";
export * from "./tasks/index.js";
import * as snippets from "./snippets/index.js";
Expand Down
8 changes: 0 additions & 8 deletions packages/inference/src/lib/InferenceOutputError.ts

This file was deleted.

64 changes: 42 additions & 22 deletions packages/inference/src/lib/getInferenceProviderMapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../providers/consts.js";
import { EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS } from "../providers/hf-inference.js";
import type { InferenceProvider, InferenceProviderOrPolicy, ModelId } from "../types.js";
import { typedInclude } from "../utils/typedInclude.js";
import { InferenceClientHubApiError, InferenceClientInputError } from "../errors.js";

export const inferenceProviderMappingCache = new Map<ModelId, InferenceProviderMapping>();

Expand Down Expand Up @@ -32,27 +33,46 @@ export async function fetchInferenceProviderMappingForModel(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
inferenceProviderMapping = inferenceProviderMappingCache.get(modelId)!;
} else {
const resp = await (options?.fetch ?? fetch)(
`${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`,
{
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
const url = `${HF_HUB_URL}/api/models/${modelId}?expand[]=inferenceProviderMapping`;
const resp = await (options?.fetch ?? fetch)(url, {
headers: accessToken?.startsWith("hf_") ? { Authorization: `Bearer ${accessToken}` } : {},
});
if (!resp.ok) {
if (resp.headers.get("Content-Type")?.startsWith("application/json")) {
const error = await resp.json();
if ("error" in error && typeof error.error === "string") {
throw new InferenceClientHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}: ${error.error}`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: error }
);
}
} else {
throw new InferenceClientHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
);
if (resp.status === 404) {
throw new Error(`Model ${modelId} does not exist`);
}
inferenceProviderMapping = await resp
.json()
.then((json) => json.inferenceProviderMapping)
.catch(() => null);

if (inferenceProviderMapping) {
inferenceProviderMappingCache.set(modelId, inferenceProviderMapping);
let payload: { inferenceProviderMapping?: InferenceProviderMapping } | null = null;
try {
payload = await resp.json();
} catch {
throw new InferenceClientHubApiError(
`Failed to fetch inference provider mapping for model ${modelId}: malformed API response, invalid JSON`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
}

if (!inferenceProviderMapping) {
throw new Error(`We have not been able to find inference provider information for model ${modelId}.`);
if (!payload?.inferenceProviderMapping) {
throw new InferenceClientHubApiError(
`We have not been able to find inference provider information for model ${modelId}.`,
{ url, method: "GET" },
{ requestId: resp.headers.get("x-request-id") ?? "", status: resp.status, body: await resp.text() }
);
}
inferenceProviderMapping = payload.inferenceProviderMapping;
}
return inferenceProviderMapping;
}
Expand Down Expand Up @@ -83,7 +103,7 @@ export async function getInferenceProviderMapping(
? EQUIVALENT_SENTENCE_TRANSFORMERS_TASKS
: [params.task];
if (!typedInclude(equivalentTasks, providerMapping.task)) {
throw new Error(
throw new InferenceClientInputError(
`Model ${params.modelId} is not supported for task ${params.task} and provider ${params.provider}. Supported task: ${providerMapping.task}.`
);
}
Expand All @@ -104,7 +124,7 @@ export async function resolveProvider(
): Promise<InferenceProvider> {
if (endpointUrl) {
if (provider) {
throw new Error("Specifying both endpointUrl and provider is not supported.");
throw new InferenceClientInputError("Specifying both endpointUrl and provider is not supported.");
}
/// Defaulting to hf-inference helpers / API
return "hf-inference";
Expand All @@ -117,13 +137,13 @@ export async function resolveProvider(
}
if (provider === "auto") {
if (!modelId) {
throw new Error("Specifying a model is required when provider is 'auto'");
throw new InferenceClientInputError("Specifying a model is required when provider is 'auto'");
}
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
}
if (!provider) {
throw new Error(`No Inference Provider available for model ${modelId}.`);
throw new InferenceClientInputError(`No Inference Provider available for model ${modelId}.`);
}
return provider;
}
11 changes: 8 additions & 3 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import * as Replicate from "../providers/replicate.js";
import * as Sambanova from "../providers/sambanova.js";
import * as Together from "../providers/together.js";
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js";
import { InferenceClientInputError } from "../errors.js";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
"black-forest-labs": {
Expand Down Expand Up @@ -281,14 +282,18 @@ export function getProviderHelper(
return new HFInference.HFInferenceTask();
}
if (!task) {
throw new Error("you need to provide a task name when using an external provider, e.g. 'text-to-image'");
throw new InferenceClientInputError(
"you need to provide a task name when using an external provider, e.g. 'text-to-image'"
);
}
if (!(provider in PROVIDERS)) {
throw new Error(`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`);
throw new InferenceClientInputError(
`Provider '${provider}' not supported. Available providers: ${Object.keys(PROVIDERS)}`
);
}
const providerTasks = PROVIDERS[provider];
if (!providerTasks || !(task in providerTasks)) {
throw new Error(
throw new InferenceClientInputError(
`Task '${task}' not supported for provider '${provider}'. Available tasks: ${Object.keys(providerTasks ?? {})}`
);
}
Expand Down
Loading