Skip to content

[Inference Snippet] Add a directRequest option (false by default) #1516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions packages/inference/src/snippets/getInferenceSnippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.j
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js";
import { templates } from "./templates.exported.js";

export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string; accessToken?: string } & Record<
string,
unknown
>;
export type InferenceSnippetOptions = {
streaming?: boolean;
billTo?: string;
accessToken?: string;
directRequest?: boolean;
} & Record<string, unknown>;

const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
Expand Down Expand Up @@ -124,7 +126,10 @@ const HF_JS_METHODS: Partial<Record<WidgetType, string>> = {
translation: "translation",
};

const ACCESS_TOKEN_PLACEHOLDER = "<ACCESS_TOKEN>"; // Placeholder to replace with env variable in snippets
// Placeholders to replace with env variable in snippets
// little hack to support both direct requests and routing => routed requests should start with "hf_"
const ACCESS_TOKEN_ROUTING_PLACEHOLDER = "hf_token_placeholder";
const ACCESS_TOKEN_DIRECT_REQUEST_PLACEHOLDER = "not_hf_token_placeholder";

// Snippet generators
const snippetGenerator = (templateName: string, inputPreparationFn?: InputPreparationFn) => {
Expand Down Expand Up @@ -153,7 +158,11 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
return [];
}
const accessTokenOrPlaceholder = opts?.accessToken ?? ACCESS_TOKEN_PLACEHOLDER;

const placeholder = opts?.directRequest
? ACCESS_TOKEN_DIRECT_REQUEST_PLACEHOLDER
: ACCESS_TOKEN_ROUTING_PLACEHOLDER;
const accessTokenOrPlaceholder = opts?.accessToken ?? placeholder;

/// Prepare inputs + make request
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
Expand Down Expand Up @@ -255,8 +264,8 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
}

