Skip to content

Commit

Permalink
fix: streaming ws message + i18n updates (#454)
Browse files Browse the repository at this point in the history
* i18n updates

* feat: streaming message improvements

* fixes

* fixes

* adjust speed
  • Loading branch information
paulclindo authored Sep 24, 2024
1 parent f947032 commit 760e1fe
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 85 deletions.
195 changes: 111 additions & 84 deletions apps/shinkai-desktop/src/components/chat/message-stream.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,97 @@ import { ShinkaiMessageBuilderWrapper } from '@shinkai_network/shinkai-message-t
import { FunctionKey } from '@shinkai_network/shinkai-node-state/lib/constants';
import { Message } from '@shinkai_network/shinkai-ui';
import { useQueryClient } from '@tanstack/react-query';
import { useCallback, useEffect, useRef, useState } from 'react';
import { useEffect, useRef, useState } from 'react';
import { useParams } from 'react-router-dom';
import useWebSocket from 'react-use-websocket';

import { useAuth } from '../../store/auth';

type AnimationState = {
displayedContent: string;
pendingContent: string;
};

type UseWebSocketMessage = {
enabled?: boolean;
};

const START_ANIMATION_SPEED = 4;
const END_ANIMATION_SPEED = 15;

const createSmoothMessage = (params: {
onTextUpdate: (delta: string, text: string) => void;
onFinished?: () => void;
startSpeed?: number;
}) => {
const { startSpeed = START_ANIMATION_SPEED } = params;

let buffer = '';
const outputQueue: string[] = [];
let isAnimationActive = false;
let animationFrameId: number | null = null;

const stopAnimation = () => {
isAnimationActive = false;
if (animationFrameId !== null) {
cancelAnimationFrame(animationFrameId);
animationFrameId = null;
}
};

const startAnimation = (speed = startSpeed) =>
new Promise<void>((resolve) => {
if (isAnimationActive) {
resolve();
return;
}

isAnimationActive = true;

const updateText = () => {
if (!isAnimationActive) {
cancelAnimationFrame(animationFrameId as number);
animationFrameId = null;
resolve();
return;
}

if (outputQueue.length > 0) {
const charsToAdd = outputQueue.splice(0, speed).join('');
buffer += charsToAdd;
params.onTextUpdate(charsToAdd, buffer);
} else {
isAnimationActive = false;
animationFrameId = null;
params.onFinished?.();
resolve();
return;
}
animationFrameId = requestAnimationFrame(updateText);
};

animationFrameId = requestAnimationFrame(updateText);
});

const pushToQueue = (text: string) => {
outputQueue.push(...text.split(''));
};

const reset = () => {
buffer = '';
outputQueue.length = 0;
stopAnimation();
};

return {
isAnimationActive,
isTokenRemain: () => outputQueue.length > 0,
pushToQueue,
startAnimation,
stopAnimation,
reset,
};
};

export const useWebSocketMessage = ({ enabled }: UseWebSocketMessage) => {
const auth = useAuth((state) => state.auth);
const nodeAddressUrl = new URL(auth?.node_address ?? 'http://localhost:9850');
Expand All @@ -32,68 +108,57 @@ export const useWebSocketMessage = ({ enabled }: UseWebSocketMessage) => {
const { inboxId: encodedInboxId = '' } = useParams();
const inboxId = decodeURIComponent(encodedInboxId);

const isStreamFinishedRef = useRef(false);
const [animationState, setAnimationState] = useState<AnimationState>({
displayedContent: '',
pendingContent: '',
});
const isStreamingFinished = useRef(false);

const animationFrameRef = useRef<number | null>(null);
const isAnimatingRef = useRef(false);

const animateText = useCallback(() => {
setAnimationState((prevState) => {
if (
prevState.pendingContent.length === 0 &&
isStreamFinishedRef.current
) {
isAnimatingRef.current = false;
return prevState;
}

const chunkSize = Math.max(
1,
Math.round(prevState.pendingContent.length / 90),
);

const nextChunk = prevState.pendingContent.slice(0, chunkSize);
const remainingPending = prevState.pendingContent.slice(chunkSize);
const textControllerRef = useRef<ReturnType<
typeof createSmoothMessage
> | null>(null);

return {
displayedContent: prevState.displayedContent + nextChunk,
pendingContent: remainingPending,
};
useEffect(() => {
textControllerRef.current = createSmoothMessage({
onTextUpdate: (_, text) => {
if (isStreamingFinished.current) return;
setAnimationState({
displayedContent: text,
});
},
});

if (isAnimatingRef.current) {
animationFrameRef.current = requestAnimationFrame(animateText);
}
}, []);

const startAnimation = useCallback(() => {
if (!isAnimatingRef.current) {
isAnimatingRef.current = true;
animationFrameRef.current = requestAnimationFrame(animateText);
}
}, [animateText]);

useEffect(() => {
if (!enabled) return;
if (!textControllerRef.current) return;
if (lastMessage?.data) {
try {
const parseData: WsMessage = JSON.parse(lastMessage.data);
if (parseData.message_type !== 'Stream') return;
isStreamFinishedRef.current = false;
isStreamingFinished.current = false;
if (parseData.metadata?.is_done === true) {
isStreamFinishedRef.current = true;
textControllerRef.current.stopAnimation();
if (textControllerRef.current.isTokenRemain()) {
textControllerRef.current.startAnimation(END_ANIMATION_SPEED);
}

const paginationKey = [
FunctionKey.GET_CHAT_CONVERSATION_PAGINATION,
{ inboxId: inboxId as string },
];
queryClient.invalidateQueries({ queryKey: paginationKey });
isStreamingFinished.current = true;
// TODO: unify streaming message as part of messages cache to avoid layout shift
setTimeout(() => {
setAnimationState({ displayedContent: '' });
textControllerRef.current?.reset();
}, 600);
}

setAnimationState((prevState) => ({
...prevState,
pendingContent: prevState.pendingContent + parseData.message,
}));
textControllerRef.current?.pushToQueue(parseData.message);

startAnimation();
if (!textControllerRef.current.isAnimationActive)
textControllerRef.current.startAnimation();
} catch (error) {
console.error('Failed to parse ws message', error);
}
Expand All @@ -106,7 +171,6 @@ export const useWebSocketMessage = ({ enabled }: UseWebSocketMessage) => {
inboxId,
lastMessage?.data,
queryClient,
startAnimation,
]);

useEffect(() => {
Expand Down Expand Up @@ -138,45 +202,8 @@ export const useWebSocketMessage = ({ enabled }: UseWebSocketMessage) => {
sendMessage,
]);

useEffect(() => {
return () => {
if (animationFrameRef.current) {
cancelAnimationFrame(animationFrameRef.current);
}
};
}, []);

useEffect(() => {
if (isStreamFinishedRef.current) {
const paginationKey = [
FunctionKey.GET_CHAT_CONVERSATION_PAGINATION,
{
nodeAddress: auth?.node_address ?? '',
inboxId: inboxId as string,
shinkaiIdentity: auth?.shinkai_identity ?? '',
profile: auth?.profile ?? '',
},
];
queryClient.invalidateQueries({ queryKey: paginationKey });
setTimeout(() => {
setAnimationState({
displayedContent: '',
pendingContent: '',
});
}, 500);
}
}, [
isStreamFinishedRef.current,
auth?.node_address,
auth?.profile,
auth?.shinkai_identity,
inboxId,
queryClient,
]);

return {
messageContent: animationState.displayedContent,
isStreamFinished: isStreamFinishedRef.current,
readyState,
};
};
Expand Down
7 changes: 6 additions & 1 deletion libs/shinkai-i18n/locales/en-US.json
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@
"createRegistrationCode": "Create Registration Code",
"analytics": "Analytics",
"publicKeys": "Public Keys",
"galxe": "Galxe Quest"
"galxe": "Galxe Quest",
"promptLibrary": "Prompt Library"
},
"shinkaiIdentity": {
"label": "Shinkai Identity",
Expand Down Expand Up @@ -319,6 +320,9 @@
"profileIdentity": "Profile Identity",
"myDeviceEncryption": "My Device Encryption",
"myDeviceIdentity": "My Device Identity"
},
"promptLibrary": {
"label": "Prompt Library"
}
},
"exportConnection": {
Expand Down Expand Up @@ -450,6 +454,7 @@
"modelName": "Model Name",
"modelId": "Model ID",
"modelType": "Model Type",
"customModelType": "Custom Model Type",
"toggleCustomModel": "Add a custom model",
"selectModel": "Select your Model"
},
Expand Down
5 changes: 5 additions & 0 deletions libs/shinkai-i18n/locales/es-ES.json
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@
"form": {
"agentName": "Nombre de la IA",
"apiKey": "Clave API",
"customModelType": "Tipo de Modelo Personalizado",
"externalUrl": "URL Externa",
"modelId": "ID del Modelo",
"modelName": "Nombre del Modelo",
Expand Down Expand Up @@ -329,9 +330,13 @@
"exportConnection": "Exportar Conexión",
"galxe": "Misión Galxe",
"general": "General",
"promptLibrary": "Biblioteca de Prompts",
"publicKeys": "Claves Públicas",
"shinkaiNode": "Administrador de Nodo Shinkai"
},
"promptLibrary": {
"label": "Biblioteca de Prompts"
},
"publicKeys": {
"label": "Claves Públicas",
"myDeviceEncryption": "Cifrado de Mi Dispositivo",
Expand Down
5 changes: 5 additions & 0 deletions libs/shinkai-i18n/locales/id-ID.json
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@
"form": {
"agentName": "Nama AI",
"apiKey": "Kunci API",
"customModelType": "Tipe Model Kustom",
"externalUrl": "URL Eksternal",
"modelId": "ID Model",
"modelName": "Nama Model",
Expand Down Expand Up @@ -329,9 +330,13 @@
"exportConnection": "Ekspor Koneksi",
"galxe": "Pertualangan Galxe",
"general": "Umum",
"promptLibrary": "Perpustakaan Prompt",
"publicKeys": "Kunci Publik",
"shinkaiNode": "Manajer Node Shinkai"
},
"promptLibrary": {
"label": "Perpustakaan Prompt"
},
"publicKeys": {
"label": "Kunci Publik",
"myDeviceEncryption": "Enkripsi Perangkat Saya",
Expand Down
5 changes: 5 additions & 0 deletions libs/shinkai-i18n/locales/ja-JP.json
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@
"form": {
"agentName": "AI名",
"apiKey": "APIキー",
"customModelType": "カスタムモデルタイプ",
"externalUrl": "外部URL",
"modelId": "モデルID",
"modelName": "モデル名",
Expand Down Expand Up @@ -329,9 +330,13 @@
"exportConnection": "接続のエクスポート",
"galxe": "ガルクス・クエスト",
"general": "一般",
"promptLibrary": "プロンプトライブラリ",
"publicKeys": "公開鍵",
"shinkaiNode": "シンカイ・ノード・マネージャー"
},
"promptLibrary": {
"label": "プロンプトライブラリ"
},
"publicKeys": {
"label": "公開鍵",
"myDeviceEncryption": "私のデバイスの暗号化",
Expand Down
5 changes: 5 additions & 0 deletions libs/shinkai-i18n/locales/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@
"form": {
"agentName": "人工智能名称",
"apiKey": "API密钥",
"customModelType": "自定义模型类型",
"externalUrl": "外部URL",
"modelId": "模型ID",
"modelName": "模型名称",
Expand Down Expand Up @@ -329,9 +330,13 @@
"exportConnection": "导出连接",
"galxe": "Galxe任务",
"general": "常规",
"promptLibrary": "提示库",
"publicKeys": "公钥",
"shinkaiNode": "深海节点管理器"
},
"promptLibrary": {
"label": "提示库"
},
"publicKeys": {
"label": "公钥",
"myDeviceEncryption": "我的设备加密",
Expand Down

0 comments on commit 760e1fe

Please sign in to comment.