Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C#: Handle all response types. #3395

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 27 additions & 52 deletions csharp/lib/BaseClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Glide.Internals;

using static Glide.ConnectionConfiguration;
using static Glide.Internals.ResponseHandler;
using static Glide.Route;

namespace Glide;
Expand Down Expand Up @@ -54,7 +55,7 @@ protected BaseClient(BaseClientConfiguration config)
}
}

protected delegate T ResponseHandler<T>(object? response);
protected delegate T ResponseHandler<T>(IntPtr response);

protected async Task<T> Command<T>(RequestType requestType, GlideString[] arguments, ResponseHandler<T> responseHandler, Route? route = null) where T : class?
{
Expand Down Expand Up @@ -112,10 +113,10 @@ protected async Task<T> Command<T>(RequestType requestType, GlideString[] argume
return responseHandler(await message);
}

protected static string HandleOk(object? response)
protected static string HandleOk(IntPtr response)
=> HandleServerResponse<GlideString, string>(response, false, gs => gs.GetString());

protected static T HandleServerResponse<T>(object? response, bool isNullable) where T : class?
protected static T HandleServerResponse<T>(IntPtr response, bool isNullable) where T : class?
=> HandleServerResponse<T, T>(response, isNullable, o => o);

/// <summary>
Expand All @@ -128,69 +129,40 @@ protected static T HandleServerResponse<T>(object? response, bool isNullable) wh
/// <param name="converter">Optional converted to convert <typeparamref name="R"/> to <typeparamref name="T"/>.</param>
/// <returns></returns>
/// <exception cref="Exception"></exception>
protected static T HandleServerResponse<R, T>(object? response, bool isNullable, Func<R, T> converter) where T : class? where R : class?
protected static T HandleServerResponse<R, T>(IntPtr response, bool isNullable, Func<R, T> converter) where T : class? where R : class?
{
if (response is null)
try
{
if (isNullable)
object? value = HandleResponse(response);
if (value is null)
{
if (isNullable)
{
#pragma warning disable CS8603 // Possible null reference return.
return null;
return null;
#pragma warning restore CS8603 // Possible null reference return.
}
throw new Exception($"Unexpected return type from Glide: got null expected {typeof(T).Name}");
}
throw new Exception($"Unexpected return type from Glide: got null expected {typeof(T).Name}");
return value is R
? converter((value as R)!)
: throw new Exception($"Unexpected return type from Glide: got {value?.GetType().Name} expected {typeof(T).Name}");
}
response = ConvertByteArrayToGlideString(response);
#pragma warning disable IDE0046 // Convert to conditional expression
if (response is R)
finally
{
return converter((response as R)!);
FreeResponse(response);
}
#pragma warning restore IDE0046 // Convert to conditional expression
throw new Exception($"Unexpected return type from Glide: got {response?.GetType().Name} expected {typeof(T).Name}");
}

protected static object? ConvertByteArrayToGlideString(object? response)
{
if (response is null)
{
return null;
}
if (response is byte[] bytes)
{
response = new GlideString(bytes);
}
// TODO handle other types
return response;
}
#endregion protected methods

#region private methods
// TODO rework the callback to handle other response types
private void SuccessCallback(ulong index, int strLen, IntPtr strPtr)
{
object? result = null;
if (strPtr != IntPtr.Zero)
{
byte[] bytes = new byte[strLen];
Marshal.Copy(strPtr, bytes, 0, strLen);
result = bytes;
}
private void SuccessCallback(ulong index, IntPtr ptr) =>
// Work needs to be offloaded from the calling thread, because otherwise we might starve the client's thread pool.
_ = Task.Run(() =>
{
Message message = _messageContainer.GetMessage((int)index);
message.SetResult(result);
});
}
Task.Run(() => _messageContainer.GetMessage((int)index).SetResult(ptr));

private void FailureCallback(ulong index) =>
// Work needs to be offloaded from the calling thread, because otherwise we might starve the client's thread pool.
Task.Run(() =>
{
Message message = _messageContainer.GetMessage((int)index);
message.SetException(new Exception("Operation failed"));
});
Task.Run(() => _messageContainer.GetMessage((int)index).SetException(new Exception("Operation failed")));

~BaseClient() => Dispose();
#endregion private methods
Expand All @@ -203,7 +175,7 @@ private void FailureCallback(ulong index) =>

/// Held as a measure to prevent the delegate being garbage collected. These are delegated once
/// and held in order to prevent the cost of marshalling on each function call.
private readonly StringAction _successCallbackDelegate;
private readonly SuccessAction _successCallbackDelegate;

/// Raw pointer to the underlying native client.
private IntPtr _clientPointer;
Expand All @@ -215,12 +187,15 @@ private void FailureCallback(ulong index) =>

#region FFI function declarations

private delegate void StringAction(ulong index, int strLen, IntPtr strPtr);
private delegate void SuccessAction(ulong index, IntPtr ptr);
private delegate void FailureAction(ulong index);

[DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "command")]
private static extern void CommandFfi(IntPtr client, ulong index, int requestType, IntPtr args, uint argCount, IntPtr argLengths, IntPtr routeInfo);