/// Replace access token placeholder
if (snippet.includes(ACCESS_TOKEN_PLACEHOLDER)) {
snippet = replaceAccessTokenPlaceholder(snippet, language, provider);
if (snippet.includes(placeholder)) {
snippet = replaceAccessTokenPlaceholder(opts?.directRequest, placeholder, snippet, language, provider);
}

/// Snippet is ready!
Expand Down Expand Up @@ -431,6 +440,8 @@ function removeSuffix(str: string, suffix: string) {
}

function replaceAccessTokenPlaceholder(
directRequest: boolean | undefined,
placeholder: string,
snippet: string,
language: InferenceSnippetLanguage,
provider: InferenceProviderOrPolicy
Expand All @@ -439,46 +450,57 @@ function replaceAccessTokenPlaceholder(
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.

// Determine if HF_TOKEN or specific provider token should be used
const accessTokenEnvVar =
!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
snippet.includes("https://router.huggingface.co") || // explicit routed request => use $HF_TOKEN
provider == "hf-inference" // hf-inference provider => use $HF_TOKEN
? "HF_TOKEN"
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"
const useHfToken =
provider == "hf-inference" || // hf-inference provider => use $HF_TOKEN
(!directRequest && // if explicit directRequest => use provider-specific token
(!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
snippet.includes("https://router.huggingface.co"))); // explicit routed request => use $HF_TOKEN

const accessTokenEnvVar = useHfToken
? "HF_TOKEN" // e.g. routed request or hf-inference
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"

// Replace the placeholder with the env variable
if (language === "sh") {
snippet = snippet.replace(
`'Authorization: Bearer ${ACCESS_TOKEN_PLACEHOLDER}'`,
`'Authorization: Bearer ${placeholder}'`,
`"Authorization: Bearer $${accessTokenEnvVar}"` // e.g. "Authorization: Bearer $HF_TOKEN"
);
} else if (language === "python") {
snippet = "import os\n" + snippet;
snippet = snippet.replace(
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
`"${placeholder}"`,
`os.environ["${accessTokenEnvVar}"]` // e.g. os.environ["HF_TOKEN")
);
snippet = snippet.replace(
`"Bearer ${ACCESS_TOKEN_PLACEHOLDER}"`,
`"Bearer ${placeholder}"`,
`f"Bearer {os.environ['${accessTokenEnvVar}']}"` // e.g. f"Bearer {os.environ['HF_TOKEN']}"
);
snippet = snippet.replace(
`"Key ${ACCESS_TOKEN_PLACEHOLDER}"`,
`"Key ${placeholder}"`,
`f"Key {os.environ['${accessTokenEnvVar}']}"` // e.g. f"Key {os.environ['FAL_AI_API_KEY']}"
);
snippet = snippet.replace(
`"X-Key ${placeholder}"`,
`f"X-Key {os.environ['${accessTokenEnvVar}']}"` // e.g. f"X-Key {os.environ['BLACK_FOREST_LABS_API_KEY']}"
);
} else if (language === "js") {
snippet = snippet.replace(
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
`"${placeholder}"`,
`process.env.${accessTokenEnvVar}` // e.g. process.env.HF_TOKEN
);
snippet = snippet.replace(
`Authorization: "Bearer ${ACCESS_TOKEN_PLACEHOLDER}",`,
`Authorization: "Bearer ${placeholder}",`,
`Authorization: \`Bearer $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `Bearer ${process.env.HF_TOKEN}`,
);
snippet = snippet.replace(
`Authorization: "Key ${ACCESS_TOKEN_PLACEHOLDER}",`,
`Authorization: "Key ${placeholder}",`,
`Authorization: \`Key $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `Key ${process.env.FAL_AI_API_KEY}`,
);
snippet = snippet.replace(
`Authorization: "X-Key ${placeholder}",`,
`Authorization: \`X-Key $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `X-Key ${process.env.BLACK_FOREST_LABS_AI_API_KEY}`,
);
}
return snippet;
}
12 changes: 12 additions & 0 deletions packages/tasks-gen/scripts/generate-snippets-fixtures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,18 @@ const TEST_CASES: {
providers: ["hf-inference"],
opts: { accessToken: "hf_xxx" },
},
{
testName: "explicit-direct-request",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["together"],
opts: { directRequest: true },
},
{
testName: "text-to-speech",
task: "text-to-speech",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api.together.xyz/v1",
apiKey: process.env.TOGETHER_API_KEY,
baseURL: "https://router.huggingface.co/together/v1",
apiKey: process.env.HF_TOKEN,
});

const chatCompletion = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from openai import OpenAI

client = OpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
base_url="https://router.huggingface.co/together/v1",
api_key=os.environ["HF_TOKEN"],
)

completion = client.chat.completions.create(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import requests

API_URL = "https://api.together.xyz/v1/chat/completions"
API_URL = "https://router.huggingface.co/together/v1/chat/completions"
headers = {
"Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}",
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
}

def query(payload):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
curl https://api.together.xyz/v1/chat/completions \
-H "Authorization: Bearer $TOGETHER_API_KEY" \
curl https://router.huggingface.co/together/v1/chat/completions \
-H "Authorization: Bearer $HF_TOKEN" \
-H 'Content-Type: application/json' \
-d '{
"messages": [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api.together.xyz/v1",
apiKey: process.env.TOGETHER_API_KEY,
baseURL: "https://router.huggingface.co/together/v1",
apiKey: process.env.HF_TOKEN,
});

const stream = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from openai import OpenAI

client = OpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
base_url="https://router.huggingface.co/together/v1",
api_key=os.environ["HF_TOKEN"],
)

stream = client.chat.completions.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import json
import requests

API_URL = "https://api.together.xyz/v1/chat/completions"
API_URL = "https://router.huggingface.co/together/v1/chat/completions"
headers = {
"Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}",
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
}

def query(payload):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
curl https://api.together.xyz/v1/chat/completions \
-H "Authorization: Bearer $TOGETHER_API_KEY" \
curl https://router.huggingface.co/together/v1/chat/completions \
-H "Authorization: Bearer $HF_TOKEN" \
-H 'Content-Type: application/json' \
-d '{
"messages": [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api.fireworks.ai/inference/v1",
apiKey: process.env.FIREWORKS_AI_API_KEY,
baseURL: "https://router.huggingface.co/fireworks-ai/inference/v1",
apiKey: process.env.HF_TOKEN,
});

const chatCompletion = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from openai import OpenAI

client = OpenAI(
base_url="https://api.fireworks.ai/inference/v1",
api_key=os.environ["FIREWORKS_AI_API_KEY"],
base_url="https://router.huggingface.co/fireworks-ai/inference/v1",
api_key=os.environ["HF_TOKEN"],
)

completion = client.chat.completions.create(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import requests

API_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
API_URL = "https://router.huggingface.co/fireworks-ai/inference/v1/chat/completions"
headers = {
"Authorization": f"Bearer {os.environ['FIREWORKS_AI_API_KEY']}",
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
}

def query(payload):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
curl https://api.fireworks.ai/inference/v1/chat/completions \
-H "Authorization: Bearer $FIREWORKS_AI_API_KEY" \
curl https://router.huggingface.co/fireworks-ai/inference/v1/chat/completions \
-H "Authorization: Bearer $HF_TOKEN" \
-H 'Content-Type: application/json' \
-d '{
"messages": [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api.fireworks.ai/inference/v1",
apiKey: process.env.FIREWORKS_AI_API_KEY,
baseURL: "https://router.huggingface.co/fireworks-ai/inference/v1",
apiKey: process.env.HF_TOKEN,
});

const stream = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from openai import OpenAI

client = OpenAI(
base_url="https://api.fireworks.ai/inference/v1",
api_key=os.environ["FIREWORKS_AI_API_KEY"],
base_url="https://router.huggingface.co/fireworks-ai/inference/v1",
api_key=os.environ["HF_TOKEN"],
)

stream = client.chat.completions.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import json
import requests

API_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
API_URL = "https://router.huggingface.co/fireworks-ai/inference/v1/chat/completions"
headers = {
"Authorization": f"Bearer {os.environ['FIREWORKS_AI_API_KEY']}",
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
}

def query(payload):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
curl https://api.fireworks.ai/inference/v1/chat/completions \
-H "Authorization: Bearer $FIREWORKS_AI_API_KEY" \
curl https://router.huggingface.co/fireworks-ai/inference/v1/chat/completions \
-H "Authorization: Bearer $HF_TOKEN" \
-H 'Content-Type: application/json' \
-d '{
"messages": [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { InferenceClient } from "@huggingface/inference";

const client = new InferenceClient(process.env.TOGETHER_API_KEY);

const chatCompletion = await client.chatCompletion({
provider: "together",
model: "meta-llama/Llama-3.1-8B-Instruct",
messages: [
{
role: "user",
content: "What is the capital of France?",
},
],
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api.together.xyz/v1",
apiKey: process.env.TOGETHER_API_KEY,
});

const chatCompletion = await client.chat.completions.create({
model: "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
messages: [
{
role: "user",
content: "What is the capital of France?",
},
],
});

console.log(chatCompletion.choices[0].message);
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="together",
api_key=os.environ["TOGETHER_API_KEY"],
)

completion = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "What is the capital of France?"
}
],
)

print(completion.choices[0].message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
from openai import OpenAI

client = OpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
)

completion = client.chat.completions.create(
model="<together alias for meta-llama/Llama-3.1-8B-Instruct>",
messages=[
{
"role": "user",
"content": "What is the capital of France?"
}
],
)

print(completion.choices[0].message)
Loading