Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stack/app/api/v1/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class CreateRunPayload(BaseModel):

assistant_id: Optional[str] = None
thread_id: Optional[str] = None
input: list[dict]
input: Optional[list[dict]] = None
config: Optional[RunnableConfig] = None


Expand Down
1 change: 1 addition & 0 deletions ui/src/components/features/chat-panel/chat-panel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export default function ChatPanel() {
<MessagesContainer
threadId={threadId as string}
stream={stream as TStreamState}
startStream={startStream}
/>
) : (
<div className="self-center h-full items-center flex">
Expand Down
18 changes: 14 additions & 4 deletions ui/src/components/features/chat-panel/components/composer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ export function Composer({ onChange, onSend, value, disabled, isStreaming, onSto
if (e.shiftKey) return;

e.preventDefault();
onSend();
if (value.trim()) {
onSend();
}
}
};

Expand All @@ -37,9 +39,17 @@ export function Composer({ onChange, onSend, value, disabled, isStreaming, onSto
<Tooltip delayDuration={0}>
<TooltipTrigger asChild>
<Input
endIcon={<Button onClick={isStreaming ? onStop : onSend} variant="ghost">
{isStreaming ? <StopCircle size={18}/> : <Send size={18}/>}
</Button>}
endIcon={
<Button
onClick={isStreaming ? onStop : onSend}
variant="ghost"
disabled={!isStreaming && !value.trim()}
>
{isStreaming ?
<StopCircle size={18}/> :
<Send size={18} className={!value.trim() ? 'opacity-50' : ''}/>}
</Button>
}
onChange={onChange}
value={value}
onKeyDown={handleKeyDown}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@ import { useChatMessages } from "@/hooks/useChat";
import ToolContainer from "../../tools/tool-container";
import { ToolResult } from "../../tools/tool-result";
import { useSlugRoutes } from "@/hooks/useSlugParams";
import { ArrowDownCircle } from "lucide-react";
import { useStream } from "@/hooks/useStream";


type Props = {
streamingMessage?: TMessage | null;
onRetry?: VoidFunction;
threadId: string;
stream: TStreamState;
startStream: ({
input,
thread_id,
assistant_id,
user_id,
}: TStartStreamProps) => Promise<void>;
};

function usePrevious<T>(value: T): T | undefined {
Expand All @@ -33,16 +41,16 @@ export default function MessagesContainer({
streamingMessage,
onRetry,
threadId,
stream
stream,
startStream,
}: Props) {
const { messages } = useChatMessages(threadId, stream);
const { messages, next } = useChatMessages(threadId, stream);
const prevMessages = usePrevious(messages);
const {assistantId} = useSlugRoutes();

const {data: selectedAssistant, isLoading: isLoadingAssistant} = useAssistant(assistantId as string, {
enabled: !!assistantId
})

const { assistantId } = useSlugRoutes();
const { data: selectedAssistant, isLoading: isLoadingAssistant } =
useAssistant(assistantId as string, {
enabled: !!assistantId,
});

const divRef = useRef<HTMLDivElement>(null);

Expand All @@ -57,11 +65,12 @@ export default function MessagesContainer({
: undefined,
});
}
}, [messages]);
}, [messages, prevMessages]);

return (
<div className="p-6 overflow-y-scroll" ref={divRef}>
{messages?.map((message, index) => {
console.log("Rendering message:", message);
const isToolCall =
message.tool_calls?.length && message.tool_calls.length > 0;

Expand All @@ -83,9 +92,28 @@ export default function MessagesContainer({
}

return (
<MessageItem message={message} assistant={selectedAssistant} key={`${message.id}-${index}`} />
<MessageItem
message={message}
assistant={selectedAssistant}
key={`${message.id}-${index}`}
/>
);
})}
{next.length > 0 && stream?.status !== "inflight" && (
<div
className="flex items-center rounded-md bg-blue-50 px-2 py-1 text-xs font-medium text-blue-800 ring-1 ring-inset ring-blue-600/20 cursor-pointer"
onClick={() =>
startStream({
input: null,
thread_id: threadId,
assistant_id: assistantId as string,
})
}
>
<ArrowDownCircle className="h-5 w-5 mr-1" />
Click to continue.
</div>
)}
</div>
);
}
38 changes: 26 additions & 12 deletions ui/src/hooks/useChat.ts
Original file line number Diff line number Diff line change
@@ -1,45 +1,59 @@
import { TMessage, TStreamState } from "@/data-provider/types";
import { useEffect } from "react";
import { useEffect, useRef } from "react";
import { mergeMessagesById } from "./useStream";
import { useThreadState } from "@/data-provider/query-service";
import { useAtom } from "jotai";
import { messagesAtom } from "@/store";

