Skip to content

wip #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
"@types/jest": "^29.5.1",
"@types/node": "^20.2.4",
"jest": "^29.5.0",
"openai": "^3.2.1",
"openai": "^3.3.0",
"ts-jest": "^29.1.0",
"ts-toolbelt": "^9.6.0",
"typescript": "^5.0.4"
},
"dependencies": {
"arktype": "1.0.14-alpha",
"zod": "^3.21.4"
}
}
15 changes: 11 additions & 4 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion src/Chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ export class Chat<
) {}

toArray() {
return (this.messages as any[]).map((m: ChatCompletionRequestMessage) => ({
return (this.messages as any[])
.filter((m) => m.content !== undefined)
.map((m) => ({
role: m.role,
content: new PromptBuilder(m.content)
.addInputValidation<ExtractArgs<typeof m.content, typeof this.args>>()
Expand Down
28 changes: 21 additions & 7 deletions src/ChatBuilder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { z } from "zod";
import { F } from "ts-toolbelt";
import { ChatCompletionRequestMessage } from "openai";
import { Chat } from "./Chat";
import { user, assistant, system } from "./ChatHelpers";
import { User, Assistant, System, Function } from "./ChatHelpers";
import { ExtractArgs, ExtractChatArgs, TypeToZodShape, ReplaceChatArgs } from "./types";

export class ChatBuilder<
Expand All @@ -19,31 +19,45 @@ export class ChatBuilder<
return new ChatBuilder(this.messages) as any;
}

user<TUserText extends string>(
User<TUserText extends string>(
str: TUserText
): ChatBuilder<
[...TMessages, { role: "user"; content: TUserText }],
F.Narrow<TExpectedInput> & ExtractArgs<TUserText>
> {
return new ChatBuilder([...this.messages, user(str)]) as any;
return new ChatBuilder([...this.messages, User(str)]) as any;
}

system<TSystemText extends string>(
System<TSystemText extends string>(
str: TSystemText
): ChatBuilder<
[...TMessages, { role: "system"; content: TSystemText }],
F.Narrow<TExpectedInput> & ExtractArgs<TSystemText>
> {
return new ChatBuilder([...this.messages, system(str)]) as any;
return new ChatBuilder([...this.messages, System(str)]) as any;
}

assistant<TAssistantText extends string>(
Assistant<TAssistantText extends string>(
str: TAssistantText
): ChatBuilder<
[...TMessages, { role: "assistant"; content: TAssistantText }],
F.Narrow<TExpectedInput> & ExtractArgs<TAssistantText>
> {
return new ChatBuilder([...this.messages, assistant(str)]) as any;
return new ChatBuilder([...this.messages, Assistant(str)]) as any;
}

// Backwards compadibility
user = this.User;
system = this.System;
assistant = this.Assistant;

Function<TAssistantText extends string>(
str: TAssistantText
): ChatBuilder<
[...TMessages, { role: "function"; content: TAssistantText }],
F.Narrow<TExpectedInput> & ExtractArgs<TAssistantText>
> {
return new ChatBuilder([...this.messages, Function(str)]) as any;
}

addZodInputValidation<TShape extends TExpectedInput>(
Expand Down
21 changes: 17 additions & 4 deletions src/ChatHelpers.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// ChatMessage creation helpers
// Ideally these would Dedent their content, but ts is checker is way too slow
// https://tinyurl.com/message-creators-literal-types
export function system<T extends string>(
export function System<T extends string>(
literals: TemplateStringsArray | T,
...placeholders: unknown[]
) {
Expand All @@ -10,7 +9,7 @@ export function system<T extends string>(
content: dedent(literals, ...placeholders),
};
}
export function user<T extends string>(
export function User<T extends string>(
literals: TemplateStringsArray | T,
...placeholders: unknown[]
) {
Expand All @@ -19,7 +18,7 @@ export function user<T extends string>(
content: dedent(literals, ...placeholders),
};
}
export function assistant<T extends string>(
export function Assistant<T extends string>(
literals: TemplateStringsArray | T,
...placeholders: unknown[]
) {
Expand All @@ -28,6 +27,20 @@ export function assistant<T extends string>(
content: dedent(literals, ...placeholders),
};
}
export function Function<T extends string>(
literals: TemplateStringsArray | T,
...placeholders: unknown[]
) {
return {
role: "function" as const,
content: dedent(literals, ...placeholders),
};
}

// backwards compadibility
export const system = System;
export const user = User;
export const assistant = Assistant;

export function dedent<T extends string>(
templ: TemplateStringsArray | T,
Expand Down
85 changes: 69 additions & 16 deletions src/PromptBuilder.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { z } from "zod";
import { Type } from "arktype";
import { F } from "ts-toolbelt";
import { Prompt } from "./Prompt";
import { ExtractArgs, ReplaceArgs, TypeToZodShape } from "./types";
Expand All @@ -18,23 +19,13 @@ export class PromptBuilder<
addZodInputValidation<TShape extends TExpectedInput>(
shape: TypeToZodShape<TShape>
) {
const zodValidator = z.object(shape as any);
return new (class extends PromptBuilder<TPromptTemplate, TShape> {
validate(args: Record<string, any>): args is TShape {
return zodValidator.safeParse(args).success;
}

get type() {
return this.template as ReplaceArgs<TPromptTemplate, TShape>;
}
return new ZodPromptBuilder(this.template, shape);
}

build<TSuppliedInputArgs extends TShape>(
args: F.Narrow<TSuppliedInputArgs>
) {
zodValidator.parse(args);
return new Prompt(this.template, args).toString();
}
})(this.template);
addArkTypeInputValidation<TShape extends TExpectedInput>(
shape: Type<TShape>
) {
return new ArkTypePromptBuilder(this.template, shape);
}

validate(args: Record<string, any>): args is TExpectedInput {
Expand All @@ -52,3 +43,65 @@ export class PromptBuilder<
return new Prompt(this.template, args).toString();
}
}

class ZodPromptBuilder<
TPromptTemplate extends string,
TExpectedInput extends ExtractArgs<TPromptTemplate, {}>
> extends PromptBuilder<TPromptTemplate, TExpectedInput> {
constructor(
public template: TPromptTemplate,
public shape: TypeToZodShape<TExpectedInput>
) {
super(template);
}
validate(args: Record<string, any>): args is TExpectedInput {
const zodValidator = z.object(this.shape as any);
return zodValidator.safeParse(args).success;
}

get type() {
return this.template as ReplaceArgs<TPromptTemplate, TExpectedInput>;
}

build<TSuppliedInputArgs extends TExpectedInput>(
args: F.Narrow<TSuppliedInputArgs>
) {
const zodValidator = z.object(this.shape as any);
zodValidator.parse(args);
return new Prompt(this.template, args).toString();
}
}

class ArkTypePromptBuilder<
TPromptTemplate extends string,
TExpectedInput extends ExtractArgs<TPromptTemplate, {}>
> extends PromptBuilder<TPromptTemplate, TExpectedInput> {
constructor(
public template: TPromptTemplate,
public shape: Type<TExpectedInput>
) {
super(template);
}
validate(args: Record<string, any>): args is TExpectedInput {
try {
this.shape(args);
return true;
} catch (e) {
return false;
}
}

get type() {
return this.template as ReplaceArgs<TPromptTemplate, TExpectedInput>;
}

build<TSuppliedInputArgs extends TExpectedInput>(
args: F.Narrow<TSuppliedInputArgs>
) {
const { problems } = this.shape(args);
if (problems?.summary) {
throw new Error(problems.summary);
}
return new Prompt(this.template, args).toString();
}
}
32 changes: 16 additions & 16 deletions src/__tests__/Chat.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { strict as assert } from "node:assert";
import { Chat } from "../Chat";
import { system, user, assistant } from "../ChatHelpers";
import { System, User, Assistant } from "../ChatHelpers";
import { Equal, Expect } from "./types.test";

describe("Chat", () => {
Expand All @@ -14,7 +14,7 @@ describe("Chat", () => {
const chat = new Chat(
[
// ^?
user("Tell me a {{jokeType}} joke"),
User("Tell me a {{jokeType}} joke"),
],
// @ts-expect-error
{}
Expand All @@ -28,13 +28,13 @@ describe("Chat", () => {
const chat = new Chat(
[
// ^?
user(`Tell me a {{jokeType}} joke`),
User(`Tell me a {{jokeType}} joke`),
],
{
jokeType: "funny" as const,
}
).toArray();
const usrMsg = user("Tell me a funny joke");
const usrMsg = User("Tell me a funny joke");
// ^?
type test = Expect<Equal<typeof chat, [typeof usrMsg]>>;
assert.deepEqual(chat, [usrMsg]);
Expand All @@ -44,19 +44,19 @@ describe("Chat", () => {
const chat = new Chat(
[
// ^?
user(`Tell me a {{jokeType1}} joke`),
assistant(`{{var2}} joke?`),
system(`joke? {{var3}}`),
User(`Tell me a {{jokeType1}} joke`),
Assistant(`{{var2}} joke?`),
System(`joke? {{var3}}`),
],
{
jokeType1: "funny",
var2: "foo",
var3: "bar",
} as const
).toArray();
const usrMsg = user("Tell me a funny joke");
const astMsg = assistant("foo joke?");
const sysMsg = system("joke? bar");
const usrMsg = User("Tell me a funny joke");
const astMsg = Assistant("foo joke?");
const sysMsg = System("joke? bar");
type test = Expect<
Equal<typeof chat, [typeof usrMsg, typeof astMsg, typeof sysMsg]>
>;
Expand All @@ -67,15 +67,15 @@ describe("Chat", () => {
const chat = new Chat(
[
// ^?
user(`Tell me a joke`),
assistant(`joke?`),
system(`joke?`),
User(`Tell me a joke`),
Assistant(`joke?`),
System(`joke?`),
],
{}
).toArray();
const usrMsg = user("Tell me a joke");
const astMsg = assistant("joke?");
const sysMsg = system("joke?");
const usrMsg = User("Tell me a joke");
const astMsg = Assistant("joke?");
const sysMsg = System("joke?");
type test = Expect<
Equal<typeof chat, [typeof usrMsg, typeof astMsg, typeof sysMsg]>
>;
Expand Down
Loading