Skip to content

Commit

Permalink
Fix test and add a zero window test
Browse files Browse the repository at this point in the history
  • Loading branch information
guhetier committed Feb 11, 2025
1 parent 4104203 commit cab13f1
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 27 deletions.
4 changes: 1 addition & 3 deletions src/core/unittest/RecvBufferTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ struct RecvBuffer {
(QUIC_RECV_CHUNK*)CXPLAT_ALLOC_NONPAGED(
sizeof(QUIC_RECV_CHUNK) + AllocBufferLength,
QUIC_POOL_TEST);
QuicRecvChunkInitialize(PreallocChunk, AllocBufferLength, (uint8_t*)(PreallocChunk + 1), FALSE);
}
printf("Initializing: [mode=%u,vlen=%u,alen=%u]\n", RecvMode, VirtualBufferLength, AllocBufferLength);

Expand Down Expand Up @@ -79,9 +80,6 @@ struct RecvBuffer {
CxPlatListInsertTail(&ChunkList, &Chunk2->Link);
}
Result = QuicRecvBufferProvideChunks(&RecvBuf, &ChunkList);
} else {
Result = QuicRecvBufferInitialize(
&RecvBuf, AllocBufferLength, VirtualBufferLength, RecvMode, &AppBufferChunkPool, PreallocChunk);
}

Dump();
Expand Down
3 changes: 3 additions & 0 deletions src/test/MsQuicTests.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ QuicTestEcn(
void QuicTestStreamAppProvidedBuffers(
);

void QuicTestStreamAppProvidedBuffersZeroWindow(
);

//
// QuicDrill tests
//
Expand Down
11 changes: 11 additions & 0 deletions src/test/bin/quic_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2265,6 +2265,17 @@ TEST(Misc, StreamAppProvidedBuffers) {
QuicTestStreamAppProvidedBuffers();
}
}

TEST(Misc, StreamAppProvidedBuffersZeroWindow) {
TestLogger Logger("StreamAppProvidedBuffersZeroWindow");
if (TestingKernelMode) {
// GTEST_SKIP();
// TODO guhetier: Implement
// ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_STREAM_APP_PROVIDED_BUFFERS_ZERO_WINDOW));
} else {
QuicTestStreamAppProvidedBuffersZeroWindow();
}
}
#endif // QUIC_API_ENABLE_PREVIEW_FEATURES

