Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions firebaseai/src/ContextWindowCompressionConfig.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Collections.Generic;
using Firebase.AI.Internal;

namespace Firebase.AI
{
/// <summary>
/// Configures the sliding window context compression mechanism.
/// </summary>
public class SlidingWindow
{
/// <summary>
/// The session reduction target, i.e., how many tokens we should keep.
/// </summary>
public int? TargetTokens { get; }

public SlidingWindow(int? targetTokens = null)
{
TargetTokens = targetTokens;
}

internal Dictionary<string, object> ToJson()
{
var dict = new Dictionary<string, object>();
if (TargetTokens.HasValue)
{
dict["targetTokens"] = TargetTokens.Value;
}
return dict;
}
}

/// <summary>
/// Enables context window compression to manage the model's context window.
/// </summary>
public class ContextWindowCompressionConfig
{
/// <summary>
/// The number of tokens (before running a turn) that triggers the context
/// window compression.
/// </summary>
public int? TriggerTokens { get; }

/// <summary>
/// The sliding window compression mechanism.
/// </summary>
public SlidingWindow? SlidingWindow { get; }

public ContextWindowCompressionConfig(int? triggerTokens = null, SlidingWindow? slidingWindow = null)
{
TriggerTokens = triggerTokens;
SlidingWindow = slidingWindow;
}

internal Dictionary<string, object> ToJson()
{
var dict = new Dictionary<string, object>();
if (TriggerTokens.HasValue)
{
dict["triggerTokens"] = TriggerTokens.Value;
}
if (SlidingWindow != null)
{
dict["slidingWindow"] = SlidingWindow.ToJson();
}
return dict;
}
}
}
7 changes: 6 additions & 1 deletion firebaseai/src/LiveGenerationConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ public readonly struct LiveGenerationConfig
private readonly float? _frequencyPenalty;
private readonly AudioTranscriptionConfig? _inputAudioTranscription;
private readonly AudioTranscriptionConfig? _outputAudioTranscription;
private readonly ContextWindowCompressionConfig? _contextWindowCompression;

internal readonly AudioTranscriptionConfig? InputAudioTranscription => _inputAudioTranscription;
internal readonly AudioTranscriptionConfig? OutputAudioTranscription => _outputAudioTranscription;
internal readonly ContextWindowCompressionConfig? ContextWindowCompression => _contextWindowCompression;

