Skip to content

Commit a160b10

Browse files
authored
[inference] Fix types for Tool calling (#1367)
EDIT: ok, ready for review ![tenor](https://github.com/user-attachments/assets/e5871f1a-d42a-449b-9cb1-77aed17f45d8)
1 parent 2161017 commit a160b10

File tree

6 files changed

+39
-19
lines changed

6 files changed

+39
-19
lines changed

README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ await uploadFile({
2727
}
2828
});
2929

30-
// Use HF Inference API, or external Inference Providers!
30+
// Use all supported Inference Providers!
3131

3232
await inference.chatCompletion({
3333
model: "meta-llama/Llama-3.1-8B-Instruct",
@@ -55,7 +55,7 @@ await inference.textToImage({
5555

5656
This is a collection of JS libraries to interact with the Hugging Face API, with TS types included.
5757

58-
- [@huggingface/inference](packages/inference/README.md): Use HF Inference API (serverless), Inference Endpoints (dedicated) and all supported Inference Providers to make calls to 100,000+ Machine Learning models
58+
- [@huggingface/inference](packages/inference/README.md): Use all supported (serverless) Inference Providers or switch to Inference Endpoints (dedicated) to make calls to 100,000+ Machine Learning models
5959
- [@huggingface/hub](packages/hub/README.md): Interact with huggingface.co to create or delete repos and commit / download files
6060
- [@huggingface/agents](packages/agents/README.md): Interact with HF models through a natural language interface
6161
- [@huggingface/gguf](packages/gguf/README.md): A GGUF parser that works on remotely hosted files.
@@ -128,18 +128,18 @@ import { InferenceClient } from "@huggingface/inference";
128128

129129
const HF_TOKEN = "hf_...";
130130

131-
const inference = new InferenceClient(HF_TOKEN);
131+
const client = new InferenceClient(HF_TOKEN);
132132

133133
// Chat completion API
134-
const out = await inference.chatCompletion({
134+
const out = await client.chatCompletion({
135135
model: "meta-llama/Llama-3.1-8B-Instruct",
136136
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
137137
max_tokens: 512
138138
});
139139
console.log(out.choices[0].message);
140140

141141
// Streaming chat completion API
142-
for await (const chunk of inference.chatCompletionStream({
142+
for await (const chunk of client.chatCompletionStream({
143143
model: "meta-llama/Llama-3.1-8B-Instruct",
144144
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
145145
max_tokens: 512
@@ -148,14 +148,14 @@ for await (const chunk of inference.chatCompletionStream({
148148
}
149149

150150
/// Using a third-party provider:
151-
await inference.chatCompletion({
151+
await client.chatCompletion({
152152
model: "meta-llama/Llama-3.1-8B-Instruct",
153153
messages: [{ role: "user", content: "Hello, nice to meet you!" }],
154154
max_tokens: 512,
155155
provider: "sambanova", // or together, fal-ai, replicate, cohere …
156156
})
157157

158-
await inference.textToImage({
158+
await client.textToImage({
159159
model: "black-forest-labs/FLUX.1-dev",
160160
inputs: "a picture of a green bird",
161161
provider: "fal-ai",
@@ -164,7 +164,7 @@ await inference.textToImage({
164164

165165

166166
// You can also omit "model" to use the recommended model for the task
167-
await inference.translation({
167+
await client.translation({
168168
inputs: "My name is Wolfgang and I live in Amsterdam",
169169
parameters: {
170170
src_lang: "en",
@@ -173,17 +173,17 @@ await inference.translation({
173173
});
174174

175175
// pass multimodal files or URLs as inputs
176-
await inference.imageToText({
176+
await client.imageToText({
177177
model: 'nlpconnect/vit-gpt2-image-captioning',
178178
data: await (await fetch('https://picsum.photos/300/300')).blob(),
179179
})
180180

181181
// Using your own dedicated inference endpoint: https://hf.co/docs/inference-endpoints/
182-
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
183-
const { generated_text } = await gpt2.textGeneration({ inputs: 'The answer to the universe is' });
182+
const gpt2Client = client.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
183+
const { generated_text } = await gpt2Client.textGeneration({ inputs: 'The answer to the universe is' });
184184

185185
// Chat Completion
186-
const llamaEndpoint = inference.endpoint(
186+
const llamaEndpoint = client.endpoint(
187187
"https://router.huggingface.co/hf-inference/models/meta-llama/Llama-3.1-8B-Instruct"
188188
);
189189
const out = await llamaEndpoint.chatCompletion({

packages/inference/test/InferenceClient.spec.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ describe.concurrent("InferenceClient", () => {
376376
);
377377
});
378378

379-
it("textGeneration - gpt2", async () => {
379+
it.skip("textGeneration - gpt2", async () => {
380380
expect(
381381
await hf.textGeneration({
382382
model: "gpt2",
@@ -387,7 +387,7 @@ describe.concurrent("InferenceClient", () => {
387387
});
388388
});
389389

390-
it("textGeneration - openai-community/gpt2", async () => {
390+
it.skip("textGeneration - openai-community/gpt2", async () => {
391391
expect(
392392
await hf.textGeneration({
393393
model: "openai-community/gpt2",

packages/tasks-gen/scripts/inference-tgi-import.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ function toCamelCase(str: string, joiner = "") {
3434
.join(joiner);
3535
}
3636

37+
const OVERRIDES_TYPES_RENAME_PROPERTIES: Record<string, Record<string, string>> = {
38+
ChatCompletionInputFunctionDefinition: { arguments: "parameters" },
39+
};
40+
const OVERRIDES_TYPES_OVERRIDE_PROPERTY_TYPE: Record<string, Record<string, unknown>> = {
41+
ChatCompletionOutputFunctionDefinition: { arguments: { type: "string" } },
42+
};
43+
3744
async function _extractAndAdapt(task: string, mainComponentName: string, type: "input" | "output" | "stream_output") {
3845
console.debug(`✨ Importing`, task, type);
3946

@@ -57,6 +64,17 @@ async function _extractAndAdapt(task: string, mainComponentName: string, type: "
5764
_scan(item);
5865
}
5966
} else if (data && typeof data === "object") {
67+
/// This next section can be removed when we don't use TGI as source of types.
68+
if (typeof data.title === "string" && data.title in OVERRIDES_TYPES_RENAME_PROPERTIES) {
69+
const [[oldName, newName]] = Object.entries(OVERRIDES_TYPES_RENAME_PROPERTIES[data.title]);
70+
data.required = JSON.parse(JSON.stringify(data.required).replaceAll(oldName, newName));
71+
data.properties = JSON.parse(JSON.stringify(data.properties).replaceAll(oldName, newName));
72+
}
73+
if (typeof data.title === "string" && data.title in OVERRIDES_TYPES_OVERRIDE_PROPERTY_TYPE) {
74+
const [[prop, newType]] = Object.entries(OVERRIDES_TYPES_OVERRIDE_PROPERTY_TYPE[data.title]);
75+
(data.properties as Record<string, unknown>)[prop] = newType;
76+
}
77+
/// End of overrides section
6078
for (const key of Object.keys(data)) {
6179
if (key === "$ref" && typeof data[key] === "string") {
6280
// Verify reference exists

packages/tasks/src/tasks/chat-completion/inference.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ export interface ChatCompletionInputToolCall {
130130
[property: string]: unknown;
131131
}
132132
export interface ChatCompletionInputFunctionDefinition {
133-
arguments: unknown;
134133
description?: string;
135134
name: string;
135+
parameters: unknown;
136136
[property: string]: unknown;
137137
}
138138
export interface ChatCompletionInputGrammarType {
@@ -235,7 +235,7 @@ export interface ChatCompletionOutputToolCall {
235235
[property: string]: unknown;
236236
}
237237
export interface ChatCompletionOutputFunctionDefinition {
238-
arguments: unknown;
238+
arguments: string;
239239
description?: string;
240240
name: string;
241241
[property: string]: unknown;

packages/tasks/src/tasks/chat-completion/spec/input.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@
275275
},
276276
"ChatCompletionInputFunctionDefinition": {
277277
"type": "object",
278-
"required": ["name", "arguments"],
278+
"required": ["name", "parameters"],
279279
"properties": {
280-
"arguments": {},
280+
"parameters": {},
281281
"description": {
282282
"type": "string",
283283
"nullable": true

packages/tasks/src/tasks/chat-completion/spec/output.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@
173173
"type": "object",
174174
"required": ["name", "arguments"],
175175
"properties": {
176-
"arguments": {},
176+
"arguments": {
177+
"type": "string"
178+
},
177179
"description": {
178180
"type": "string",
179181
"nullable": true

0 commit comments

Comments
 (0)