Skip to content

Commit 62e314a

Browse files
saksham36julien-cSBrandeis
authored
Black Forest Labs Image Models (#1193)
Co-authored-by: Julien Chaumond <[email protected]> Co-authored-by: SBrandeis <[email protected]>
1 parent 57154a5 commit 62e314a

File tree

12 files changed

+190
-8
lines changed

12 files changed

+190
-8
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jobs:
4848
HF_TOGETHER_KEY: dummy
4949
HF_NOVITA_KEY: dummy
5050
HF_FIREWORKS_KEY: dummy
51+
HF_BLACK_FOREST_LABS_KEY: dummy
5152

5253
browser:
5354
runs-on: ubuntu-latest
@@ -91,6 +92,7 @@ jobs:
9192
HF_TOGETHER_KEY: dummy
9293
HF_NOVITA_KEY: dummy
9394
HF_FIREWORKS_KEY: dummy
95+
HF_BLACK_FOREST_LABS_KEY: dummy
9496

9597
e2e:
9698
runs-on: ubuntu-latest
@@ -161,3 +163,4 @@ jobs:
161163
HF_TOGETHER_KEY: dummy
162164
HF_NOVITA_KEY: dummy
163165
HF_FIREWORKS_KEY: dummy
166+
HF_BLACK_FOREST_LABS_KEY: dummy

packages/inference/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Currently, we support the following providers:
5454
- [Replicate](https://replicate.com)
5555
- [Sambanova](https://sambanova.ai)
5656
- [Together](https://together.xyz)
57+
- [Blackforestlabs](https://blackforestlabs.ai)
5758

5859
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
5960
```ts

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
66
import { TOGETHER_API_BASE_URL } from "../providers/together";
77
import { NOVITA_API_BASE_URL } from "../providers/novita";
88
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
9+
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
910
import type { InferenceProvider } from "../types";
1011
import type { InferenceTask, Options, RequestArgs } from "../types";
1112
import { isUrl } from "./isUrl";
@@ -80,8 +81,13 @@ export async function makeRequestOptions(
8081

8182
const headers: Record<string, string> = {};
8283
if (accessToken) {
83-
headers["Authorization"] =
84-
provider === "fal-ai" && authMethod === "provider-key" ? `Key ${accessToken}` : `Bearer ${accessToken}`;
84+
if (provider === "fal-ai" && authMethod === "provider-key") {
85+
headers["Authorization"] = `Key ${accessToken}`;
86+
} else if (provider === "black-forest-labs" && authMethod === "provider-key") {
87+
headers["X-Key"] = accessToken;
88+
} else {
89+
headers["Authorization"] = `Bearer ${accessToken}`;
90+
}
8591
}
8692

8793
// e.g. @huggingface/inference/3.1.3
@@ -148,6 +154,12 @@ function makeUrl(params: {
148154

149155
const shouldProxy = params.provider !== "hf-inference" && params.authMethod !== "provider-key";
150156
switch (params.provider) {
157+
case "black-forest-labs": {
158+
const baseUrl = shouldProxy
159+
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
160+
: BLACKFORESTLABS_AI_API_BASE_URL;
161+
return `${baseUrl}/${params.model}`;
162+
}
151163
case "fal-ai": {
152164
const baseUrl = shouldProxy
153165
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const BLACKFORESTLABS_AI_API_BASE_URL = "https://api.us1.bfl.ai/v1";
2+
3+
/**
4+
* See the registered mapping of HF model ID => Black Forest Labs model ID here:
5+
*
6+
* https://huggingface.co/api/partners/blackforestlabs/models
7+
*
8+
* This is a publicly available mapping.
9+
*
10+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
11+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
12+
*
13+
* - If you work at Black Forest Labs and want to update this mapping, please use the model mapping API we provide on huggingface.co
14+
* - If you're a community member and want to add a new supported HF model to Black Forest Labs, please open an issue on the present repo
15+
* and we will tag Black Forest Labs team members.
16+
*
17+
* Thanks!
18+
*/

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
1616
* Example:
1717
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
1818
*/
19+
"black-forest-labs": {},
1920
"fal-ai": {},
2021
"fireworks-ai": {},
2122
"hf-inference": {},

packages/inference/src/providers/novita.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ export const NOVITA_API_BASE_URL = "https://api.novita.ai/v3/openai";
1515
* and we will tag Novita team members.
1616
*
1717
* Thanks!
18-
*/
18+
*/

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
33
import type { BaseArgs, InferenceProvider, Options } from "../../types";
44
import { omit } from "../../utils/omit";
55
import { request } from "../custom/request";
6+
import { delay } from "../../utils/delay";
67

78
export type TextToImageArgs = BaseArgs & TextToImageInput;
89

@@ -14,6 +15,10 @@ interface Base64ImageGeneration {
1415
interface OutputUrlImageGeneration {
1516
output: string[];
1617
}
18+
interface BlackForestLabsResponse {
19+
id: string;
20+
polling_url: string;
21+
}
1722

1823
function getResponseFormatArg(provider: InferenceProvider) {
1924
switch (provider) {
@@ -44,12 +49,17 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
4449
...getResponseFormatArg(args.provider),
4550
prompt: args.inputs,
4651
};
47-
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
52+
const res = await request<
53+
TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse
54+
>(payload, {
4855
...options,
4956
taskHint: "text-to-image",
5057
});
5158

5259
if (res && typeof res === "object") {
60+
if (args.provider === "black-forest-labs" && "polling_url" in res && typeof res.polling_url === "string") {
61+
return await pollBflResponse(res.polling_url);
62+
}
5363
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
5464
const image = await fetch(res.images[0].url);
5565
return await image.blob();
@@ -72,3 +82,33 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
7282
}
7383
return res;
7484
}
85+
86+
async function pollBflResponse(url: string): Promise<Blob> {
87+
const urlObj = new URL(url);
88+
for (let step = 0; step < 5; step++) {
89+
await delay(1000);
90+
console.debug(`Polling Black Forest Labs API for the result... ${step + 1}/5`);
91+
urlObj.searchParams.set("attempt", step.toString(10));
92+
const resp = await fetch(urlObj, { headers: { "Content-Type": "application/json" } });
93+
if (!resp.ok) {
94+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
95+
}
96+
const payload = await resp.json();
97+
if (
98+
typeof payload === "object" &&
99+
payload &&
100+
"status" in payload &&
101+
typeof payload.status === "string" &&
102+
payload.status === "Ready" &&
103+
"result" in payload &&
104+
typeof payload.result === "object" &&
105+
payload.result &&
106+
"sample" in payload.result &&
107+
typeof payload.result.sample === "string"
108+
) {
109+
const image = await fetch(payload.result.sample);
110+
return await image.blob();
111+
}
112+
}
113+
throw new InferenceOutputError("Failed to fetch result from black forest labs API");
114+
}

packages/inference/src/types.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ export interface Options {
2929
export type InferenceTask = Exclude<PipelineType, "other">;
3030

3131
export const INFERENCE_PROVIDERS = [
32+
"black-forest-labs",
3233
"fal-ai",
3334
"fireworks-ai",
34-
"nebius",
3535
"hf-inference",
36+
"nebius",
37+
"novita",
3638
"replicate",
3739
"sambanova",
3840
"together",
39-
"novita",
4041
] as const;
4142

4243
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];

packages/inference/src/utils/delay.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
export function delay(ms: number): Promise<void> {
2+
return new Promise((resolve) => {
3+
setTimeout(() => resolve(), ms);
4+
});
5+
}

packages/inference/test/HfInference.spec.ts

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { assert, describe, expect, it } from "vitest";
22

33
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
44

5-
import { chatCompletion, HfInference } from "../src";
5+
import { chatCompletion, HfInference, textToImage } from "../src";
66
import { textToVideo } from "../src/tasks/cv/textToVideo";
77
import { readTestFile } from "./test-files";
88
import "./vcr";
@@ -1214,4 +1214,30 @@ describe.concurrent("HfInference", () => {
12141214
},
12151215
TIMEOUT
12161216
);
1217+
describe.concurrent(
1218+
"Black Forest Labs",
1219+
() => {
1220+
HARDCODED_MODEL_ID_MAPPING["black-forest-labs"] = {
1221+
"black-forest-labs/FLUX.1-dev": "flux-dev",
1222+
// "black-forest-labs/FLUX.1-schnell": "flux-pro",
1223+
};
1224+
1225+
it("textToImage", async () => {
1226+
const res = await textToImage({
1227+
model: "black-forest-labs/FLUX.1-dev",
1228+
provider: "black-forest-labs",
1229+
accessToken: env.HF_BLACK_FOREST_LABS_KEY,
1230+
inputs: "A raccoon driving a truck",
1231+
parameters: {
1232+
height: 256,
1233+
width: 256,
1234+
num_inference_steps: 4,
1235+
seed: 8817,
1236+
},
1237+
});
1238+
expect(res).toBeInstanceOf(Blob);
1239+
});
1240+
},
1241+
TIMEOUT
1242+
);
12171243
});

0 commit comments

Comments
 (0)