/// <summary>
/// Creates a new `LiveGenerationConfig` value.
Expand Down Expand Up @@ -191,7 +193,8 @@ public LiveGenerationConfig(
float? presencePenalty = null,
float? frequencyPenalty = null,
AudioTranscriptionConfig? inputAudioTranscription = null,
AudioTranscriptionConfig? outputAudioTranscription = null)
AudioTranscriptionConfig? outputAudioTranscription = null,
ContextWindowCompressionConfig? contextWindowCompression = null)
{
_speechConfig = speechConfig;
_responseModalities = responseModalities != null ?
Expand All @@ -204,6 +207,7 @@ public LiveGenerationConfig(
_frequencyPenalty = frequencyPenalty;
_inputAudioTranscription = inputAudioTranscription;
_outputAudioTranscription = outputAudioTranscription;
_contextWindowCompression = contextWindowCompression;
}

/// <summary>
Expand All @@ -225,6 +229,7 @@ internal Dictionary<string, object> ToJson()
if (_maxOutputTokens.HasValue) jsonDict["maxOutputTokens"] = _maxOutputTokens.Value;
if (_presencePenalty.HasValue) jsonDict["presencePenalty"] = _presencePenalty.Value;
if (_frequencyPenalty.HasValue) jsonDict["frequencyPenalty"] = _frequencyPenalty.Value;
if (_contextWindowCompression != null) jsonDict["contextWindowCompression"] = _contextWindowCompression.ToJson();

return jsonDict;
}
Expand Down
74 changes: 43 additions & 31 deletions firebaseai/src/LiveGenerativeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,34 +118,35 @@ private string GetModelName()
/// </summary>
/// <param name="cancellationToken">The token that can be used to cancel the creation of the session.</param>
/// <returns>The LiveSession, once it is established.</returns>
public async Task<LiveSession> ConnectAsync(CancellationToken cancellationToken = default)
public async Task<LiveSession> ConnectAsync(SessionResumptionConfig? sessionResumption = null, CancellationToken cancellationToken = default)
{
ClientWebSocket clientWebSocket = new();

string endpoint = GetURL();

// Set initial headers
string version = Firebase.Internal.FirebaseInterops.GetVersionInfoSdkVersion();
clientWebSocket.Options.SetRequestHeader("x-goog-api-client", $"gl-csharp/8.0 fire/{version}");
if (Firebase.Internal.FirebaseInterops.GetIsDataCollectionDefaultEnabled(_firebaseApp))
Func<SessionResumptionConfig?, CancellationToken, Task<ClientWebSocket>> connectFactory = async (resumptionConfig, cancelToken) =>
{
clientWebSocket.Options.SetRequestHeader("X-Firebase-AppId", _firebaseApp.Options.AppId);
clientWebSocket.Options.SetRequestHeader("X-Firebase-AppVersion", UnityEngine.Application.version);
}
// Add additional Firebase tokens to the header.
await Firebase.Internal.FirebaseInterops.AddFirebaseTokensAsync(clientWebSocket, _firebaseApp);
ClientWebSocket clientWebSocket = new();
string endpoint = GetURL();

// Set initial headers
string version = Firebase.Internal.FirebaseInterops.GetVersionInfoSdkVersion();
clientWebSocket.Options.SetRequestHeader("x-goog-api-client", $"gl-csharp/8.0 fire/{version}");
if (Firebase.Internal.FirebaseInterops.GetIsDataCollectionDefaultEnabled(_firebaseApp))
{
clientWebSocket.Options.SetRequestHeader("X-Firebase-AppId", _firebaseApp.Options.AppId);
clientWebSocket.Options.SetRequestHeader("X-Firebase-AppVersion", UnityEngine.Application.version);
}
// Add additional Firebase tokens to the header.
await Firebase.Internal.FirebaseInterops.AddFirebaseTokensAsync(clientWebSocket, _firebaseApp);

// Add a timeout to the initial connection, using the RequestOptions.
using var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
TimeSpan connectionTimeout = _requestOptions?.Timeout ?? RequestOptions.DefaultTimeout;
connectionCts.CancelAfter(connectionTimeout);
// Add a timeout to the initial connection, using the RequestOptions.
using var connectionCts = CancellationTokenSource.CreateLinkedTokenSource(cancelToken);
TimeSpan connectionTimeout = _requestOptions?.Timeout ?? RequestOptions.DefaultTimeout;
connectionCts.CancelAfter(connectionTimeout);

await clientWebSocket.ConnectAsync(new Uri(endpoint), connectionCts.Token);
await clientWebSocket.ConnectAsync(new Uri(endpoint), connectionCts.Token);

if (clientWebSocket.State != WebSocketState.Open)
{
throw new WebSocketException("ClientWebSocket failed to connect, can't create LiveSession.");
}
if (clientWebSocket.State != WebSocketState.Open)
{
throw new WebSocketException("ClientWebSocket failed to connect, can't create LiveSession.");
}

try
{
Expand Down Expand Up @@ -175,25 +176,36 @@ public async Task<LiveSession> ConnectAsync(CancellationToken cancellationToken
{
setupDict["tools"] = _tools.Select(t => t.ToJson()).ToList();
}

if (resumptionConfig != null)
{
setupDict["sessionResumption"] = resumptionConfig.ToJson();
}
if (_liveConfig?.ContextWindowCompression != null)
{
setupDict["contextWindowCompression"] = _liveConfig?.ContextWindowCompression.ToJson();
}

Dictionary<string, object> jsonDict = new() {
{ "setup", setupDict }
};

var byteArray = Encoding.UTF8.GetBytes(Json.Serialize(jsonDict));
await clientWebSocket.SendAsync(new ArraySegment<byte>(byteArray), WebSocketMessageType.Binary, true, cancellationToken);
await clientWebSocket.SendAsync(new ArraySegment<byte>(byteArray), WebSocketMessageType.Binary, true, cancelToken);

return new LiveSession(clientWebSocket);
return clientWebSocket;
}
catch (Exception)
{
if (clientWebSocket.State == WebSocketState.Open)
{
// Try to clean up the WebSocket, to avoid leaking connections.
await clientWebSocket.CloseAsync(WebSocketCloseStatus.EndpointUnavailable,
"Failed to send initial setup message.", CancellationToken.None);
}
// Try to clean up the WebSocket, to avoid leaking connections.
// It might not be available in scope, we rely on GC mostly here unless we catch on clientWebSocket explicitly.
// Wait, clientWebSocket is available because this is all within the lambda!
throw;
Comment on lines +200 to 203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on these lines, "// It might not be available in scope, we rely on GC mostly here unless we catch on clientWebSocket explicitly. // Wait, clientWebSocket is available because this is all within the lambda!", is misleading. clientWebSocket is clearly in scope within the lambda. This comment should be removed or updated to accurately reflect the code's behavior and the intended cleanup strategy.

        // Try to clean up the WebSocket, to avoid leaking connections.
        throw;

}
Comment on lines 198 to 204
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The catch (Exception) block within the connectFactory lambda does not explicitly close the clientWebSocket if an exception occurs during the SendAsync operation. This could lead to a resource leak where an open WebSocket connection is not properly disposed of if the initial setup message fails to send.

      catch (Exception)
      {
        if (clientWebSocket.State == WebSocketState.Open)
        {
          await clientWebSocket.CloseAsync(WebSocketCloseStatus.InternalServerError, "Initial setup message failed.", CancellationToken.None);
        }
        throw;
      }

};

var webSocket = await connectFactory(sessionResumption, cancellationToken);
return new LiveSession(webSocket, connectFactory);
}
}