private delegate void IntAction(IntPtr arg);
[DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "free_respose")]
private static extern void FreeResponse(IntPtr response);

[DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "create_client")]
private static extern IntPtr CreateClientFfi(IntPtr config, IntPtr successCallback, IntPtr failureCallback);

Expand Down
2 changes: 1 addition & 1 deletion csharp/lib/GlideClusterClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ namespace Glide;
public sealed class GlideClusterClient(ClusterClientConfiguration config) : BaseClient(config), IGenericClusterCommands
{
public async Task<object?> CustomCommand(GlideString[] args, Route? route = null)
=> await Command<object?>(RequestType.CustomCommand, args, resp => HandleServerResponse<object?>(resp, true), route);
=> await Command(RequestType.CustomCommand, args, resp => HandleServerResponse<object?>(resp, true), route);
}
6 changes: 3 additions & 3 deletions csharp/lib/Internals/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ internal class Message(int index, MessageContainer container) : INotifyCompletio
private const int COMPLETION_STAGE_NEXT_SHOULD_EXECUTE_CONTINUATION = 1;
private const int COMPLETION_STAGE_CONTINUATION_EXECUTED = 2;
private int _completionState;
private object? _result = default;
private IntPtr _result = default;
private Exception? _exception;
// Holding the client prevents it from being GC'd until all operations complete.
#pragma warning disable IDE0052 // Remove unread private members
Expand All @@ -28,7 +28,7 @@ internal class Message(int index, MessageContainer container) : INotifyCompletio

/// Triggers a succesful completion of the task returned from the latest call
/// to CreateTask.
public void SetResult(object? result)
public void SetResult(IntPtr result)
{
_result = result;
FinishSet();
Expand Down Expand Up @@ -90,5 +90,5 @@ public void OnCompleted(Action continuation)

public bool IsCompleted => _completionState == COMPLETION_STAGE_CONTINUATION_EXECUTED;

public object? GetResult() => _exception is null ? _result : throw _exception;
public IntPtr GetResult() => _exception is null ? _result : throw _exception;
}
99 changes: 99 additions & 0 deletions csharp/lib/Internals/ResponseHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0

using System.Runtime.InteropServices;

namespace Glide.Internals;

internal class ResponseHandler
{
[StructLayout(LayoutKind.Sequential)]
private struct GlideValue
{
public ValueType Type;
public nuint Value;
public uint Size;
}

public enum ValueType : uint
{
Null = 0,
Int = 1,
Float = 2,
Bool = 3,
String = 4,
Array = 5,
Map = 6,
Set = 7,
BulkString = 8,
OK = 9,
}

public static object? HandleResponse(IntPtr valuePtr)
{
GlideValue value = Marshal.PtrToStructure<GlideValue>(valuePtr);
return TraverseValue(value);
}

private static object? TraverseValue(GlideValue value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use pattern matching switch here?

private static object? TraverseValue(GlideValue value) => value.Type switch
{
    ValueType.Null => null,
    ValueType.Int => (long)value.Value,
    ValueType.Float => (double)value.Value,
    ValueType.Bool => value.Value != 0,
    ValueType.BulkString or ValueType.String => CreateString(value),
    ValueType.Array => CreateArray(value),
    ValueType.Map => CreateMap(value),
    ValueType.Set => CreateSet(value),
    ValueType.OK => new GlideString("OK"),
    _ => throw new NotImplementedException()
};

{
#pragma warning disable IDE0022 // Use expression body for method
switch (value.Type)
{
case ValueType.Null: return null;
case ValueType.Int: return (long)value.Value;
case ValueType.Float: return (double)value.Value;
case ValueType.Bool: return value.Value != 0;
case ValueType.BulkString:
case ValueType.String:
{
byte[] bytes = new byte[value.Size];
Marshal.Copy(new IntPtr((long)value.Value), bytes, 0, (int)value.Size);
return new GlideString(bytes);
}
case ValueType.Array:
{
object?[] values = new object?[value.Size];
IntPtr ptr = new((long)value.Value);
for (int i = 0; i < values.Length; i++)
{
values[i] = HandleResponse(ptr);
ptr += Marshal.SizeOf<GlideValue>();
}

return values;
}
case ValueType.Map:
{
object?[] values = new object?[value.Size];
IntPtr ptr = new((long)value.Value);
for (int i = 0; i < values.Length; i++)
{
values[i] = HandleResponse(ptr);
ptr += Marshal.SizeOf<GlideValue>();
}

Dictionary<GlideString, object?> res = [];
for (int i = 0; i < values.Length; i += 2)
{
res[(GlideString)values[i]!] = values[i + 1];
}
return res;
}
case ValueType.Set:
{
object?[] values = new object?[value.Size];
IntPtr ptr = new((long)value.Value);
for (int i = 0; i < values.Length; i++)
{
values[i] = HandleResponse(ptr);
ptr += Marshal.SizeOf<GlideValue>();
}
return values.ToHashSet();
}
case ValueType.OK: return new GlideString("OK");
default:
throw new NotImplementedException();
}
#pragma warning restore IDE0022 // Use expression body for method
}
}
Loading
Loading