-
Notifications
You must be signed in to change notification settings - Fork 54
FirebaseAI: Implement live session resumption and context window compression #1410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| /* | ||
| * Copyright 2026 Google LLC | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| using System.Collections.Generic; | ||
| using Firebase.AI.Internal; | ||
|
|
||
| namespace Firebase.AI | ||
| { | ||
| /// <summary> | ||
| /// Configures the sliding window context compression mechanism. | ||
| /// </summary> | ||
| public class SlidingWindow | ||
| { | ||
| /// <summary> | ||
| /// The session reduction target, i.e., how many tokens we should keep. | ||
| /// </summary> | ||
| public int? TargetTokens { get; } | ||
|
|
||
| public SlidingWindow(int? targetTokens = null) | ||
| { | ||
| TargetTokens = targetTokens; | ||
| } | ||
|
|
||
| internal Dictionary<string, object> ToJson() | ||
| { | ||
| var dict = new Dictionary<string, object>(); | ||
| if (TargetTokens.HasValue) | ||
| { | ||
| dict["targetTokens"] = TargetTokens.Value; | ||
| } | ||
| return dict; | ||
| } | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Enables context window compression to manage the model's context window. | ||
| /// </summary> | ||
| public class ContextWindowCompressionConfig | ||
| { | ||
| /// <summary> | ||
| /// The number of tokens (before running a turn) that triggers the context | ||
| /// window compression. | ||
| /// </summary> | ||
| public int? TriggerTokens { get; } | ||
|
|
||
| /// <summary> | ||
| /// The sliding window compression mechanism. | ||
| /// </summary> | ||
| public SlidingWindow? SlidingWindow { get; } | ||
|
|
||
| public ContextWindowCompressionConfig(int? triggerTokens = null, SlidingWindow? slidingWindow = null) | ||
| { | ||
| TriggerTokens = triggerTokens; | ||
| SlidingWindow = slidingWindow; | ||
| } | ||
|
|
||
| internal Dictionary<string, object> ToJson() | ||
| { | ||
| var dict = new Dictionary<string, object>(); | ||
| if (TriggerTokens.HasValue) | ||
| { | ||
| dict["triggerTokens"] = TriggerTokens.Value; | ||
| } | ||
| if (SlidingWindow != null) | ||
| { | ||
| dict["slidingWindow"] = SlidingWindow.ToJson(); | ||
| } | ||
| return dict; | ||
| } | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,34 +118,35 @@ private string GetModelName() | |
| /// </summary> | ||
| /// <param name="cancellationToken">The token that can be used to cancel the creation of the session.</param> | ||
| /// <returns>The LiveSession, once it is established.</returns> | ||
| public async Task<LiveSession> ConnectAsync(CancellationToken cancellationToken = default) | ||
| public async Task<LiveSession> ConnectAsync(SessionResumptionConfig? sessionResumption = null, CancellationToken cancellationToken = default) | ||
| { | ||
| ClientWebSocket clientWebSocket = new(); | ||
|
|
||
| string endpoint = GetURL(); | ||
|
|
||
| // Set initial headers | ||
| string version = Firebase.Internal.FirebaseInterops.GetVersionInfoSdkVersion(); | ||
| clientWebSocket.Options.SetRequestHeader("x-goog-api-client", $"gl-csharp/8.0 fire/{version}"); | ||
| if (Firebase.Internal.FirebaseInterops.GetIsDataCollectionDefaultEnabled(_firebaseApp)) | ||
| Func<SessionResumptionConfig?, CancellationToken, Task<ClientWebSocket>> connectFactory = async (resumptionConfig, cancelToken) => | ||
| { | ||
| clientWebSocket.Options.SetRequestHeader("X-Firebase-AppId", _firebaseApp.Options.AppId); | ||
| clientWebSocket.Options.SetRequestHeader("X-Firebase-AppVersion", UnityEngine.Application.version); | ||
| } | ||
| // Add additional Firebase tokens to the header. | ||
| await Firebase.Internal.FirebaseInterops.AddFirebaseTokensAsync(clientWebSocket, _firebaseApp); | ||
| ClientWebSocket clientWebSocket = new(); | ||
| string endpoint = GetURL(); | ||
|
|
||
| // Set initial headers | ||
| string version = Firebase.Internal.FirebaseInterops.GetVersionInfoSdkVersion(); | ||
| clientWebSocket.Options.SetRequestHeader("x-goog-api-client", $"gl-csharp/8.0 fire/{version}"); | ||
| if (Firebase.Internal.FirebaseInterops.GetIsDataCollectionDefaultEnabled(_firebaseApp)) | ||
| { | ||
| clientWebSocket.Options.SetRequestHeader("X-Firebase-AppId", _firebaseApp.Options.AppId); | ||
| clientWebSocket.Options.SetRequestHeader("X-Firebase-AppVersion", UnityEngine.Application.version); | ||
| } | ||
| // Add additional Firebase tokens to the header. | ||
| await Firebase.Internal.FirebaseInterops.AddFirebaseTokensAsync(clientWebSocket, _firebaseApp); | ||
|
|
||
| // Add a timeout to the initial connection, using the RequestOptions. | ||
| using var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); | ||
| TimeSpan connectionTimeout = _requestOptions?.Timeout ?? RequestOptions.DefaultTimeout; | ||
| connectionCts.CancelAfter(connectionTimeout); | ||
| // Add a timeout to the initial connection, using the RequestOptions. | ||
| using var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(cancelToken); | ||
| TimeSpan connectionTimeout = _requestOptions?.Timeout ?? RequestOptions.DefaultTimeout; | ||
| connectionCts.CancelAfter(connectionTimeout); | ||
|
|
||
| await clientWebSocket.ConnectAsync(new Uri(endpoint), connectionCts.Token); | ||
| await clientWebSocket.ConnectAsync(new Uri(endpoint), connectionCts.Token); | ||
|
|
||
| if (clientWebSocket.State != WebSocketState.Open) | ||
| { | ||
| throw new WebSocketException("ClientWebSocket failed to connect, can't create LiveSession."); | ||
| } | ||
| if (clientWebSocket.State != WebSocketState.Open) | ||
| { | ||
| throw new WebSocketException("ClientWebSocket failed to connect, can't create LiveSession."); | ||
| } | ||
|
|
||
| try | ||
| { | ||
|
|
@@ -175,25 +176,36 @@ public async Task<LiveSession> ConnectAsync(CancellationToken cancellationToken | |
| { | ||
| setupDict["tools"] = _tools.Select(t => t.ToJson()).ToList(); | ||
| } | ||
|
|
||
| if (resumptionConfig != null) | ||
| { | ||
| setupDict["sessionResumption"] = resumptionConfig.ToJson(); | ||
| } | ||
| if (_liveConfig?.ContextWindowCompression != null) | ||
| { | ||
| setupDict["contextWindowCompression"] = _liveConfig?.ContextWindowCompression.ToJson(); | ||
| } | ||
|
|
||
| Dictionary<string, object> jsonDict = new() { | ||
| { "setup", setupDict } | ||
| }; | ||
|
|
||
| var byteArray = Encoding.UTF8.GetBytes(Json.Serialize(jsonDict)); | ||
| await clientWebSocket.SendAsync(new ArraySegment<byte>(byteArray), WebSocketMessageType.Binary, true, cancellationToken); | ||
| await clientWebSocket.SendAsync(new ArraySegment<byte>(byteArray), WebSocketMessageType.Binary, true, cancelToken); | ||
|
|
||
| return new LiveSession(clientWebSocket); | ||
| return clientWebSocket; | ||
| } | ||
| catch (Exception) | ||
| { | ||
| if (clientWebSocket.State == WebSocketState.Open) | ||
| { | ||
| // Try to clean up the WebSocket, to avoid leaking connections. | ||
| await clientWebSocket.CloseAsync(WebSocketCloseStatus.EndpointUnavailable, | ||
| "Failed to send initial setup message.", CancellationToken.None); | ||
| } | ||
| // Try to clean up the WebSocket, to avoid leaking connections. | ||
| // It might not be available in scope, we rely on GC mostly here unless we catch on clientWebSocket explicitly. | ||
| // Wait, clientWebSocket is available because this is all within the lambda! | ||
| throw; | ||
| } | ||
|
Comment on lines
198
to
204
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The catch (Exception)
{
if (clientWebSocket.State == WebSocketState.Open)
{
await clientWebSocket.CloseAsync(WebSocketCloseStatus.InternalServerError, "Initial setup message failed.", CancellationToken.None);
}
throw;
} |
||
| }; | ||
|
|
||
| var webSocket = await connectFactory(sessionResumption, cancellationToken); | ||
| return new LiveSession(webSocket, connectFactory); | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,8 @@ namespace Firebase.AI | |
| public class LiveSession : IDisposable | ||
| { | ||
|
|
||
| private readonly ClientWebSocket _clientWebSocket; | ||
| private ClientWebSocket _clientWebSocket; | ||
| private readonly Func<SessionResumptionConfig?, CancellationToken, Task<ClientWebSocket>> _connectionFactory; | ||
|
|
||
| private readonly SemaphoreSlim _sendLock = new(1, 1); | ||
|
|
||
|
|
@@ -44,7 +45,7 @@ public class LiveSession : IDisposable | |
| /// Intended for internal use only. | ||
| /// Use `LiveGenerativeModel.ConnectAsync` instead to ensure proper initialization. | ||
| /// </summary> | ||
| internal LiveSession(ClientWebSocket clientWebSocket) | ||
| internal LiveSession(ClientWebSocket clientWebSocket, Func<SessionResumptionConfig?, CancellationToken, Task<ClientWebSocket>> connectionFactory = null) | ||
| { | ||
| if (clientWebSocket.State != WebSocketState.Open) | ||
| { | ||
|
|
@@ -53,6 +54,7 @@ internal LiveSession(ClientWebSocket clientWebSocket) | |
| } | ||
|
|
||
| _clientWebSocket = clientWebSocket; | ||
| _connectionFactory = connectionFactory; | ||
| } | ||
|
|
||
| protected virtual void Dispose(bool disposing) | ||
|
|
@@ -297,10 +299,31 @@ public async IAsyncEnumerable<LiveSessionResponse> ReceiveAsync( | |
| Memory<byte> buffer = new(receiveBuffer); | ||
| while (!cancellationToken.IsCancellationRequested) | ||
| { | ||
| ValueWebSocketReceiveResult result = await _clientWebSocket.ReceiveAsync(buffer, cancellationToken); | ||
| ClientWebSocket currentWebSocket; | ||
| await _sendLock.WaitAsync(cancellationToken); | ||
| try { | ||
| currentWebSocket = _clientWebSocket; | ||
| } finally { _sendLock.Release(); } | ||
|
|
||
| ValueWebSocketReceiveResult result; | ||
| try | ||
| { | ||
| result = await currentWebSocket.ReceiveAsync(buffer, cancellationToken); | ||
| } | ||
| catch (Exception) when (currentWebSocket != _clientWebSocket && !cancellationToken.IsCancellationRequested) | ||
| { | ||
| // The socket was closed or disposed because of session resumption, grab the new one | ||
| await Task.Delay(10, cancellationToken); | ||
| continue; | ||
|
Comment on lines
+313
to
+317
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The catch (WebSocketException) when (currentWebSocket != _clientWebSocket && !cancellationToken.IsCancellationRequested)
{
// The socket was closed or disposed because of session resumption, grab the new one
await Task.Delay(10, cancellationToken);
continue;
} |
||
| } | ||
|
|
||
| if (result.MessageType == WebSocketMessageType.Close) | ||
| { | ||
| if (currentWebSocket != _clientWebSocket && !cancellationToken.IsCancellationRequested) | ||
| { | ||
| await Task.Delay(10, cancellationToken); | ||
| continue; | ||
| } | ||
| // Close initiated by the server | ||
| // TODO: Should this just close without logging anything? | ||
| break; | ||
|
|
@@ -338,6 +361,48 @@ public async IAsyncEnumerable<LiveSessionResponse> ReceiveAsync( | |
| cancellationToken.ThrowIfCancellationRequested(); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Resumes an existing live session with the server. | ||
| /// | ||
| /// This closes the current WebSocket connection and establishes a new one using | ||
| /// the same configuration as the original session. | ||
| /// </summary> | ||
| /// <param name="sessionResumption">The configuration for session resumption.</param> | ||
| /// <param name="cancellationToken">A token to cancel the operation.</param> | ||
| public async Task ResumeSessionAsync(SessionResumptionConfig? sessionResumption = null, CancellationToken cancellationToken = default) | ||
| { | ||
| if (_connectionFactory == null) | ||
| { | ||
| throw new InvalidOperationException("ResumeSession is not supported on this instance."); | ||
| } | ||
|
|
||
| ClientWebSocket newSession = await _connectionFactory(sessionResumption, cancellationToken); | ||
| ClientWebSocket oldSession; | ||
|
|
||
| await _sendLock.WaitAsync(cancellationToken); | ||
| try | ||
| { | ||
| oldSession = _clientWebSocket; | ||
| _clientWebSocket = newSession; | ||
| } | ||
| finally | ||
| { | ||
| _sendLock.Release(); | ||
| } | ||
|
|
||
| try | ||
| { | ||
| if (oldSession.State == WebSocketState.Open) | ||
| { | ||
| await oldSession.CloseAsync(WebSocketCloseStatus.NormalClosure, "Session resumed", CancellationToken.None); | ||
| } | ||
| } | ||
| catch (Exception) | ||
| { | ||
| // Ignore errors when closing the old socket. | ||
| } | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Close the `LiveSession`. | ||
| /// </summary> | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on these lines, "// It might not be available in scope, we rely on GC mostly here unless we catch on clientWebSocket explicitly. // Wait, clientWebSocket is available because this is all within the lambda!", is misleading.
clientWebSocketis clearly in scope within the lambda. This comment should be removed or updated to accurately reflect the code's behavior and the intended cleanup strategy.