TEST(Misc, StreamBlockUnblockUnidiConnFlowControl) {
Expand Down
239 changes: 215 additions & 24 deletions src/test/lib/DataTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4536,24 +4536,37 @@ struct AppBuffersSenderContext {

// Helper context to receive data on a stream
struct AppBuffersReceiverContext {
MsQuicStream* Stream{};

// App buffers to provide when a peer stream is received
QUIC_BUFFER *BuffersForStreamStarted{};
uint32_t NumBuffersForStreamStarted{};

// App buffers to provide when MoreBufferThreshold bytes have been received
QUIC_BUFFER *BuffersForThreshold{};
uint32_t NumBuffersForThreshold{};
uint64_t MoreBufferThreshold{};

uint64_t ReceivedBytes{};
CxPlatEvent StreamClosed{};

// Event to signal when the sender stream get closed
CxPlatEvent SenderStreamClosed{};

// Event to signal when at least
CxPlatEvent ReceivedBytesThresholdReached{};
uint64_t ReceivedBytesThreshold{};

// Accept a stream on the listener side (no need to keep the handle to it)
static QUIC_STATUS ConnCallback(_In_ MsQuicConnection*, _In_opt_ void* Context, _Inout_ QUIC_CONNECTION_EVENT* Event) {
auto ReceiverContext = (AppBuffersReceiverContext*)Context;

if (Event->Type == QUIC_CONNECTION_EVENT_PEER_STREAM_STARTED) {
auto* Stream = new(std::nothrow) MsQuicStream(
ReceiverContext->Stream = new(std::nothrow) MsQuicStream(
Event->PEER_STREAM_STARTED.Stream,
CleanUpAutoDelete,
AppBuffersReceiverContext::StreamCallback,
Context);
Stream->ProvideReceiveBuffers(
ReceiverContext->Stream->ProvideReceiveBuffers(
ReceiverContext->NumBuffersForStreamStarted,
ReceiverContext->BuffersForStreamStarted);
}
Expand All @@ -4564,8 +4577,23 @@ struct AppBuffersReceiverContext {
auto ReceiverContext = (AppBuffersReceiverContext*)Context;
if (Event->Type == QUIC_STREAM_EVENT_RECEIVE) {
ReceiverContext->ReceivedBytes += Event->RECEIVE.TotalBufferLength;
} else if (Event->Type == QUIC_STREAM_EVENT_SHUTDOWN_COMPLETE) {
ReceiverContext->StreamClosed.Set();

if (ReceiverContext->MoreBufferThreshold > 0 &&
ReceiverContext->ReceivedBytes >= ReceiverContext->MoreBufferThreshold) {
// Provide more buffers if needed
ReceiverContext->Stream->ProvideReceiveBuffers(
ReceiverContext->NumBuffersForThreshold,
ReceiverContext->BuffersForThreshold);
ReceiverContext->MoreBufferThreshold = 0;
}

if (ReceiverContext->ReceivedBytesThreshold > 0 &&
ReceiverContext->ReceivedBytes >= ReceiverContext->ReceivedBytesThreshold) {
ReceiverContext->ReceivedBytesThresholdReached.Set();
ReceiverContext->ReceivedBytesThreshold = 0;
}
} else if (Event->Type == QUIC_STREAM_EVENT_PEER_SEND_SHUTDOWN) {
ReceiverContext->SenderStreamClosed.Set();
}
return QUIC_STATUS_SUCCESS;
}
Expand All @@ -4579,33 +4607,37 @@ QuicTestStreamAppProvidedBuffers(
TEST_QUIC_SUCCEEDED(Registration.GetInitStatus());

MsQuicConfiguration ServerConfiguration(Registration, "MsQuicTest",
MsQuicSettings().SetPeerUnidiStreamCount(1).SetPeerBidiStreamCount(1).SetConnFlowControlWindow(0x2000),
MsQuicSettings().SetPeerUnidiStreamCount(1).SetPeerBidiStreamCount(1).SetStreamRecvWindowDefault(0x2000),
ServerSelfSignedCredConfig);
TEST_QUIC_SUCCEEDED(ServerConfiguration.GetInitStatus());

MsQuicConfiguration ClientConfiguration(Registration, "MsQuicTest",
MsQuicSettings().SetPeerUnidiStreamCount(1).SetPeerBidiStreamCount(1).SetConnFlowControlWindow(0x2000),
MsQuicSettings().SetPeerUnidiStreamCount(1).SetPeerBidiStreamCount(1).SetStreamRecvWindowDefault(0x2000),
MsQuicCredentialConfig());
TEST_QUIC_SUCCEEDED(ClientConfiguration.GetInitStatus());

// Client side sending data
{
// Create send and receive buffers
const uint32_t BufferSize = 0x5000;
const uint32_t NumBuffers = 0x10;
uint8_t SendDataBuffer[BufferSize] = {};
for (auto i = 0u; i < BufferSize; ++i) {
SendDataBuffer[i] = static_cast<uint8_t>(i);
}
uint8_t ReceiveDataBuffer[BufferSize] = {};
QUIC_BUFFER QuicBuffers[5]{};
for (auto i = 0u; i < 5; ++i) {
QuicBuffers[i].Buffer = ReceiveDataBuffer + i * BufferSize / 5;
QuicBuffers[i].Length = BufferSize / 5;
QUIC_BUFFER QuicBuffers[NumBuffers]{};
for (auto i = 0u; i < NumBuffers; ++i) {
QuicBuffers[i].Buffer = ReceiveDataBuffer + i * BufferSize / NumBuffers;
QuicBuffers[i].Length = BufferSize / NumBuffers;
}

AppBuffersReceiverContext ReceiveContext;
ReceiveContext.BuffersForStreamStarted = QuicBuffers;
ReceiveContext.NumBuffersForStreamStarted = ARRAYSIZE(QuicBuffers);
ReceiveContext.NumBuffersForStreamStarted = NumBuffers / 2;
ReceiveContext.BuffersForThreshold = QuicBuffers + NumBuffers / 2;
ReceiveContext.NumBuffersForThreshold = NumBuffers / 2;
ReceiveContext.MoreBufferThreshold = 0x1500;

// Setup a listener
MsQuicAutoAcceptListener Listener(Registration, ServerConfiguration, AppBuffersReceiverContext::ConnCallback, &ReceiveContext);
Expand All @@ -4632,9 +4664,9 @@ QuicTestStreamAppProvidedBuffers(

// Send data
QUIC_BUFFER Buffer{BufferSize, SendDataBuffer};
TEST_QUIC_SUCCEEDED(ClientStream.Send(&Buffer, 1));
TEST_QUIC_SUCCEEDED(ClientStream.Send(&Buffer, 1, QUIC_SEND_FLAG_FIN));

ReceiveContext.StreamClosed.WaitTimeout(TestWaitTimeout);
TEST_TRUE(ReceiveContext.SenderStreamClosed.WaitTimeout(TestWaitTimeout));
TEST_EQUAL(ReceiveContext.ReceivedBytes, BufferSize);
TEST_EQUAL(0, memcmp(SendDataBuffer, ReceiveDataBuffer, BufferSize));
}
Expand All @@ -4661,45 +4693,204 @@ QuicTestStreamAppProvidedBuffers(
TEST_TRUE(Connection.HandshakeCompleteEvent.WaitTimeout(TestWaitTimeout));
TEST_TRUE(Connection.HandshakeComplete);


// Create send and receive buffers
const uint32_t BufferSize = 0x5000;
const uint32_t NumBuffers = 0x10;
uint8_t SendDataBuffer[BufferSize] = {};
for (auto i = 0u; i < BufferSize; ++i) {
SendDataBuffer[i] = static_cast<uint8_t>(i);
}

uint8_t ReceiveDataBuffer[BufferSize] = {};
QUIC_BUFFER QuicBuffers[5]{};
for (auto i = 0u; i < 5; ++i) {
QuicBuffers[i].Buffer = ReceiveDataBuffer + i * BufferSize / 5;
QuicBuffers[i].Length = BufferSize / 5;
uint8_t ReceiveDataBuffer[BufferSize]{};
QUIC_BUFFER QuicBuffers[NumBuffers]{};
for (auto i = 0u; i < NumBuffers; ++i) {
QuicBuffers[i].Buffer = ReceiveDataBuffer + i * BufferSize / NumBuffers;
QuicBuffers[i].Length = BufferSize / NumBuffers;
}

// Create and start a stream
AppBuffersReceiverContext ReceiveContext;
ReceiveContext.BuffersForThreshold = QuicBuffers + NumBuffers / 2;
ReceiveContext.NumBuffersForThreshold = NumBuffers / 2;
ReceiveContext.MoreBufferThreshold = 0x1500;

MsQuicStream ClientStream(
Connection,
QUIC_STREAM_OPEN_FLAG_APP_OWNED_BUFFERS,
CleanUpManual,
AppBuffersReceiverContext::StreamCallback,
&ReceiveContext);
TEST_QUIC_SUCCEEDED(ClientStream.GetInitStatus());

ReceiveContext.Stream = &ClientStream;
// Provide some receive buffers before starting the stream
ClientStream.ProvideReceiveBuffers(NumBuffers / 2, QuicBuffers);

TEST_QUIC_SUCCEEDED(ClientStream.Start(QUIC_STREAM_START_FLAG_IMMEDIATE));
TEST_QUIC_SUCCEEDED(ClientStream.Shutdown(QUIC_STATUS_SUCCESS, QUIC_STREAM_SHUTDOWN_FLAG_GRACEFUL));

auto* SenderStream = SenderContext.WaitForSenderStream();
TEST_NOT_EQUAL(SenderStream, nullptr);

// Send data
QUIC_BUFFER Buffer{BufferSize, SendDataBuffer};
TEST_QUIC_SUCCEEDED(SenderStream->Send(&Buffer, 1, QUIC_SEND_FLAG_FIN));

TEST_TRUE(ReceiveContext.SenderStreamClosed.WaitTimeout(TestWaitTimeout));
TEST_EQUAL(ReceiveContext.ReceivedBytes, BufferSize);
TEST_EQUAL(0, memcmp(SendDataBuffer, ReceiveDataBuffer, BufferSize));
}
}

void
QuicTestStreamAppProvidedBuffersZeroWindow(
)
{
MsQuicRegistration Registration(true);
TEST_QUIC_SUCCEEDED(Registration.GetInitStatus());

MsQuicConfiguration ServerConfiguration(Registration, "MsQuicTest",
MsQuicSettings().SetPeerUnidiStreamCount(1).SetPeerBidiStreamCount(1).SetStreamRecvWindowDefault(0x2000),
ServerSelfSignedCredConfig);
TEST_QUIC_SUCCEEDED(ServerConfiguration.GetInitStatus());

MsQuicConfiguration ClientConfiguration(Registration, "MsQuicTest",
MsQuicSettings().SetPeerUnidiStreamCount(1).SetPeerBidiStreamCount(1).SetStreamRecvWindowDefault(0x2000),
MsQuicCredentialConfig());
TEST_QUIC_SUCCEEDED(ClientConfiguration.GetInitStatus());

// Client side sending data
{
// Create send and receive buffers
const uint32_t BufferSize = 0x5000;
const uint32_t NumBuffers = 0x10;
uint8_t SendDataBuffer[BufferSize] = {};
for (auto i = 0u; i < BufferSize; ++i) {
SendDataBuffer[i] = static_cast<uint8_t>(i);
}
uint8_t ReceiveDataBuffer[BufferSize] = {};
QUIC_BUFFER QuicBuffers[NumBuffers]{};
for (auto i = 0u; i < NumBuffers; ++i) {
QuicBuffers[i].Buffer = ReceiveDataBuffer + i * BufferSize / NumBuffers;
QuicBuffers[i].Length = BufferSize / NumBuffers;
}

AppBuffersReceiverContext ReceiveContext;
ReceiveContext.BuffersForStreamStarted = QuicBuffers;
ReceiveContext.NumBuffersForStreamStarted = NumBuffers / 2;
ReceiveContext.BuffersForThreshold = QuicBuffers + NumBuffers / 2;
ReceiveContext.NumBuffersForThreshold = NumBuffers / 2;

// Setup the threshold so that more buffers are provided:
// - only when the receive window reaches zero bytes
// - but inline in the receive callback
ReceiveContext.MoreBufferThreshold = 0x2000;

// Setup a listener
MsQuicAutoAcceptListener Listener(Registration, ServerConfiguration, AppBuffersReceiverContext::ConnCallback, &ReceiveContext);
TEST_QUIC_SUCCEEDED(Listener.GetInitStatus());
TEST_QUIC_SUCCEEDED(Listener.Start("MsQuicTest"));
QuicAddr ServerLocalAddr;
TEST_QUIC_SUCCEEDED(Listener.GetLocalAddr(ServerLocalAddr));

// Setup and start a client connection
MsQuicConnection Connection(Registration);
TEST_QUIC_SUCCEEDED(Connection.GetInitStatus());

TEST_QUIC_SUCCEEDED(Connection.Start(
ClientConfiguration,
ServerLocalAddr.GetFamily(),
QUIC_TEST_LOOPBACK_FOR_AF(ServerLocalAddr.GetFamily()),
ServerLocalAddr.GetPort()));
TEST_TRUE(Connection.HandshakeCompleteEvent.WaitTimeout(TestWaitTimeout));
TEST_TRUE(Connection.HandshakeComplete);

MsQuicStream ClientStream(Connection, QUIC_STREAM_OPEN_FLAG_UNIDIRECTIONAL);
TEST_QUIC_SUCCEEDED(ClientStream.GetInitStatus());
TEST_QUIC_SUCCEEDED(ClientStream.Start(QUIC_STREAM_START_FLAG_IMMEDIATE));

// Send data
QUIC_BUFFER Buffer{BufferSize, SendDataBuffer};
TEST_QUIC_SUCCEEDED(ClientStream.Send(&Buffer, 1, QUIC_SEND_FLAG_FIN));

TEST_TRUE(ReceiveContext.SenderStreamClosed.WaitTimeout(TestWaitTimeout));
TEST_EQUAL(ReceiveContext.ReceivedBytes, BufferSize);
TEST_EQUAL(0, memcmp(SendDataBuffer, ReceiveDataBuffer, BufferSize));
}

// Server side sending data
{
// Setup a listener
AppBuffersSenderContext SenderContext{};
MsQuicAutoAcceptListener Listener(Registration, ServerConfiguration, AppBuffersSenderContext::ConnCallback, &SenderContext);
TEST_QUIC_SUCCEEDED(Listener.GetInitStatus());
TEST_QUIC_SUCCEEDED(Listener.Start("MsQuicTest"));
QuicAddr ServerLocalAddr;
TEST_QUIC_SUCCEEDED(Listener.GetLocalAddr(ServerLocalAddr));

// Provide receive buffers before starting the stream
ClientStream.ProvideReceiveBuffers(ARRAYSIZE(QuicBuffers), QuicBuffers);
// Setup a client connection
MsQuicConnection Connection(Registration);
TEST_QUIC_SUCCEEDED(Connection.GetInitStatus());

TEST_QUIC_SUCCEEDED(Connection.Start(
ClientConfiguration,
ServerLocalAddr.GetFamily(),
QUIC_TEST_LOOPBACK_FOR_AF(ServerLocalAddr.GetFamily()),
ServerLocalAddr.GetPort()));
TEST_TRUE(Connection.HandshakeCompleteEvent.WaitTimeout(TestWaitTimeout));
TEST_TRUE(Connection.HandshakeComplete);

// Create send and receive buffers
const uint32_t BufferSize = 0x5000;
const uint32_t NumBuffers = 0x10;
uint8_t SendDataBuffer[BufferSize] = {};
for (auto i = 0u; i < BufferSize; ++i) {
SendDataBuffer[i] = static_cast<uint8_t>(i);
}

uint8_t ReceiveDataBuffer[BufferSize]{};
QUIC_BUFFER QuicBuffers[NumBuffers]{};
for (auto i = 0u; i < NumBuffers; ++i) {
QuicBuffers[i].Buffer = ReceiveDataBuffer + i * BufferSize / NumBuffers;
QuicBuffers[i].Length = BufferSize / NumBuffers;
}

// Create and start a stream
AppBuffersReceiverContext ReceiveContext;

MsQuicStream ClientStream(
Connection,
QUIC_STREAM_OPEN_FLAG_APP_OWNED_BUFFERS,
CleanUpManual,
AppBuffersReceiverContext::StreamCallback,
&ReceiveContext);
TEST_QUIC_SUCCEEDED(ClientStream.GetInitStatus());

ReceiveContext.Stream = &ClientStream;
// Set the threshold to the amount of provided buffer space to wait until
// the maximum byte offset receivable is received.
ReceiveContext.ReceivedBytesThreshold = BufferSize / 2;

// Provide some receive buffers before starting the stream
TEST_QUIC_SUCCEEDED(ClientStream.ProvideReceiveBuffers(NumBuffers / 2, QuicBuffers));

TEST_QUIC_SUCCEEDED(ClientStream.Start(QUIC_STREAM_START_FLAG_IMMEDIATE));
TEST_QUIC_SUCCEEDED(ClientStream.Shutdown(QUIC_STATUS_SUCCESS, QUIC_STREAM_SHUTDOWN_FLAG_GRACEFUL));

auto* SenderStream = SenderContext.WaitForSenderStream();
TEST_NOT_EQUAL(SenderStream, nullptr);

// Send data
QUIC_BUFFER Buffer{BufferSize, SendDataBuffer};
TEST_QUIC_SUCCEEDED(SenderStream->Send(&Buffer, 1));
TEST_QUIC_SUCCEEDED(SenderStream->Send(&Buffer, 1, QUIC_SEND_FLAG_FIN));

// Wait until enough data is received to fill the window completely
TEST_TRUE(ReceiveContext.ReceivedBytesThresholdReached.WaitTimeout(TestWaitTimeout));

// Provide more buffers out of a callback context and check the remaining data is received
TEST_QUIC_SUCCEEDED(ClientStream.ProvideReceiveBuffers(NumBuffers / 2, QuicBuffers + NumBuffers / 2));

ReceiveContext.StreamClosed.WaitTimeout(TestWaitTimeout);
TEST_TRUE(ReceiveContext.SenderStreamClosed.WaitTimeout(TestWaitTimeout));
TEST_EQUAL(ReceiveContext.ReceivedBytes, BufferSize);
TEST_EQUAL(0, memcmp(SendDataBuffer, ReceiveDataBuffer, BufferSize));
}
Expand Down

0 comments on commit cab13f1

Please sign in to comment.