Skip to content

[InferenceSnippet] Take token from env variable if not set #1514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions packages/inference/src/snippets/getInferenceSnippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions.j
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types.js";
import { templates } from "./templates.exported.js";

export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string } & Record<string, unknown>;
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string; accessToken?: string } & Record<
string,
unknown
>;

const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
Expand Down Expand Up @@ -121,11 +124,12 @@ const HF_JS_METHODS: Partial<Record<WidgetType, string>> = {
translation: "translation",
};

const ACCESS_TOKEN_PLACEHOLDER = "<ACCESS_TOKEN>"; // Placeholder to replace with env variable in snippets

// Snippet generators
const snippetGenerator = (templateName: string, inputPreparationFn?: InputPreparationFn) => {
return (
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProviderOrPolicy,
inferenceProviderMapping?: InferenceProviderModelMapping,
opts?: InferenceSnippetOptions
Expand All @@ -149,13 +153,15 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
return [];
}
const accessTokenOrPlaceholder = opts?.accessToken ?? ACCESS_TOKEN_PLACEHOLDER;

/// Prepare inputs + make request
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
const request = makeRequestOptionsFromResolvedModel(
providerModelId,
providerHelper,
{
accessToken,
accessToken: accessTokenOrPlaceholder,
provider,
...inputs,
} as RequestArgs,
Expand All @@ -180,7 +186,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar

/// Prepare template injection data
const params: TemplateParams = {
accessToken,
accessToken: accessTokenOrPlaceholder,
authorizationHeader: (request.info.headers as Record<string, string>)?.Authorization,
baseUrl: removeSuffix(request.url, "/chat/completions"),
fullUrl: request.url,
Expand Down Expand Up @@ -248,6 +254,11 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
snippet = `${importSection}\n\n${snippet}`;
}

/// Replace access token placeholder
if (snippet.includes(ACCESS_TOKEN_PLACEHOLDER)) {
snippet = replaceAccessTokenPlaceholder(snippet, language, provider);
}

/// Snippet is ready!
return { language, client: client as string, content: snippet };
})
Expand Down Expand Up @@ -299,7 +310,6 @@ const snippets: Partial<
PipelineType,
(
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProviderOrPolicy,
inferenceProviderMapping?: InferenceProviderModelMapping,
opts?: InferenceSnippetOptions
Expand Down Expand Up @@ -339,13 +349,12 @@ const snippets: Partial<

export function getInferenceSnippets(
model: ModelDataMinimal,
accessToken: string,
provider: InferenceProviderOrPolicy,
inferenceProviderMapping?: InferenceProviderModelMapping,
opts?: Record<string, unknown>
): InferenceSnippet[] {
return model.pipeline_tag && model.pipeline_tag in snippets
? snippets[model.pipeline_tag]?.(model, accessToken, provider, inferenceProviderMapping, opts) ?? []
? snippets[model.pipeline_tag]?.(model, provider, inferenceProviderMapping, opts) ?? []
: [];
}

Expand Down Expand Up @@ -420,3 +429,56 @@ function indentString(str: string): string {
function removeSuffix(str: string, suffix: string) {
return str.endsWith(suffix) ? str.slice(0, -suffix.length) : str;
}

function replaceAccessTokenPlaceholder(
snippet: string,
language: InferenceSnippetLanguage,
provider: InferenceProviderOrPolicy
): string {
// If "opts.accessToken" is not set, the snippets are generated with a placeholder.
// Once snippets are rendered, we replace the placeholder with code to fetch the access token from an environment variable.

// Determine if HF_TOKEN or specific provider token should be used
const accessTokenEnvVar =
!snippet.includes("https://") || // no URL provided => using a client => use $HF_TOKEN
snippet.includes("https://router.huggingface.co") || // explicit routed request => use $HF_TOKEN
provider == "hf-inference" // hf-inference provider => use $HF_TOKEN
? "HF_TOKEN"
: provider.toUpperCase().replace("-", "_") + "_API_KEY"; // e.g. "REPLICATE_API_KEY"

// Replace the placeholder with the env variable
if (language === "sh") {
snippet = snippet.replace(
`'Authorization: Bearer ${ACCESS_TOKEN_PLACEHOLDER}'`,
`"Authorization: Bearer $${accessTokenEnvVar}"` // e.g. "Authorization: Bearer $HF_TOKEN"
);
} else if (language === "python") {
snippet = "import os\n" + snippet;
snippet = snippet.replace(
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
`os.environ["${accessTokenEnvVar}"]` // e.g. os.environ["HF_TOKEN")
);
snippet = snippet.replace(
`"Bearer ${ACCESS_TOKEN_PLACEHOLDER}"`,
`f"Bearer {os.environ['${accessTokenEnvVar}']}"` // e.g. f"Bearer {os.environ['HF_TOKEN']}"
);
snippet = snippet.replace(
`"Key ${ACCESS_TOKEN_PLACEHOLDER}"`,
`f"Key {os.environ['${accessTokenEnvVar}']}"` // e.g. f"Key {os.environ['FAL_AI_API_KEY']}"
);
} else if (language === "js") {
snippet = snippet.replace(
`"${ACCESS_TOKEN_PLACEHOLDER}"`,
`process.env.${accessTokenEnvVar}` // e.g. process.env.HF_TOKEN
);
snippet = snippet.replace(
`Authorization: "Bearer ${ACCESS_TOKEN_PLACEHOLDER}",`,
`Authorization: \`Bearer $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `Bearer ${process.env.HF_TOKEN}`,
);
snippet = snippet.replace(
`Authorization: "Key ${ACCESS_TOKEN_PLACEHOLDER}",`,
`Authorization: \`Key $\{process.env.${accessTokenEnvVar}}\`,` // e.g. Authorization: `Key ${process.env.FAL_AI_API_KEY}`,
);
}
return snippet;
}
13 changes: 12 additions & 1 deletion packages/tasks-gen/scripts/generate-snippets-fixtures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,18 @@ const TEST_CASES: {
providers: ["hf-inference"],
opts: { billTo: "huggingface" },
},
{
testName: "with-access-token",
task: "conversational",
model: {
id: "meta-llama/Llama-3.1-8B-Instruct",
pipeline_tag: "text-generation",
tags: ["conversational"],
inference: "",
},
providers: ["hf-inference"],
opts: { accessToken: "hf_xxx" },
},
{
testName: "text-to-speech",
task: "text-to-speech",
Expand Down Expand Up @@ -314,7 +326,6 @@ function generateInferenceSnippet(
): InferenceSnippet[] {
const allSnippets = snippets.getInferenceSnippets(
model,
"api_token",
provider,
{
hfModelId: model.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ async function query(data) {
"https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo",
{
headers: {
Authorization: "Bearer api_token",
Authorization: `Bearer ${process.env.HF_TOKEN}`,
"Content-Type": "audio/flac",
},
method: "POST",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { InferenceClient } from "@huggingface/inference";

const client = new InferenceClient("api_token");
const client = new InferenceClient(process.env.HF_TOKEN);

const data = fs.readFileSync("sample1.flac");

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="hf-inference",
api_key="api_token",
api_key=os.environ["HF_TOKEN"],
)

output = client.automatic_speech_recognition("sample1.flac", model="openai/whisper-large-v3-turbo")
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import requests

API_URL = "https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo"
headers = {
"Authorization": "Bearer api_token",
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
}

def query(filename):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
curl https://router.huggingface.co/hf-inference/models/openai/whisper-large-v3-turbo \
-X POST \
-H 'Authorization: Bearer api_token' \
-H "Authorization: Bearer $HF_TOKEN" \
-H 'Content-Type: audio/flac' \
--data-binary @"sample1.flac"
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ async function query(data) {
"https://router.huggingface.co/hf-inference/models/FacebookAI/xlm-roberta-large-finetuned-conll03-english",
{
headers: {
Authorization: "Bearer api_token",
Authorization: `Bearer ${process.env.HF_TOKEN}`,
"Content-Type": "application/json",
},
method: "POST",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { InferenceClient } from "@huggingface/inference";

const client = new InferenceClient("api_token");
const client = new InferenceClient(process.env.HF_TOKEN);

const output = await client.tokenClassification({
model: "FacebookAI/xlm-roberta-large-finetuned-conll03-english",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="hf-inference",
api_key="api_token",
api_key=os.environ["HF_TOKEN"],
)

result = client.token_classification(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import requests

API_URL = "https://router.huggingface.co/hf-inference/models/FacebookAI/xlm-roberta-large-finetuned-conll03-english"
headers = {
"Authorization": "Bearer api_token",
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
}

def query(payload):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
curl https://router.huggingface.co/hf-inference/models/FacebookAI/xlm-roberta-large-finetuned-conll03-english \
-X POST \
-H 'Authorization: Bearer api_token' \
-H "Authorization: Bearer $HF_TOKEN" \
-H 'Content-Type: application/json' \
-d '{
"inputs": "\"My name is Sarah Jessica Parker but you can call me Jessica\""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { InferenceClient } from "@huggingface/inference";

const client = new InferenceClient("api_token");
const client = new InferenceClient(process.env.HF_TOKEN);

const chatCompletion = await client.chatCompletion({
provider: "hf-inference",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
apiKey: "api_token",
apiKey: process.env.HF_TOKEN,
defaultHeaders: {
"X-HF-Bill-To": "huggingface"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="hf-inference",
api_key="api_token",
api_key=os.environ["HF_TOKEN"],
bill_to="huggingface",
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from openai import OpenAI

client = OpenAI(
base_url="https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
api_key="api_token",
api_key=os.environ["HF_TOKEN"],
default_headers={
"X-HF-Bill-To": "huggingface"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import requests

API_URL = "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions"
headers = {
"Authorization": "Bearer api_token",
"Authorization": f"Bearer {os.environ['HF_TOKEN']}",
"X-HF-Bill-To": "huggingface"
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
curl https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions \
-H 'Authorization: Bearer api_token' \
-H "Authorization: Bearer $HF_TOKEN" \
-H 'Content-Type: application/json' \
-H 'X-HF-Bill-To: huggingface' \
-d '{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { InferenceClient } from "@huggingface/inference";

const client = new InferenceClient("api_token");
const client = new InferenceClient(process.env.HF_TOKEN);

const chatCompletion = await client.chatCompletion({
provider: "hf-inference",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { InferenceClient } from "@huggingface/inference";

const client = new InferenceClient("api_token");
const client = new InferenceClient(process.env.HF_TOKEN);

const chatCompletion = await client.chatCompletion({
provider: "together",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
apiKey: "api_token",
apiKey: process.env.HF_TOKEN,
});

const chatCompletion = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { OpenAI } from "openai";

const client = new OpenAI({
baseURL: "https://api.together.xyz/v1",
apiKey: "api_token",
apiKey: process.env.TOGETHER_API_KEY,
});

const chatCompletion = await client.chat.completions.create({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="hf-inference",
api_key="api_token",
api_key=os.environ["HF_TOKEN"],
)

completion = client.chat.completions.create(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from huggingface_hub import InferenceClient

client = InferenceClient(
provider="together",
api_key="api_token",
api_key=os.environ["HF_TOKEN"],
)

completion = client.chat.completions.create(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from openai import OpenAI

client = OpenAI(
base_url="https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct/v1",
api_key="api_token",
api_key=os.environ["HF_TOKEN"],
)

completion = client.chat.completions.create(
Expand Down
Loading
Loading