Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft: add hyperbolic support #1191

Merged
merged 21 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ jobs:
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_HYPERBOLIC_KEY: dummy
HF_NEBIUS_KEY: dummy
HF_NOVITA_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_NOVITA_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy

browser:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -85,14 +86,15 @@ jobs:
run: VCR_MODE=playback pnpm --filter ...[${{ steps.since.outputs.SINCE }}] test:browser
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_HYPERBOLIC_KEY: dummy
HF_NEBIUS_KEY: dummy
HF_NOVITA_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_NOVITA_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy

e2e:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -156,11 +158,12 @@ jobs:
env:
NPM_CONFIG_REGISTRY: http://localhost:4874/
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_BLACK_FOREST_LABS_KEY: dummy
HF_FAL_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_HYPERBOLIC_KEY: dummy
HF_NEBIUS_KEY: dummy
HF_NOVITA_KEY: dummy
HF_REPLICATE_KEY: dummy
HF_SAMBANOVA_KEY: dummy
HF_TOGETHER_KEY: dummy
HF_NOVITA_KEY: dummy
HF_FIREWORKS_KEY: dummy
HF_BLACK_FOREST_LABS_KEY: dummy
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ You can send inference requests to third-party providers with the inference clie
Currently, we support the following providers:
- [Fal.ai](https://fal.ai)
- [Fireworks AI](https://fireworks.ai)
- [Hyperbolic](https://hyperbolic.xyz)
- [Nebius](https://studio.nebius.ai)
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
- [Replicate](https://replicate.com)
Expand All @@ -74,6 +75,7 @@ When authenticated with a third-party provider key, the request is made directly
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
- [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models)
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
Expand Down
17 changes: 16 additions & 1 deletion packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { SAMBANOVA_API_BASE_URL } from "../providers/sambanova";
import { TOGETHER_API_BASE_URL } from "../providers/together";
import { NOVITA_API_BASE_URL } from "../providers/novita";
import { FIREWORKS_AI_API_BASE_URL } from "../providers/fireworks-ai";
import { HYPERBOLIC_API_BASE_URL } from "../providers/hyperbolic";
import { BLACKFORESTLABS_AI_API_BASE_URL } from "../providers/black-forest-labs";
import type { InferenceProvider } from "../types";
import type { InferenceTask, Options, RequestArgs } from "../types";
Expand Down Expand Up @@ -132,7 +133,11 @@ export async function makeRequestOptions(
? args.data
: JSON.stringify({
...otherArgs,
...(chatCompletion || provider === "together" || provider === "nebius" ? { model } : undefined),
...(taskHint === "text-to-image" && provider === "hyperbolic"
? { model_name: model }
: chatCompletion || provider === "together" || provider === "nebius" || provider === "hyperbolic"
? { model }
: undefined),
}),
...(credentials ? { credentials } : undefined),
signal: options?.signal,
Expand Down Expand Up @@ -229,6 +234,16 @@ function makeUrl(params: {
}
return baseUrl;
}
case "hyperbolic": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
: HYPERBOLIC_API_BASE_URL;

if (params.taskHint === "text-to-image") {
return `${baseUrl}/v1/images/generations`;
}
return `${baseUrl}/v1/chat/completions`;
}
case "novita": {
const baseUrl = shouldProxy
? HF_HUB_INFERENCE_PROXY_TEMPLATE.replace("{{PROVIDER}}", params.provider)
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
"fal-ai": {},
"fireworks-ai": {},
"hf-inference": {},
hyperbolic: {},
nebius: {},
replicate: {},
sambanova: {},
Expand Down
18 changes: 18 additions & 0 deletions packages/inference/src/providers/hyperbolic.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export const HYPERBOLIC_API_BASE_URL = "https://api.hyperbolic.xyz";

/**
* See the registered mapping of HF model ID => Hyperbolic model ID here:
*
* https://huggingface.co/api/partners/hyperbolic/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Hyperbolic and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Hyperbolic, please open an issue on the present repo
* and we will tag Hyperbolic team members.
*
* Thanks!
*/
21 changes: 20 additions & 1 deletion packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ interface Base64ImageGeneration {
interface OutputUrlImageGeneration {
output: string[];
}
interface HyperbolicTextToImageOutput {
images: Array<{ image: string }>;
}

interface BlackForestLabsResponse {
id: string;
polling_url: string;
Expand Down Expand Up @@ -50,7 +54,11 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
prompt: args.inputs,
};
const res = await request<
TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration | BlackForestLabsResponse
| TextToImageOutput
| Base64ImageGeneration
| OutputUrlImageGeneration
| BlackForestLabsResponse
| HyperbolicTextToImageOutput
>(payload, {
...options,
taskHint: "text-to-image",
Expand All @@ -64,6 +72,17 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
const image = await fetch(res.images[0].url);
return await image.blob();
}
if (
args.provider === "hyperbolic" &&
"images" in res &&
Array.isArray(res.images) &&
res.images[0] &&
typeof res.images[0].image === "string"
) {
const base64Response = await fetch(`data:image/jpeg;base64,${res.images[0].image}`);
const blob = await base64Response.blob();
return blob;
}
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
const base64Data = res.data[0].b64_json;
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
Expand Down
31 changes: 31 additions & 0 deletions packages/inference/src/tasks/nlp/textGeneration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { toArray } from "../../utils/toArray";
import { request } from "../custom/request";
import { omit } from "../../utils/omit";

export type { TextGenerationInput, TextGenerationOutput };

Expand All @@ -21,6 +22,12 @@ interface TogeteherTextCompletionOutput extends Omit<ChatCompletionOutput, "choi
}>;
}

interface HyperbolicTextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
choices: Array<{
message: { content: string };
}>;
}

/**
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
*/
Expand All @@ -43,6 +50,30 @@ export async function textGeneration(
return {
generated_text: completion.text,
};
} else if (args.provider === "hyperbolic") {
const payload = {
messages: [{ content: args.inputs, role: "user" }],
...(args.parameters
? {
max_tokens: args.parameters.max_new_tokens,
...omit(args.parameters, "max_new_tokens"),
}
: undefined),
...omit(args, ["inputs", "parameters"]),
};
const raw = await request<HyperbolicTextCompletionOutput>(payload, {
...options,
taskHint: "text-generation",
});
const isValidOutput =
typeof raw === "object" && "choices" in raw && Array.isArray(raw?.choices) && typeof raw?.model === "string";
if (!isValidOutput) {
throw new InferenceOutputError("Expected ChatCompletionOutput");
}
const completion = raw.choices[0];
return {
generated_text: completion.message.content,
};
} else {
const res = toArray(
await request<TextGenerationOutput | TextGenerationOutput[]>(args, {
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export const INFERENCE_PROVIDERS = [
"fal-ai",
"fireworks-ai",
"hf-inference",
"hyperbolic",
"nebius",
"novita",
"replicate",
Expand Down
82 changes: 81 additions & 1 deletion packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import { assert, describe, expect, it } from "vitest";

import type { ChatCompletionStreamOutput } from "@huggingface/tasks";

import { chatCompletion, HfInference, textToImage } from "../src";
import type { TextToImageArgs } from "../src";
import { chatCompletion, chatCompletionStream, HfInference, textGeneration, textToImage } from "../src";
import { textToVideo } from "../src/tasks/cv/textToVideo";
import { readTestFile } from "./test-files";
import "./vcr";
Expand Down Expand Up @@ -1176,6 +1177,85 @@ describe.concurrent("HfInference", () => {
TIMEOUT
);

describe.concurrent(
"Hyperbolic",
() => {
HARDCODED_MODEL_ID_MAPPING.hyperbolic = {
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct",
"stabilityai/stable-diffusion-2": "SD2",
"meta-llama/Llama-3.1-405B": "meta-llama/Meta-Llama-3.1-405B-Instruct",
};

it("chatCompletion - hyperbolic", async () => {
const res = await chatCompletion({
accessToken: env.HF_HYPERBOLIC_KEY,
model: "meta-llama/Llama-3.2-3B-Instruct",
provider: "hyperbolic",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
temperature: 0.1,
});

expect(res).toBeDefined();
expect(res.choices).toBeDefined();
expect(res.choices?.length).toBeGreaterThan(0);

if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = chatCompletionStream({
accessToken: env.HF_HYPERBOLIC_KEY,
model: "meta-llama/Llama-3.3-70B-Instruct",
provider: "hyperbolic",
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
expect(out).toContain("2");
});

it("textToImage", async () => {
const res = await textToImage({
accessToken: env.HF_HYPERBOLIC_KEY,
model: "stabilityai/stable-diffusion-2",
provider: "hyperbolic",
inputs: "award winning high resolution photo of a giant tortoise",
parameters: {
height: 128,
width: 128,
},
} satisfies TextToImageArgs);
expect(res).toBeInstanceOf(Blob);
});

it("textGeneration", async () => {
const res = await textGeneration({
accessToken: env.HF_HYPERBOLIC_KEY,
model: "meta-llama/Llama-3.1-405B",
provider: "hyperbolic",
inputs: "Paris is",
parameters: {
temperature: 0,
top_p: 0.01,
max_new_tokens: 10,
},
});
expect(res).toMatchObject({ generated_text: "...the capital and most populous city of France," });
});
},
TIMEOUT
);

describe.concurrent(
"Novita",
() => {
Expand Down
Loading