Skip to content

Commit 8de167d

Browse files
C#: Custom command for standalone client. (#3335)
* Custom command for standalone. Signed-off-by: Yury-Fridlyand <[email protected]>
1 parent fbd5836 commit 8de167d

File tree

8 files changed

+133
-84
lines changed

8 files changed

+133
-84
lines changed

csharp/lib/BaseClient.cs

+24-25
Original file line numberDiff line numberDiff line change
@@ -29,41 +29,40 @@ protected BaseClient(BaseClientConfiguration config)
2929
}
3030
}
3131

32-
protected async Task<string?> Command(IntPtr[] args, int argsCount, RequestType requestType)
32+
protected async Task<T> Command<T>(string[] arguments, RequestType requestType) where T : class?
3333
{
34+
IntPtr[] args = _arrayPool.Rent(arguments.Length);
35+
for (int i = 0; i < arguments.Length; i++)
36+
{
37+
args[i] = Marshal.StringToHGlobalAnsi(arguments[i]);
38+
}
3439
// We need to pin the array in place, in order to ensure that the GC doesn't move it while the operation is running.
3540
GCHandle pinnedArray = GCHandle.Alloc(args, GCHandleType.Pinned);
3641
IntPtr pointer = pinnedArray.AddrOfPinnedObject();
37-
Message<string> message = _messageContainer.GetMessageForCall(args, argsCount);
38-
CommandFfi(_clientPointer, (ulong)message.Index, (int)requestType, pointer, (uint)argsCount);
39-
string? result = await message;
42+
Message message = _messageContainer.GetMessageForCall<T>(args);
43+
CommandFfi(_clientPointer, (ulong)message.Index, (int)requestType, pointer, (uint)arguments.Length);
44+
for (int i = 0; i < arguments.Length; i++)
45+
{
46+
Marshal.FreeHGlobal(args[i]);
47+
}
4048
pinnedArray.Free();
41-
return result;
42-
}
43-
44-
public async Task<string?> Set(string key, string value)
45-
{
46-
IntPtr[] args = _arrayPool.Rent(2);
47-
args[0] = Marshal.StringToHGlobalAnsi(key);
48-
args[1] = Marshal.StringToHGlobalAnsi(value);
49-
string? result = await Command(args, 2, RequestType.Set);
5049
_arrayPool.Return(args);
51-
return result;
50+
#pragma warning disable CS8603 // Possible null reference return.
51+
return await message as T;
52+
#pragma warning restore CS8603 // Possible null reference return.
5253
}
5354

55+
public async Task<string> Set(string key, string value)
56+
=> await Command<string>([key, value], RequestType.Set);
57+
5458
public async Task<string?> Get(string key)
55-
{
56-
IntPtr[] args = _arrayPool.Rent(1);
57-
args[0] = Marshal.StringToHGlobalAnsi(key);
58-
string? result = await Command(args, 1, RequestType.Get);
59-
_arrayPool.Return(args);
60-
return result;
61-
}
59+
=> await Command<string?>([key], RequestType.Get);
6260

6361
private readonly object _lock = new();
6462

