Skip to content

Commit

Permalink
Merge branch 'main' into add-link-to-local-chat
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavs10 authored Feb 13, 2025
2 parents 6bb5f4e + 62e314a commit 19989df
Show file tree
Hide file tree
Showing 14 changed files with 339 additions and 127 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ jobs:
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_NOVITA_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy

browser:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -88,7 +90,9 @@ jobs:
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_NOVITA_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy

e2e:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -157,4 +161,6 @@ jobs:
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_NOVITA_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ Currently, we support the following providers:
- [Fal.ai](https://fal.ai)
- [Fireworks AI](https://fireworks.ai)
- [Nebius](https://studio.nebius.ai)
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
- [Replicate](https://replicate.com)
- [Sambanova](https://sambanova.ai)
- [Together](https://together.xyz)
- [Blackforestlabs](https://blackforestlabs.ai)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
```ts
Expand Down
46 changes: 31 additions & 15 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import { NEBIUS_API_BASE_URL } from "../providers/nebius";
import { REPLICATE_API_BASE_URL } from "../providers/replicate";
import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL } from "../providers/together";
import { NOVITA_API_BASE_URL } from "../providers/novita";
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isUrl } from "./isUrl";
Expand All @@ -28,8 +30,6 @@ export async function makeRequestOptions(
stream?: boolean;
},
options?: Options & {
/** When a model can be used for multiple tasks, and we want to run a non-default task */
forceTask?: string | InferenceTask;
/** To load default model if needed */
taskHint?: InferenceTask;
chatCompletion?: boolean;
Expand All @@ -39,14 +39,11 @@ export async function makeRequestOptions(
let otherArgs = remainingArgs;
const provider = maybeProvider ?? "hf-inference";

const { forceTask, includeCredentials, taskHint, chatCompletion } = options ?? {};
const { includeCredentials, taskHint, chatCompletion } = options ?? {};

if (endpointUrl && provider !== "hf-inference") {
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
}
if (forceTask && provider !== "hf-inference") {
throw new Error(`Cannot use forceTask with a third-party provider.`);
}
if (maybeModel && isUrl(maybeModel)) {
throw new Error(`Model URLs are no longer supported. Use endpointUrl instead.`);
}
Expand Down Expand Up @@ -77,16 +74,20 @@ export async function makeRequestOptions(
: makeUrl({
authMethod,
chatCompletion: chatCompletion ?? false,
forceTask,
model,
provider: provider ?? "hf-inference",
taskHint,
});

const headers: Record<string, string> = {};
if (accessToken) {
headers["Authorization"] =
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
if (provider === "fal-ai" && authMethod === "provider-key") {
headers["Authorization"] = `Key ${accessToken}`;
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
headers["X-Key"] = accessToken;
} else {
headers["Authorization"] = `Bearer ${accessToken}`;
}
}

// e.g. @huggingface/inference/3.1.3
Expand Down Expand Up @@ -146,14 +147,19 @@ function makeUrl(params: {
model: string;
provider: InferenceProvider;
taskHint: InferenceTask | undefined;
forceTask?: string | InferenceTask;
}): string {
if (params.authMethod === "none" && params.provider !== "hf-inference") {
throw new Error("Authentication is required when requesting a third-party provider. Please provide accessToken");
}

const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
switch (params.provider) {
case "black-forest-labs": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: BLACKFORESTLABS_AI_API_BASE_URL;
return `${baseUrl}/${params.model}`;
}
case "fal-ai": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
Expand Down Expand Up @@ -213,6 +219,7 @@ function makeUrl(params: {
}
return baseUrl;
}

case "fireworks-ai": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
Expand All @@ -222,15 +229,24 @@ function makeUrl(params: {
}
return baseUrl;
}
case "novita": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: NOVITA_API_BASE_URL;
if (params.taskHint === "text-generation") {
if (params.chatCompletion) {
return `${baseUrl}/chat/completions`;
}
return `${baseUrl}/completions`;
}
return baseUrl;
}
default: {
const baseUrl = HF_HUB_INFERENCE_PROXY_TEMPLATE.replaceAll("{{PROVIDER}}", "hf-inference");
const url = params.forceTask
? `${baseUrl}/pipeline/${params.forceTask}/${params.model}`
: `${baseUrl}/models/${params.model}`;
if (params.taskHint === "text-generation" && params.chatCompletion) {
return url + `/v1/chat/completions`;
return `${baseUrl}/models/${params.model}/v1/chat/completions`;
}
return url;
return `${baseUrl}/models/${params.model}`;
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions packages/inference/src/providers/black-forest-labs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";

/**
* See the registered mapping of HF model ID => Black Forest Labs model ID here:
*
* https://huggingface.co/api/partners/blackforestlabs/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Black Forest Labs and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Black Forest Labs, please open an issue on the present repo
* and we will tag Black Forest Labs team members.
*
* Thanks!
*/
2 changes: 2 additions & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
* Example:
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
"black-forest-labs": {},
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
nebius: {},
replicate: {},
sambanova: {},
together: {},
novita: {},
};
18 changes: 18 additions & 0 deletions packages/inference/src/providers/novita.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";

/**
* See the registered mapping of HF model ID => Novita model ID here:
*
* https://huggingface.co/api/partners/novita/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Novita and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Novita, please open an issue on the present repo
* and we will tag Novita team members.
*
* Thanks!
*/
42 changes: 41 additions & 1 deletion packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, InferenceProvider, Options } from "../../types";
import { omit } from "../../utils/omit";
import { request } from "../custom/request";
import { delay } from "../../utils/delay";

export type TextToImageArgs = BaseArgs & TextToImageInput;

Expand All @@ -14,6 +15,10 @@ interface Base64ImageGeneration {
interface OutputUrlImageGeneration {
output: string[];
}
interface BlackForestLabsResponse {
id: string;
polling_url: string;
}

function getResponseFormatArg(provider: InferenceProvider) {
switch (provider) {
Expand Down Expand Up @@ -44,12 +49,17 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
...getResponseFormatArg(args.provider),
prompt: args.inputs,
};
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
const res = await request<
TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse
>(payload, {
...options,
taskHint: "text-to-image",
});

if (res && typeof res === "object") {
if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
return await pollBflResponse(res.polling_url);
}
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
const image = await fetch(res.images[0].url);
return await image.blob();
Expand All @@ -72,3 +82,33 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
}
return res;
}

async function pollBflResponse(url: string): Promise<Blob> {
const urlObj = new URL(url);
for (let step = 0; step < 5; step++) {
await delay(1000);
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
urlObj.searchParams.set("attempt", step.toString(10));
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
if (!resp.ok) {
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
}
const payload = await resp.json();
if (
typeof payload === "object" &&
payload &&
"status" in payload &&
typeof payload.status === "string" &&
payload.status === "Ready" &&
"result" in payload &&
typeof payload.result === "object" &&
payload.result &&
"sample" in payload.result &&
typeof payload.result.sample === "string"
) {
const image = await fetch(payload.result.sample);
return await image.blob();
}
}
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
}
4 changes: 0 additions & 4 deletions packages/inference/src/tasks/nlp/featureExtraction.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { getDefaultTask } from "../../lib/getDefaultTask";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

Expand All @@ -25,12 +24,9 @@ export async function featureExtraction(
args: FeatureExtractionArgs,
options?: Options
): Promise<FeatureExtractionOutput> {
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;

const res = await request<FeatureExtractionOutput>(args, {
...options,
taskHint: "feature-extraction",
...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }),
});
let isValidOutput = true;

Expand Down
3 changes: 0 additions & 3 deletions packages/inference/src/tasks/nlp/sentenceSimilarity.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import type { SentenceSimilarityInput, SentenceSimilarityOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { getDefaultTask } from "../../lib/getDefaultTask";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { omit } from "../../utils/omit";
Expand All @@ -14,11 +13,9 @@ export async function sentenceSimilarity(
args: SentenceSimilarityArgs,
options?: Options
): Promise<SentenceSimilarityOutput> {
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
const res = await request<SentenceSimilarityOutput>(prepareInput(args), {
...options,
taskHint: "sentence-similarity",
...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }),
});

const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
Expand Down
5 changes: 4 additions & 1 deletion packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@ export interface Options {
export type InferenceTask = Exclude<PipelineType, "other">;

export const INFERENCE_PROVIDERS = [
"black-forest-labs",
"fal-ai",
"fireworks-ai",
"nebius",
"hf-inference",
"nebius",
"novita",
"replicate",
"sambanova",
"together",
] as const;

export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

export interface BaseArgs {
Expand Down
5 changes: 5 additions & 0 deletions packages/inference/src/utils/delay.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export function delay(ms: number): Promise<void> {
return new Promise((resolve) => {
setTimeout(() => resolve(), ms);
});
}
Loading

0 comments on commit 19989df

Please sign in to comment.