Skip to content

Commit 1f57d73

Browse files
committed
protect missed locations where network reads can happen
1 parent 69dc7d4 commit 1f57d73

File tree

4 files changed

+48
-17
lines changed

4 files changed

+48
-17
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.Windows.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ internal void PostReadAsyncForMars()
2424

2525
_pMarsPhysicalConObj.IncrementPendingCallbacks();
2626
SessionHandle handle = _pMarsPhysicalConObj.SessionHandle;
27+
// we do not need to consider partial packets when making this read because we
28+
// expect this read to pend. a partial packet should not exist at setup of the
29+
// parser
30+
Debug.Assert(_physicalStateObj.PartialPacket==null);
2731
temp = _pMarsPhysicalConObj.ReadAsync(handle, out error);
2832

2933
Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,22 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error)
322322
stateObj.SendAttention(mustTakeWriteLock: true);
323323

324324
PacketHandle syncReadPacket = default;
325+
bool readFromNetwork = true;
325326
RuntimeHelpers.PrepareConstrainedRegions();
326327
bool shouldDecrement = false;
327328
try
328329
{
329330
Interlocked.Increment(ref _readingCount);
330331
shouldDecrement = true;
331-
332-
syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error);
332+
readFromNetwork = !PartialPacketContainsCompletePacket();
333+
if (readFromNetwork)
334+
{
335+
syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error);
336+
}
337+
else
338+
{
339+
error = TdsEnums.SNI_SUCCESS;
340+
}
333341

334342
Interlocked.Decrement(ref _readingCount);
335343
shouldDecrement = false;
@@ -342,7 +350,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error)
342350
}
343351
else
344352
{
345-
Debug.Assert(!IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease");
353+
Debug.Assert(!readFromNetwork || !IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease");
346354
fail = true; // Subsequent read failed, time to give up.
347355
}
348356
}
@@ -353,7 +361,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error)
353361
Interlocked.Decrement(ref _readingCount);
354362
}
355363

356-
if (!IsPacketEmpty(syncReadPacket))
364+
if (readFromNetwork && !IsPacketEmpty(syncReadPacket))
357365
{
358366
ReleasePacket(syncReadPacket);
359367
}

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,22 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error)
452452
stateObj.SendAttention(mustTakeWriteLock: true);
453453

454454
PacketHandle syncReadPacket = default;
455+
bool readFromNetwork = true;
455456
RuntimeHelpers.PrepareConstrainedRegions();
456457
bool shouldDecrement = false;
457458
try
458459
{
459460
Interlocked.Increment(ref _readingCount);
460461
shouldDecrement = true;
461-
462-
syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error);
462+
readFromNetwork = !PartialPacketContainsCompletePacket();
463+
if (readFromNetwork)
464+
{
465+
syncReadPacket = ReadSyncOverAsync(stateObj.GetTimeoutRemaining(), out error);
466+
}
467+
else
468+
{
469+
error = TdsEnums.SNI_SUCCESS;
470+
}
463471

464472
Interlocked.Decrement(ref _readingCount);
465473
shouldDecrement = false;
@@ -472,7 +480,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error)
472480
}
473481
else
474482
{
475-
Debug.Assert(!IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease");
483+
Debug.Assert(!readFromNetwork || !IsValidPacket(syncReadPacket), "unexpected syncReadPacket without corresponding SNIPacketRelease");
476484
fail = true; // Subsequent read failed, time to give up.
477485
}
478486
}
@@ -483,7 +491,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error)
483491
Interlocked.Decrement(ref _readingCount);
484492
}
485493

486-
if (!IsPacketEmpty(syncReadPacket))
494+
if (readFromNetwork && !IsPacketEmpty(syncReadPacket))
487495
{
488496
// Be sure to release packet, otherwise it will be leaked by native.
489497
ReleasePacket(syncReadPacket);

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,7 @@ internal void ReadSni(TaskCompletionSource<object> completion)
23822382
PacketHandle readPacket = default;
23832383

23842384
uint error = 0;
2385+
bool readFromNetwork = true;
23852386

23862387
RuntimeHelpers.PrepareConstrainedRegions();
23872388
try
@@ -2427,17 +2428,27 @@ internal void ReadSni(TaskCompletionSource<object> completion)
24272428
Interlocked.Increment(ref _readingCount);
24282429

24292430
handle = SessionHandle;
2430-
if (!handle.IsNull)
2431+
2432+
readFromNetwork = !PartialPacketContainsCompletePacket();
2433+
if (readFromNetwork)
24312434
{
2432-
IncrementPendingCallbacks();
2435+
if (!handle.IsNull)
2436+
{
2437+
IncrementPendingCallbacks();
24332438

2434-
readPacket = ReadAsync(handle, out error);
2439+
readPacket = ReadAsync(handle, out error);
24352440

2436-
if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error))
2437-
{
2438-
DecrementPendingCallbacks(false); // Failure - we won't receive callback!
2441+
if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error))
2442+
{
2443+
DecrementPendingCallbacks(false); // Failure - we won't receive callback!
2444+
}
24392445
}
24402446
}
2447+
else
2448+
{
2449+
readPacket = default;
2450+
error = TdsEnums.SNI_SUCCESS;
2451+
}
24412452

24422453
Interlocked.Decrement(ref _readingCount);
24432454
}
@@ -2449,12 +2460,12 @@ internal void ReadSni(TaskCompletionSource<object> completion)
24492460

24502461
if (TdsEnums.SNI_SUCCESS == error)
24512462
{ // Success - process results!
2452-
Debug.Assert(IsValidPacket(readPacket), "ReadNetworkPacket should not have been null on this async operation!");
2463+
Debug.Assert(!readFromNetwork || IsValidPacket(readPacket) , "ReadNetworkPacket should not have been null on this async operation!");
24532464
// Evaluate this condition for MANAGED_SNI. This may not be needed because the network call is happening Async and only the callback can receive a success.
24542465
ReadAsyncCallback(IntPtr.Zero, readPacket, 0);
24552466

24562467
// Only release packet for Managed SNI as for Native SNI packet is released in finally block.
2457-
if (TdsParserStateObjectFactory.UseManagedSNI && !IsPacketEmpty(readPacket))
2468+
if (TdsParserStateObjectFactory.UseManagedSNI && readFromNetwork && !IsPacketEmpty(readPacket))
24582469
{
24592470
ReleasePacket(readPacket);
24602471
}
@@ -2492,7 +2503,7 @@ internal void ReadSni(TaskCompletionSource<object> completion)
24922503
{
24932504
if (!TdsParserStateObjectFactory.UseManagedSNI)
24942505
{
2495-
if (!IsPacketEmpty(readPacket))
2506+
if (readFromNetwork && !IsPacketEmpty(readPacket))
24962507
{
24972508
// Be sure to release packet, otherwise it will be leaked by native.
24982509
ReleasePacket(readPacket);

0 commit comments

Comments
 (0)