Skip to content

Commit b174dcb

Browse files
committed
ToolBuilder
1 parent af206d6 commit b174dcb

File tree

3 files changed

+145
-5
lines changed

3 files changed

+145
-5
lines changed

src/Chat.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import OpenAI from "openai";
22
import { PromptBuilder } from "./PromptBuilder";
33
import { ExtractArgs, ExtractChatArgs, ReplaceChatArgs } from "./types";
4+
import { ToolBuilder } from "./ToolBuilder";
45

56
export class Chat<
7+
ToolNames extends string,
68
TMessages extends
79
| []
810
| [
@@ -14,8 +16,13 @@ export class Chat<
1416
constructor(
1517
public messages: TMessages,
1618
public args: TSuppliedInputArgs,
19+
public tools = {} as Record<ToolNames, ToolBuilder>,
20+
public mustUseTool: boolean = false
1721
) {}
1822

23+
toJSONSchema() {
24+
}
25+
1926
toArray() {
2027
return this.messages.map((m) => ({
2128
role: m.role,

src/ToolBuilder.ts

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
interface Tool<I = unknown, O = unknown> {
2+
name: string;
3+
type: "query" | "mutation"
4+
build: (input: I) => O;
5+
}
6+
7+
export class ToolBuilder<TType extends "query" | "mutation" = "query", I = unknown, O = unknown> {
8+
private name: string;
9+
private implementation?: (input: I) => O;
10+
private type: TType;
11+
12+
constructor(name: string, type: TType = "query" as TType) {
13+
this.name = name;
14+
this.type = type;
15+
}
16+
17+
addInputValidation<T = I>(): ToolBuilder<TType, T, O> {
18+
// Implementation here
19+
return this as unknown as ToolBuilder<TType, T, O>;
20+
}
21+
22+
addOutputValidation<T = O>(): ToolBuilder<TType, I, T> {
23+
// Implementation here
24+
return this as unknown as ToolBuilder<TType, I, T>;
25+
}
26+
27+
query(queryFunction: (input: I) => O): ToolBuilder<"query", I, O> {
28+
29+
return {
30+
...this,
31+
implementation: queryFunction,
32+
type: "query"
33+
};
34+
}
35+
36+
mutation(mutationFunction: (input: I) => O): ToolBuilder<"mutation", I, O> {
37+
return {
38+
...this,
39+
implementation: mutationFunction,
40+
type: "mutation"
41+
};
42+
}
43+
44+
build(): Tool<I, O> {
45+
return {
46+
name: this.name,
47+
build: this.implementation!,
48+
type: this.type
49+
};
50+
}
51+
}

src/__tests__/Chat.test.ts

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ import { strict as assert } from "node:assert";
22
import { Chat } from "../Chat";
33
import { system, user, assistant } from "../ChatHelpers";
44
import { Equal, Expect } from "./types.test";
5+
import { ToolBuilder } from "../ToolBuilder";
56

67
describe("Chat", () => {
78
it("should allow empty array", () => {
8-
const chat = new Chat([], {}).toArray();
9+
const chat = new Chat([], {}, {}).toArray();
910
type test = Expect<Equal<typeof chat, []>>;
1011
assert.deepEqual(chat, []);
1112
});
@@ -54,16 +55,97 @@ describe("Chat", () => {
5455
});
5556

5657
it("should allow chat of all diffent types with no args", () => {
57-
const chat = new Chat(
58-
[user(`Tell me a joke`), assistant(`joke?`), system(`joke?`)],
59-
{},
60-
).toArray();
6158
const usrMsg = user("Tell me a joke");
6259
const astMsg = assistant("joke?");
6360
const sysMsg = system("joke?");
61+
62+
const chat = new Chat([usrMsg, astMsg, sysMsg], {}).toArray();
6463
type test = Expect<
6564
Equal<typeof chat, [typeof usrMsg, typeof astMsg, typeof sysMsg]>
6665
>;
6766
assert.deepEqual(chat, [usrMsg, astMsg, sysMsg]);
6867
});
68+
69+
it("should allow me to pass in tools", () => {
70+
const usrMsg = user("Tell me a joke");
71+
const astMsg = assistant("joke?");
72+
const sysMsg = system("joke?");
73+
const tools = {
74+
google: new ToolBuilder("google")
75+
.addInputValidation<{ query: string }>()
76+
.addOutputValidation<{ results: string[] }>()
77+
.query(({ query }) => {
78+
return {
79+
results: ["foo", "bar"],
80+
};
81+
}),
82+
wikipedia: new ToolBuilder("wikipedia")
83+
.addInputValidation<{ page: string }>()
84+
.addOutputValidation<{ results: string[] }>()
85+
.query(({ page }) => {
86+
return {
87+
results: ["foo", "bar"],
88+
};
89+
}),
90+
sendEmail: new ToolBuilder("sendEmail")
91+
.addInputValidation<{ to: string; subject: string; body: string }>()
92+
.addOutputValidation<{ success: boolean }>()
93+
.mutation(({ to, subject, body }) => {
94+
return {
95+
success: true,
96+
};
97+
}),
98+
} as const;
99+
100+
const chat = new Chat([usrMsg, astMsg, sysMsg], {}, tools);
101+
102+
type tests = [
103+
Expect<
104+
Equal<
105+
typeof chat,
106+
Chat<
107+
keyof typeof tools,
108+
[typeof usrMsg, typeof astMsg, typeof sysMsg],
109+
{}
110+
>
111+
>
112+
>,
113+
Expect<
114+
Equal<
115+
typeof tools,
116+
{
117+
readonly google: ToolBuilder<
118+
"query",
119+
{
120+
query: string;
121+
},
122+
{
123+
results: string[];
124+
}
125+
>;
126+
readonly wikipedia: ToolBuilder<
127+
"query",
128+
{
129+
page: string;
130+
},
131+
{
132+
results: string[];
133+
}
134+
>;
135+
readonly sendEmail: ToolBuilder<
136+
"mutation",
137+
{
138+
to: string;
139+
subject: string;
140+
body: string;
141+
},
142+
{
143+
success: boolean;
144+
}
145+
>;
146+
}
147+
>
148+
>
149+
];
150+
});
69151
});

0 commit comments

Comments
 (0)