Expand Down
71 changes: 68 additions & 3 deletions firebaseai/src/LiveSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ namespace Firebase.AI
public class LiveSession : IDisposable
{

private readonly ClientWebSocket _clientWebSocket;
private ClientWebSocket _clientWebSocket;
private readonly Func<SessionResumptionConfig?, CancellationToken, Task<ClientWebSocket>> _connectionFactory;

private readonly SemaphoreSlim _sendLock = new(1, 1);

Expand All @@ -44,7 +45,7 @@ public class LiveSession : IDisposable
/// Intended for internal use only.
/// Use `LiveGenerativeModel.ConnectAsync` instead to ensure proper initialization.
/// </summary>
internal LiveSession(ClientWebSocket clientWebSocket)
internal LiveSession(ClientWebSocket clientWebSocket, Func<SessionResumptionConfig?, CancellationToken, Task<ClientWebSocket>> connectionFactory = null)
{
if (clientWebSocket.State != WebSocketState.Open)
{
Expand All @@ -53,6 +54,7 @@ internal LiveSession(ClientWebSocket clientWebSocket)
}

_clientWebSocket = clientWebSocket;
_connectionFactory = connectionFactory;
}

protected virtual void Dispose(bool disposing)
Expand Down Expand Up @@ -297,10 +299,31 @@ public async IAsyncEnumerable<LiveSessionResponse> ReceiveAsync(
Memory<byte> buffer = new(receiveBuffer);
while (!cancellationToken.IsCancellationRequested)
{
ValueWebSocketReceiveResult result = await _clientWebSocket.ReceiveAsync(buffer, cancellationToken);
ClientWebSocket currentWebSocket;
await _sendLock.WaitAsync(cancellationToken);
try {
currentWebSocket = _clientWebSocket;
} finally { _sendLock.Release(); }

ValueWebSocketReceiveResult result;
try
{
result = await currentWebSocket.ReceiveAsync(buffer, cancellationToken);
}
catch (Exception) when (currentWebSocket != _clientWebSocket && !cancellationToken.IsCancellationRequested)
{
// The socket was closed or disposed because of session resumption, grab the new one
await Task.Delay(10, cancellationToken);
continue;
Comment on lines +313 to +317
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The catch (Exception) block on line 313 in ReceiveAsync is too broad. Catching a general Exception can mask specific issues that are not related to session resumption or cancellation. It would be more robust to catch specific WebSocket-related exceptions (e.g., WebSocketException, OperationCanceledException) if the intent is to handle only those scenarios.

        catch (WebSocketException) when (currentWebSocket != _clientWebSocket && !cancellationToken.IsCancellationRequested)
        {
          // The socket was closed or disposed because of session resumption, grab the new one
          await Task.Delay(10, cancellationToken);
          continue;
        }

}

if (result.MessageType == WebSocketMessageType.Close)
{
if (currentWebSocket != _clientWebSocket && !cancellationToken.IsCancellationRequested)
{
await Task.Delay(10, cancellationToken);
continue;
}
// Close initiated by the server
// TODO: Should this just close without logging anything?
break;
Expand Down Expand Up @@ -338,6 +361,48 @@ public async IAsyncEnumerable<LiveSessionResponse> ReceiveAsync(
cancellationToken.ThrowIfCancellationRequested();
}

/// <summary>
/// Resumes an existing live session with the server.
///
/// This closes the current WebSocket connection and establishes a new one using
/// the same configuration as the original session.
/// </summary>
/// <param name="sessionResumption">The configuration for session resumption.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
public async Task ResumeSessionAsync(SessionResumptionConfig? sessionResumption = null, CancellationToken cancellationToken = default)
{
if (_connectionFactory == null)
{
throw new InvalidOperationException("ResumeSession is not supported on this instance.");
}

ClientWebSocket newSession = await _connectionFactory(sessionResumption, cancellationToken);
ClientWebSocket oldSession;

await _sendLock.WaitAsync(cancellationToken);
try
{
oldSession = _clientWebSocket;
_clientWebSocket = newSession;
}
finally
{
_sendLock.Release();
}

try
{
if (oldSession.State == WebSocketState.Open)
{
await oldSession.CloseAsync(WebSocketCloseStatus.NormalClosure, "Session resumed", CancellationToken.None);
}
}
catch (Exception)
{
// Ignore errors when closing the old socket.
}
}

/// <summary>
/// Close the `LiveSession`.
/// </summary>
Expand Down
Loading
Loading