function usePrevious<T>(value: T): T | undefined {
const ref = useRef<T>();
useEffect(() => {
ref.current = value;
});
return ref.current;
}

export function useChatMessages(
threadId: string | null,
stream: TStreamState | null,
) {
const [streamedMessages, setStreamedMessages] = useAtom(messagesAtom)

const [streamedMessages, setStreamedMessages] = useAtom(messagesAtom);
const prevStreamStatus = usePrevious(stream?.status);

const { data: threadData, refetch, isFetched } = useThreadState(threadId as string, {
enabled: !!threadId
});

// Refetch messages after streaming
// Only refetch when transitioning from inflight to non-inflight
useEffect(() => {
if (stream?.status !== "inflight" && threadId) {
if (prevStreamStatus === "inflight" &&
stream?.status !== "inflight" &&
threadId) {
refetch();
}
}, [stream?.status, threadId, refetch]);
}, [stream?.status, threadId, refetch, prevStreamStatus]);

// Stop persisting streamed messages after streaming and message refetch
// Clear streamed messages after fetching thread state
useEffect(() => {
if (isFetched) {
setStreamedMessages([])
setStreamedMessages([]);
}
},[isFetched])
}, [isFetched, setStreamedMessages]);

// Update streamed messages during streaming
useEffect(() => {
if (stream?.messages) {
setStreamedMessages(stream.messages as TMessage[])
setStreamedMessages(stream.messages as TMessage[]);
}
}, [stream?.messages])
}, [stream?.messages, setStreamedMessages]);

const messages = threadData?.values ? threadData.values : null;
const next = threadData?.next || [];

return {
messages: mergeMessagesById(messages, streamedMessages),
next: threadData?.next || [],
next,
refreshMessages: refetch
};
}

19 changes: 17 additions & 2 deletions ui/src/hooks/useStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ export const useStream = () => {
async ({ input, thread_id, assistant_id, user_id }: TStartStreamProps) => {
const controller = new AbortController();
setController(controller);
setCurrentState({ status: "inflight", messages: [] });
// setCurrentState({ status: "inflight", messages: [] });
setCurrentState(prev => ({
status: "inflight",
messages: input === null ? prev?.messages || [] : [], // Keep messages for continuation
thread_id: prev?.thread_id
}));
setIsStreaming(true);

await fetchEventSource(
Expand All @@ -36,11 +41,18 @@ export const useStream = () => {
"Content-Type": "application/json",
"X-API-KEY": "personaflow_api_key",
},
body: JSON.stringify({ user_id, input, thread_id, assistant_id }),
body: JSON.stringify({
user_id,
input,
thread_id,
assistant_id
}),
openWhenHidden: true,
onmessage(msg) {
console.log("Stream event received:", msg.event, msg.data);
if (msg.event === "data") {
const messages = JSON.parse(msg.data);
console.log("Parsed messages:", messages);
setCurrentState((currentState) => ({
status: "inflight" as TStreamState["status"],
messages: mergeMessagesById(currentState?.messages, messages),
Expand Down Expand Up @@ -72,6 +84,7 @@ export const useStream = () => {
thread_id: currentState?.thread_id
}));
setController(null);
setIsStreaming(false);
},
onerror(error) {
setCurrentState((currentState) => ({
Expand All @@ -81,6 +94,7 @@ export const useStream = () => {
thread_id: currentState?.thread_id
}));
setController(null);
setIsStreaming(false);
throw error;
},
},
Expand All @@ -107,6 +121,7 @@ export const useStream = () => {
thread_id: currentState?.thread_id
}));
}
setIsStreaming(false);
},
[controller],
);
Expand Down