Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add a request index to the streamed chat message content #10129

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ private async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMess
// This scenario should not happen but I leave it as a precaution
state.AutoInvoke = false;
// We return the first message
yield return this.GetStreamingChatContentFromChatContent(messageContent);
yield return this.GetStreamingChatContentFromChatContent(messageContent, state.Iteration);
// We return the second message
messageContent = chatResponsesEnumerator.Current;
yield return this.GetStreamingChatContentFromChatContent(messageContent);
yield return this.GetStreamingChatContentFromChatContent(messageContent, state.Iteration);
continue;
}

Expand All @@ -356,7 +356,7 @@ private async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMess
state.AutoInvoke = false;

// If we don't want to attempt to invoke any functions, just return the result.
yield return this.GetStreamingChatContentFromChatContent(messageContent);
yield return this.GetStreamingChatContentFromChatContent(messageContent, state.Iteration);
}
}
finally
Expand Down Expand Up @@ -617,7 +617,7 @@ private static GeminiRequest CreateRequest(
return geminiRequest;
}

private GeminiStreamingChatMessageContent GetStreamingChatContentFromChatContent(GeminiChatMessageContent message)
private GeminiStreamingChatMessageContent GetStreamingChatContentFromChatContent(GeminiChatMessageContent message, int requestIndex)
{
if (message.CalledToolResult is not null)
{
Expand All @@ -627,7 +627,10 @@ private GeminiStreamingChatMessageContent GetStreamingChatContentFromChatContent
modelId: this._modelId,
calledToolResult: message.CalledToolResult,
metadata: message.Metadata,
choiceIndex: message.Metadata?.Index ?? 0);
choiceIndex: message.Metadata?.Index ?? 0)
{
RequestIndex = requestIndex
};
}

if (message.ToolCalls is not null)
Expand All @@ -638,15 +641,21 @@ private GeminiStreamingChatMessageContent GetStreamingChatContentFromChatContent
modelId: this._modelId,
toolCalls: message.ToolCalls,
metadata: message.Metadata,
choiceIndex: message.Metadata?.Index ?? 0);
choiceIndex: message.Metadata?.Index ?? 0)
{
RequestIndex = requestIndex
};
}

return new GeminiStreamingChatMessageContent(
role: message.Role,
content: message.Content,
modelId: this._modelId,
choiceIndex: message.Metadata?.Index ?? 0,
metadata: message.Metadata);
metadata: message.Metadata)
{
RequestIndex = requestIndex
};
}

private static void ValidateAutoInvoke(bool autoInvoke, int resultsPerPrompt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
IAsyncEnumerable<StreamingChatMessageContent> response;
try
{
response = this.StreamChatMessageContentsAsync(chatHistory, mistralExecutionSettings, chatRequest, modelId, cancellationToken);
response = this.StreamChatMessageContentsAsync(chatHistory, mistralExecutionSettings, chatRequest, modelId, requestIndex, cancellationToken);
}
catch (Exception e) when (activity is not null)
{
Expand Down Expand Up @@ -459,7 +459,7 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes

var lastChatMessage = chatHistory.Last();

yield return new StreamingChatMessageContent(lastChatMessage.Role, lastChatMessage.Content);
yield return new StreamingChatMessageContent(lastChatMessage.Role, lastChatMessage.Content) { RequestIndex = requestIndex };
yield break;
}
}
Expand Down Expand Up @@ -498,21 +498,21 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
}
}

private async IAsyncEnumerable<StreamingChatMessageContent> StreamChatMessageContentsAsync(ChatHistory chatHistory, MistralAIPromptExecutionSettings executionSettings, ChatCompletionRequest chatRequest, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken)
private async IAsyncEnumerable<StreamingChatMessageContent> StreamChatMessageContentsAsync(ChatHistory chatHistory, MistralAIPromptExecutionSettings executionSettings, ChatCompletionRequest chatRequest, string modelId, int requestIndex, [EnumeratorCancellation] CancellationToken cancellationToken)
{
this.ValidateChatHistory(chatHistory);

var endpoint = this.GetEndpoint(executionSettings, path: "chat/completions");
using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: true);
using var response = await this.SendStreamingRequestAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false);
await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false))
await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, requestIndex, cancellationToken).ConfigureAwait(false))
{
yield return streamingChatContent;
}
}

private async IAsyncEnumerable<StreamingChatMessageContent> ProcessChatResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken)
private async IAsyncEnumerable<StreamingChatMessageContent> ProcessChatResponseStreamAsync(Stream stream, string modelId, int requestIndex, [EnumeratorCancellation] CancellationToken cancellationToken)
{
IAsyncEnumerator<MistralChatCompletionChunk>? responseEnumerator = null;

Expand All @@ -536,7 +536,10 @@ private async IAsyncEnumerable<StreamingChatMessageContent> ProcessChatResponseS
modelId: modelId,
encoding: chunk.GetEncoding(),
innerContent: chunk,
metadata: chunk.GetMetadata());
metadata: chunk.GetMetadata())
{
RequestIndex = requestIndex
};
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
OpenAIFunctionToolCall.TrackStreamingToolingUpdate(chatCompletionUpdate.ToolCallUpdates, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
}

var openAIStreamingChatMessageContent = new OpenAIStreamingChatMessageContent(chatCompletionUpdate, 0, targetModel, metadata);
var openAIStreamingChatMessageContent = new OpenAIStreamingChatMessageContent(chatCompletionUpdate, 0, targetModel, metadata) { RequestIndex = requestIndex };

if (openAIStreamingChatMessageContent.ToolCallUpdates is not null)
{
Expand Down Expand Up @@ -383,7 +383,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC

if (lastMessage != null)
{
yield return new OpenAIStreamingChatMessageContent(lastMessage.Role, lastMessage.Content);
yield return new OpenAIStreamingChatMessageContent(lastMessage.Role, lastMessage.Content) { RequestIndex = requestIndex };
yield break;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ public Encoding Encoding
}
}

/// <summary>
/// Index of the request that produced this message content
/// </summary>
public int RequestIndex { get; init; } = 0;

/// <summary>
/// Initializes a new instance of the <see cref="StreamingChatMessageContent"/> class.
/// </summary>
Expand Down
Loading