Skip to content

Commit

Permalink
tests: use task methods for better typing + match hf.js API
Browse files Browse the repository at this point in the history
  • Loading branch information
SBrandeis committed Feb 13, 2025
1 parent 6bdd200 commit de26ffa
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { assert, describe, expect, it } from "vitest";
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";

import type { TextToImageArgs } from "../src";
import { chatCompletion, HfInference } from "../src";
import { chatCompletion, chatCompletionStream, HfInference, textGeneration, textToImage } from "../src";
import { textToVideo } from "../src/tasks/cv/textToVideo";
import { readTestFile } from "./test-files";
import "./vcr";
Expand Down Expand Up @@ -1180,8 +1180,6 @@ describe.concurrent("HfInference", () => {
describe.concurrent(
"Hyperbolic",
() => {
const client = new HfInference(env.HF_HYPERBOLIC_KEY);

HARDCODED_MODEL_ID_MAPPING.hyperbolic = {
"meta-llama/Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct": "meta-llama/Llama-3.3-70B-Instruct",
Expand All @@ -1190,7 +1188,8 @@ describe.concurrent("HfInference", () => {
};

it("chatCompletion - hyperbolic", async () => {
const res = await client.chatCompletion({
const res = await chatCompletion({
accessToken: env.HF_HYPERBOLIC_KEY,
model: "meta-llama/Llama-3.2-3B-Instruct",
provider: "hyperbolic",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
Expand All @@ -1210,7 +1209,8 @@ describe.concurrent("HfInference", () => {
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
const stream = chatCompletionStream({
accessToken: env.HF_HYPERBOLIC_KEY,
model: "meta-llama/Llama-3.3-70B-Instruct",
provider: "hyperbolic",
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
Expand All @@ -1225,12 +1225,12 @@ describe.concurrent("HfInference", () => {
});

it("textToImage", async () => {
const res = await client.textToImage({
const res = await textToImage({
accessToken: env.HF_HYPERBOLIC_KEY,
model: "stabilityai/stable-diffusion-2",
provider: "hyperbolic",
inputs: "award winning high resolution photo of a giant tortoise",
parameters: {
model_name: "SD2",
height: 128,
width: 128,
},
Expand All @@ -1239,13 +1239,16 @@ describe.concurrent("HfInference", () => {
});

it("textGeneration", async () => {
const res = await client.textGeneration({
const res = await textGeneration({
accessToken: env.HF_HYPERBOLIC_KEY,
model: "meta-llama/Llama-3.1-405B",
provider: "hyperbolic",
messages: [{ role: "user", content: "Paris is" }],
temperature: 0,
top_p: 0.01,
max_tokens: 10,
inputs: "Paris is",
parameters: {
temperature: 0,
top_p: 0.01,
max_new_tokens: 10,
}
});
expect(res).toMatchObject({ generated_text: "...the capital and most populous city of France," });
});
Expand Down

0 comments on commit de26ffa

Please sign in to comment.