diff --git a/js/packages/teams-ai/src/models/OpenAIModel.ts b/js/packages/teams-ai/src/models/OpenAIModel.ts index c75df8264..fc1a96a5e 100644 --- a/js/packages/teams-ai/src/models/OpenAIModel.ts +++ b/js/packages/teams-ai/src/models/OpenAIModel.ts @@ -27,6 +27,7 @@ import { Tokenizer } from '../tokenizers'; import { ActionCall, PromptResponse } from '../types'; import { PromptCompletionModel, PromptCompletionModelEmitter } from './PromptCompletionModel'; +import { StreamingResponse } from '../StreamingResponse'; /** * Base model options common to both OpenAI and Azure OpenAI services. @@ -436,7 +437,8 @@ export class OpenAIModel implements PromptCompletionModel { // Signal response received const response: PromptResponse = { status: 'success', input, message }; - this._events.emit('responseReceived', context, memory, response); + const streamer: StreamingResponse = memory.getValue("temp.streamer"); + this._events.emit('responseReceived', context, memory, response, streamer); // Let any pending events flush before returning await new Promise((resolve) => setTimeout(resolve, 0)); diff --git a/js/packages/teams-ai/src/models/PromptCompletionModel.ts b/js/packages/teams-ai/src/models/PromptCompletionModel.ts index 3a457b0fb..538ad0cef 100644 --- a/js/packages/teams-ai/src/models/PromptCompletionModel.ts +++ b/js/packages/teams-ai/src/models/PromptCompletionModel.ts @@ -14,6 +14,7 @@ import { Tokenizer } from '../tokenizers'; import { PromptResponse } from '../types'; import { Memory } from '../MemoryFork'; import StrictEventEmitter from '../external/strict-event-emitter-types'; +import { StreamingResponse } from '../StreamingResponse'; /** * Events emitted by a PromptCompletionModel. @@ -51,7 +52,7 @@ export interface PromptCompletionModelEvents { * @param memory An interface for accessing state values. * @param response Final response returned by the model. */ - responseReceived: (context: TurnContext, memory: Memory, response: PromptResponse) => void; + responseReceived: (context: TurnContext, memory: Memory, response: PromptResponse, streamer: StreamingResponse) => void; } /** @@ -81,7 +82,8 @@ export type PromptCompletionModelChunkReceivedEvent = ( export type PromptCompletionModelResponseReceivedEvent = ( context: TurnContext, memory: Memory, - response: PromptResponse + response: PromptResponse, + streamer: StreamingResponse, ) => void; /** diff --git a/js/packages/teams-ai/src/planners/ActionPlanner.ts b/js/packages/teams-ai/src/planners/ActionPlanner.ts index d1797ad33..32a006aaf 100644 --- a/js/packages/teams-ai/src/planners/ActionPlanner.ts +++ b/js/packages/teams-ai/src/planners/ActionPlanner.ts @@ -11,7 +11,7 @@ import { TurnContext } from 'botbuilder'; import { AI } from '../AI'; import { DefaultAugmentation } from '../augmentations'; import { Memory } from '../MemoryFork'; -import { PromptCompletionModel } from '../models'; +import { PromptCompletionModel, PromptCompletionModelResponseReceivedEvent } from '../models'; import { PromptTemplate, PromptManager } from '../prompts'; import { Tokenizer } from '../tokenizers'; import { TurnState } from '../TurnState'; @@ -85,6 +85,11 @@ export interface ActionPlannerOptions { * Optional message to send a client at the start of a streaming response. */ startStreamingMessage?: string; + + /** + * Optional handler to run when a stream is about to conclude. + */ + endStreamHandler?: PromptCompletionModelResponseReceivedEvent } /** diff --git a/js/packages/teams-ai/src/planners/LLMClient.ts b/js/packages/teams-ai/src/planners/LLMClient.ts index 131658634..93848c3c9 100644 --- a/js/packages/teams-ai/src/planners/LLMClient.ts +++ b/js/packages/teams-ai/src/planners/LLMClient.ts @@ -13,7 +13,8 @@ import { Memory, MemoryFork } from '../MemoryFork'; import { PromptCompletionModel, PromptCompletionModelBeforeCompletionEvent, - PromptCompletionModelChunkReceivedEvent + PromptCompletionModelChunkReceivedEvent, + PromptCompletionModelResponseReceivedEvent } from '../models'; import { ConversationHistory, Message, Prompt, PromptFunctions, PromptTemplate } from '../prompts'; import { StreamingResponse } from '../StreamingResponse'; @@ -91,6 +92,11 @@ export interface LLMClientOptions { * Optional message to send a client at the start of a streaming response. */ startStreamingMessage?: string; + + /** + * Optional handler to run when a stream is about to conclude. + */ + endStreamHandler?: PromptCompletionModelResponseReceivedEvent } /** @@ -193,6 +199,7 @@ export interface ConfiguredLLMClientOptions { */ export class LLMClient { private readonly _startStreamingMessage: string | undefined; + private readonly _endStreamHandler: PromptCompletionModelResponseReceivedEvent | undefined; /** * Configured options for this LLMClient instance. @@ -226,6 +233,7 @@ export class LLMClient { } this._startStreamingMessage = options.startStreamingMessage; + this._endStreamHandler = options.endStreamHandler; } /** @@ -290,6 +298,7 @@ export class LLMClient { // Create streamer and send initial message streamer = new StreamingResponse(context); + memory.setValue("temp.streamer", streamer) if (this._startStreamingMessage) { streamer.queueInformativeUpdate(this._startStreamingMessage); } @@ -313,6 +322,10 @@ export class LLMClient { if (this.options.model.events) { this.options.model.events.on('beforeCompletion', beforeCompletion); this.options.model.events.on('chunkReceived', chunkReceived); + + if (this._endStreamHandler) { + this.options.model.events.on("responseReceived", this._endStreamHandler) + } } try { @@ -325,7 +338,7 @@ export class LLMClient { // End the stream if streaming // - We're not listening for the response received event because we can't await the completion of events. - if (streamer) { + if (streamer && !this._endStreamHandler) { await streamer.endStream(); } @@ -335,6 +348,10 @@ export class LLMClient { if (this.options.model.events) { this.options.model.events.off('beforeCompletion', beforeCompletion); this.options.model.events.off('chunkReceived', chunkReceived); + + if (this._endStreamHandler) { + this.options.model.events.off("responseReceived", this._endStreamHandler) + } } } }