6563
public void Dispose()
6664
{
65+
GC.SuppressFinalize(this);
6766
lock (_lock)
6867
{
6968
if (_clientPointer == IntPtr.Zero)
@@ -79,14 +78,14 @@ public void Dispose()
7978
#endregion public methods
8079

8180
#region private methods
82-
81+
// TODO rework the callback to handle other response types
8382
private void SuccessCallback(ulong index, IntPtr str)
8483
{
8584
string? result = str == IntPtr.Zero ? null : Marshal.PtrToStringAnsi(str);
8685
// Work needs to be offloaded from the calling thread, because otherwise we might starve the client's thread pool.
8786
_ = Task.Run(() =>
8887
{
89-
Message<string> message = _messageContainer.GetMessage((int)index);
88+
Message message = _messageContainer.GetMessage((int)index);
9089
message.SetResult(result);
9190
});
9291
}
@@ -95,7 +94,7 @@ private void FailureCallback(ulong index) =>
9594
// Work needs to be offloaded from the calling thread, because otherwise we might starve the client's thread pool.
9695
Task.Run(() =>
9796
{
98-
Message<string> message = _messageContainer.GetMessage((int)index);
97+
Message message = _messageContainer.GetMessage((int)index);
9998
message.SetException(new Exception("Operation failed"));
10099
});
101100

@@ -114,7 +113,7 @@ private void FailureCallback(ulong index) =>
114113

115114
/// Raw pointer to the underlying native client.
116115
private IntPtr _clientPointer;
117-
private readonly MessageContainer<string> _messageContainer = new();
116+
private readonly MessageContainer _messageContainer = new();
118117
private readonly ArrayPool<IntPtr> _arrayPool = ArrayPool<IntPtr>.Shared;
119118

120119
#endregion private fields
+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0
2+
3+
namespace Glide.Commands;
4+
5+
public interface IGenericCommands
6+
{
7+
/// <summary>
8+
/// Executes a single command, without checking inputs. Every part of the command, including subcommands,
9+
/// should be added as a separate value in <paramref name="args"/>.
10+
/// See the <see href="https://github.com/valkey-io/valkey-glide/wiki/General-Concepts#custom-command">Valkey GLIDE Wiki</see>.
11+
/// for details on the restrictions and limitations of the custom command API.
12+
/// <para />
13+
/// This function should only be used for single-response commands. Commands that don't return complete response and awaits
14+
/// (such as SUBSCRIBE), or that return potentially more than a single response (such as XREAD), or that change the client's
15+
/// behavior (such as entering pub/sub mode on RESP2 connections) shouldn't be called using this function.
16+
/// <example>
17+
/// <code>
18+
/// // Query all pub/sub clients
19+
/// object result = await client.CustomCommand(["CLIENT", "LIST", "TYPE", "PUBSUB"]);
20+
/// </code>
21+
/// </example>
22+
/// </summary>
23+
/// <param name="args">A list including the command name and arguments for the custom command.</param>
24+
/// <returns>The returning value depends on the executed command.</returns>
25+
Task<object?> CustomCommand(string[] args);
26+
}

csharp/lib/Commands/IStringBaseCommands.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ internal interface IStringBaseCommands
2020
/// <param name="key">The <paramref name="key"/> to store.</param>
2121
/// <param name="value">The value to store with the given <paramref name="key"/>.</param>
2222
/// <returns>A simple <c>"OK"</c> response.</returns>
23-
Task<string?> Set(string key, string value);
23+
Task<string> Set(string key, string value);
2424

2525
/// <summary>
2626
/// Gets the value associated with the given <paramref name="key"/>, or <see langword="null"/> if no such <paramref name="key"/> exists.

csharp/lib/GlideClient.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
namespace Glide;
88

9-
public sealed class GlideClient(StandaloneClientConfiguration config) : BaseClient(config), IConnectionManagementCommands
9+
public sealed class GlideClient(StandaloneClientConfiguration config) : BaseClient(config), IConnectionManagementCommands, IGenericCommands
1010
{
11+
public async Task<object?> CustomCommand(string[] args)
12+
=> await Command<object?>(args, RequestType.CustomCommand);
1113
}

csharp/lib/Internals/Message.cs

+14-27
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,35 @@
22

33
using System.Diagnostics;
44
using System.Runtime.CompilerServices;
5-
using System.Runtime.InteropServices;
65

76
using Glide.Internals;
87

98
/// Reusable source of ValueTask. This object can be allocated once and then reused
109
/// to create multiple asynchronous operations, as long as each call to CreateTask
1110
/// is awaited to completion before the next call begins.
12-
internal class Message<T>(int index, MessageContainer<T> container) : INotifyCompletion
11+
internal class Message(int index, MessageContainer container) : INotifyCompletion
1312
{
1413
/// This is the index of the message in an external array, that allows the user to
1514
/// know how to find the message and set its result.
1615
public int Index { get; } = index;
17-
/// The array holding the pointers to the unmanaged memory that contains the operation's arguments.
16+
/// The array holding the pointers to the unmanaged memory that contains the operation's arguments.
1817
public IntPtr[]? Args { get; private set; }
19-
// We need to save the args count, because sometimes we get arrays that are larger than they need to be. We can't rely on `this.args.Length`, due to it coming from an array pool.
20-
private int _argsCount;
21-
private MessageContainer<T> Container { get; } = container;
18+
private MessageContainer Container { get; } = container;
2219
private Action? _continuation = () => { };
2320
private const int COMPLETION_STAGE_STARTED = 0;
2421
private const int COMPLETION_STAGE_NEXT_SHOULD_EXECUTE_CONTINUATION = 1;
2522
private const int COMPLETION_STAGE_CONTINUATION_EXECUTED = 2;
2623
private int _completionState;
27-
private T? _result;
24+
private object? _result = default;
2825
private Exception? _exception;
26+
// Holding the client prevents it from being GC'd until all operations complete.
27+
#pragma warning disable IDE0052 // Remove unread private members
28+
private object? _client;
29+
#pragma warning restore IDE0052 // Remove unread private members
2930

3031
/// Triggers a succesful completion of the task returned from the latest call
3132
/// to CreateTask.
32-
public void SetResult(T? result)
33+
public void SetResult(object? result)
3334
{
3435
_result = result;
3536
FinishSet();
@@ -45,7 +46,7 @@ public void SetException(Exception exc)
4546

4647
private void FinishSet()
4748
{
48-
FreePointers();
49+
CleanUp();
4950

5051
CheckRaceAndCallContinuation();
5152
}
@@ -67,44 +68,30 @@ private void CheckRaceAndCallContinuation()
6768
}
6869
}
6970

70-
public Message<T> GetAwaiter() => this;
71+
public Message GetAwaiter() => this;
7172

7273
/// This returns a task that will complete once SetException / SetResult are called,
7374
/// and ensures that the internal state of the message is set-up before the task is created,
7475
/// and cleaned once it is complete.
75-
public void SetupTask(IntPtr[] arguments, int argsCount, object client)
76+
public void SetupTask(IntPtr[] arguments, object client)
7677
{
7778
_continuation = null;
7879
_completionState = COMPLETION_STAGE_STARTED;
7980
_result = default;
8081
_exception = null;
8182
_client = client;
8283
Args = arguments;
83-
_argsCount = argsCount;
8484
}
8585

86-
// This function isn't thread-safe. Access to it should be from a single thread, and only once per operation.
87-
// For the sake of performance, this responsibility is on the caller, and the function doesn't contain any safety measures.
88-
private void FreePointers()
86+
private void CleanUp()
8987
{
9088
if (Args is { })
9189
{
92-
for (int i = 0; i < _argsCount; i++)
93-
{
94-
Marshal.FreeHGlobal(Args[i]);
95-
}
9690
Args = null;
97-
_argsCount = 0;
9891
}
9992
_client = null;
10093
}
10194

102-
// Holding the client prevents it from being GC'd until all operations complete.
103-
#pragma warning disable IDE0052 // Remove unread private members
104-
private object? _client;
105-
#pragma warning restore IDE0052 // Remove unread private members
106-
107-
10895
public void OnCompleted(Action continuation)
10996
{
11097
_continuation = continuation;
@@ -113,5 +100,5 @@ public void OnCompleted(Action continuation)
113100

114101
public bool IsCompleted => _completionState == COMPLETION_STAGE_CONTINUATION_EXECUTED;
115102

116-
public T? GetResult() => _exception is null ? _result : throw _exception;
103+
public object? GetResult() => _exception is null ? _result : throw _exception;
117104
}

csharp/lib/Internals/MessageContainer.cs

+13-12
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,39 @@
55
namespace Glide.Internals;
66

77

8-
internal class MessageContainer<T>
8+
internal class MessageContainer
99
{
10-
internal Message<T> GetMessage(int index) => _messages[index];
10+
internal Message GetMessage(int index) => _messages[index];
1111

12-
internal Message<T> GetMessageForCall(nint[] args, int argsCount)
12+
internal Message GetMessageForCall<T>(nint[] args)
1313
{
14-
Message<T> message = GetFreeMessage();
15-
message.SetupTask(args, argsCount, this);
14+
Message message = GetFreeMessage();
15+
message.SetupTask(args, this);
1616
return message;
1717
}
1818

19-
private Message<T> GetFreeMessage()
19+
private Message GetFreeMessage()
2020
{
21-
if (!_availableMessages.TryDequeue(out Message<T>? message))
21+
if (!_availableMessages.TryDequeue(out Message? message))
2222
{
2323
lock (_messages)
2424
{
2525
int index = _messages.Count;
26-
message = new Message<T>(index, this);
26+
message = new Message(index, this);
2727
_messages.Add(message);
2828
}
2929
}
3030
return message;
3131
}
3232

33-
public void ReturnFreeMessage(Message<T> message) => _availableMessages.Enqueue(message);
33+
public void ReturnFreeMessage(Message message)
34+
=> _availableMessages.Enqueue((Message)(object)message);
3435

3536
internal void DisposeWithError(Exception? error)
3637
{
3738
lock (_messages)
3839
{
39-
foreach (Message<T>? message in _messages.Where(message => !message.IsCompleted))
40+
foreach (Message? message in _messages.Where(message => !message.IsCompleted))
4041
{
4142
try
4243
{
@@ -52,9 +53,9 @@ internal void DisposeWithError(Exception? error)
5253
/// This list allows us random-access to the message in each index,
5354
/// which means that once we receive a callback with an index, we can
5455
/// find the message to resolve in constant time.
55-
private readonly List<Message<T>> _messages = [];
56+
private readonly List<Message> _messages = [];
5657

5758
/// This queue contains the messages that were created and are currently unused by any task,
5859
/// so they can be reused y new tasks instead of allocating new messages.
59-
private readonly ConcurrentQueue<Message<T>> _availableMessages = new();
60+
private readonly ConcurrentQueue<Message> _availableMessages = new();
6061
}

csharp/lib/src/lib.rs

+15-18
Original file line numberDiff line numberDiff line change
@@ -115,26 +115,20 @@ pub unsafe extern "C" fn command(
115115
Arc::from_raw(client_ptr as *mut Client)
116116
};
117117

118-
// The safety of these needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed.
119-
let args_address = args as usize;
118+
let Some(mut cmd) = request_type.get_command() else {
119+
unsafe {
120+
(client.core.failure_callback)(callback_index); // TODO - report errors
121+
return;
122+
}
123+
};
124+
let args_slice = unsafe { std::slice::from_raw_parts(args, arg_count as usize) };
125+
for arg in args_slice {
126+
let c_str = unsafe { CStr::from_ptr(*arg as *mut c_char) };
127+
cmd.arg(c_str.to_bytes());
128+
}
120129

121130
let core = client.core.clone();
122131
client.runtime.spawn(async move {
123-
let Some(mut cmd) = request_type.get_command() else {
124-
unsafe {
125-
(core.failure_callback)(callback_index); // TODO - report errors
126-
return;
127-
}
128-
};
129-
130-
let args_slice = unsafe {
131-
std::slice::from_raw_parts(args_address as *const *mut c_char, arg_count as usize)
132-
};
133-
for arg in args_slice {
134-
let c_str = unsafe { CStr::from_ptr(*arg as *mut c_char) };
135-
cmd.arg(c_str.to_bytes());
136-
}
137-
138132
let result = core
139133
.client
140134
.clone()
@@ -145,7 +139,10 @@ pub unsafe extern "C" fn command(
145139
match result {
146140
Ok(None) => (core.success_callback)(callback_index, std::ptr::null()),
147141
Ok(Some(c_str)) => (core.success_callback)(callback_index, c_str.as_ptr()),
148-
Err(_) => (core.failure_callback)(callback_index), // TODO - report errors
142+
Err(err) => {
143+
dbg!(err); // TODO - report errors
144+
(core.failure_callback)(callback_index)
145+
}
149146
};
150147
}
151148
});

0 commit comments

Comments
 (0)