Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2361,7 +2361,10 @@ private Task CopyColumnsAsync(int col, TaskCompletionSource<object> source = nul
// This is in its own method to avoid always allocating the lambda in CopyColumnsAsync
private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource<object> source, Task task, int i)
{
AsyncHelper.ContinueTaskWithState(task, source, this,
AsyncHelper.ContinueTaskWithState(
task,
source,
state: this,
onSuccess: (object state) =>
{
SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state;
Expand All @@ -2373,9 +2376,7 @@ private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource<object> sour
{
source.SetResult(null);
}
},
connectionToDoom: _connection.GetOpenTdsConnection()
);
});
}

// The notification logic.
Expand Down Expand Up @@ -2510,10 +2511,11 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts,
}
resultTask = source.Task;

AsyncHelper.ContinueTaskWithState(readTask, source, this,
onSuccess: (object state) => ((SqlBulkCopy)state).CopyRowsAsync(i + 1, totalRows, cts, source),
connectionToDoom: _connection.GetOpenTdsConnection()
);
AsyncHelper.ContinueTaskWithState(
readTask,
source,
state: this,
onSuccess: (object state) => ((SqlBulkCopy)state).CopyRowsAsync(i + 1, totalRows, cts, source));
return resultTask; // Associated task will be completed when all rows are copied to server/exception/cancelled.
}
}
Expand All @@ -2535,14 +2537,13 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts,
}
else
{
AsyncHelper.ContinueTaskWithState(readTask, source, sqlBulkCopy,
onSuccess: (object state2) => ((SqlBulkCopy)state2).CopyRowsAsync(i + 1, totalRows, cts, source),
connectionToDoom: _connection.GetOpenTdsConnection()
);
AsyncHelper.ContinueTaskWithState(
readTask,
source,
state: sqlBulkCopy,
onSuccess: (object state2) => ((SqlBulkCopy)state2).CopyRowsAsync(i + 1, totalRows, cts, source));
}
},
connectionToDoom: _connection.GetOpenTdsConnection()
);
});
return resultTask;
}
}
Expand Down Expand Up @@ -2611,7 +2612,10 @@ private Task CopyBatchesAsync(BulkCopySimpleResultSet internalResults, string up
source = new TaskCompletionSource<object>();
}

AsyncHelper.ContinueTaskWithState(commandTask, source, this,
AsyncHelper.ContinueTaskWithState(
commandTask,
source,
state: this,
onSuccess: (object state) =>
{
SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state;
Expand All @@ -2621,9 +2625,7 @@ private Task CopyBatchesAsync(BulkCopySimpleResultSet internalResults, string up
// Continuation finished sync, recall into CopyBatchesAsync to continue
sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source);
}
},
connectionToDoom: _connection.GetOpenTdsConnection()
);
});
return source.Task;
}
}
Expand Down Expand Up @@ -2677,7 +2679,10 @@ private Task CopyBatchesAsyncContinued(BulkCopySimpleResultSet internalResults,
{ // First time only
source = new TaskCompletionSource<object>();
}
AsyncHelper.ContinueTaskWithState(task, source, this,
AsyncHelper.ContinueTaskWithState(
task,
source,
state: this,
onSuccess: (object state) =>
{
SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state;
Expand All @@ -2689,9 +2694,7 @@ private Task CopyBatchesAsyncContinued(BulkCopySimpleResultSet internalResults,
}
},
onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false),
onCancellation: static (object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: true),
connectionToDoom: _connection.GetOpenTdsConnection()
);
onCancellation: static (object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: true));

return source.Task;
}
Expand Down Expand Up @@ -2738,7 +2741,10 @@ private Task CopyBatchesAsyncContinuedOnSuccess(BulkCopySimpleResultSet internal
source = new TaskCompletionSource<object>();
}

AsyncHelper.ContinueTaskWithState(writeTask, source, this,
AsyncHelper.ContinueTaskWithState(
writeTask,
source,
state: this,
onSuccess: (object state) =>
{
SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state;
Expand All @@ -2756,9 +2762,7 @@ private Task CopyBatchesAsyncContinuedOnSuccess(BulkCopySimpleResultSet internal
// Always call back into CopyBatchesAsync
sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source);
},
onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false),
connectionToDoom: _connection.GetOpenTdsConnection()
);
onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false));
return source.Task;
}
}
Expand Down Expand Up @@ -2859,7 +2863,10 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int
{
source = new TaskCompletionSource<object>();
}
AsyncHelper.ContinueTaskWithState(task, source, this,
AsyncHelper.ContinueTaskWithState(
task,
source,
state: this,
onSuccess: (object state) =>
{
SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state;
Expand Down Expand Up @@ -2902,9 +2909,7 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int
}
}
}
},
connectionToDoom: _connection.GetOpenTdsConnection()
);
});
return;
}
else
Expand Down Expand Up @@ -3029,14 +3034,9 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio
_parserLock.Wait(canReleaseFromAnyThread: true);
WriteToServerInternalRestAsync(cts, source);
},
connectionToAbort: _connection,
onFailure: static (_, state) => ((StrongBox<CancellationTokenRegistration>)state).Value.Dispose(),
onCancellation: static state => ((StrongBox<CancellationTokenRegistration>)state).Value.Dispose(),
#if NET
exceptionConverter: ex => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex)
#else
exceptionConverter: (ex, _) => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex)
#endif
);
return;
}
Expand Down Expand Up @@ -3085,10 +3085,11 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio

if (internalResultsTask != null)
{
AsyncHelper.ContinueTaskWithState(internalResultsTask, source, this,
onSuccess: (object state) => ((SqlBulkCopy)state).WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source),
connectionToDoom: _connection.GetOpenTdsConnection()
);
AsyncHelper.ContinueTaskWithState(
internalResultsTask,
source,
state: this,
onSuccess: (object state) => ((SqlBulkCopy)state).WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source));
}
else
{
Expand Down Expand Up @@ -3169,9 +3170,7 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken)
{
sqlBulkCopy.WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee.
}
},
connectionToDoom: _connection.GetOpenTdsConnection()
);
});
return resultTask;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1749,11 +1749,7 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption(
onCancellation: static state =>
{
((SqlCommand)state).CachedAsyncState?.ResetAsyncState();
}
#if NETFRAMEWORK
, connectionToAbort: _activeConnection
#endif
);
});

task = completion.Task;
return ds;
Expand Down
109 changes: 20 additions & 89 deletions src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ internal static ArgumentOutOfRangeException InvalidMinAndMaxPair(string minParam

internal static class AsyncHelper
{
internal static Task CreateContinuationTask(Task task, Action onSuccess,
#if NETFRAMEWORK
SqlInternalConnectionTds connectionToDoom = null,
#endif
Action<Exception> onFailure = null)
internal static Task CreateContinuationTask(
Task task,
Action onSuccess,
Action<Exception> onFailure = null)
{
if (task == null)
{
Expand All @@ -65,8 +64,9 @@ internal static Task CreateContinuationTask(Task task, Action onSuccess,
else
{
TaskCompletionSource<object> completion = new TaskCompletionSource<object>();
#if NET
ContinueTaskWithState(task, completion,
ContinueTaskWithState(
task,
completion,
state: Tuple.Create(onSuccess, onFailure, completion),
onSuccess: static (object state) =>
{
Expand All @@ -82,16 +82,6 @@ internal static Task CreateContinuationTask(Task task, Action onSuccess,
Action<Exception> failure = parameters.Item2;
failure?.Invoke(exception);
}
#else
ContinueTask(task, completion,
onSuccess: () =>
{
onSuccess();
completion.SetResult(null);
},
onFailure: onFailure,
connectionToDoom: connectionToDoom
#endif
);
return completion.Task;
}
Expand Down Expand Up @@ -119,32 +109,23 @@ internal static Task CreateContinuationTaskWithState(Task task, object state, Ac
}
}

internal static Task CreateContinuationTask<T1, T2>(Task task, Action<T1, T2> onSuccess, T1 arg1, T2 arg2, SqlInternalConnectionTds connectionToDoom = null, Action<Exception> onFailure = null)
internal static Task CreateContinuationTask<T1, T2>(
Task task,
Action<T1, T2> onSuccess,
T1 arg1,
T2 arg2,
Action<Exception> onFailure = null)
{
return CreateContinuationTask(task, () => onSuccess(arg1, arg2),
#if NETFRAMEWORK
connectionToDoom,
#endif
onFailure);
return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure);
}

internal static void ContinueTask(Task task,
TaskCompletionSource<object> completion,
Action onSuccess,
Action<Exception> onFailure = null,
Action onCancellation = null,
Func<Exception, Exception> exceptionConverter = null,
#if NET
SqlInternalConnectionTds connectionToDoom = null
#else
SqlInternalConnectionTds connectionToDoom = null,
SqlConnection connectionToAbort = null
#endif
)
Func<Exception, Exception> exceptionConverter = null)
{
#if NETFRAMEWORK
Debug.Assert((connectionToAbort == null) || (connectionToDoom == null), "Should not specify both connectionToDoom and connectionToAbort");
#endif
task.ContinueWith(
tsk =>
{
Expand Down Expand Up @@ -177,42 +158,16 @@ internal static void ContinueTask(Task task,
}
else
{
#if NETFRAMEWORK
if (connectionToDoom != null || connectionToAbort != null)
{
try
{
onSuccess();
}
// @TODO: CER Exception Handling was removed here (see GH#3581)
catch (Exception e)
{
completion.SetException(e);
}
}
else
{ // no connection to doom - reliability section not required
try
{
onSuccess();
}
catch (Exception e)
{
completion.SetException(e);
}
}
}
#else
try
{
onSuccess();
}
// @TODO: CER Exception Handling was removed here (see GH#3581)
catch (Exception e)
{
completion.SetException(e);
}
}
#endif
}, TaskScheduler.Default
);
}
Expand All @@ -225,18 +180,8 @@ internal static void ContinueTaskWithState(Task task,
Action<object> onSuccess,
Action<Exception, object> onFailure = null,
Action<object> onCancellation = null,
#if NET
Func<Exception, Exception> exceptionConverter = null,
#else
Func<Exception, object, Exception> exceptionConverter = null,
#endif
SqlInternalConnectionTds connectionToDoom = null,
SqlConnection connectionToAbort = null
)
Func<Exception, Exception> exceptionConverter = null)
{
#if NETFRAMEWORK
Debug.Assert((connectionToAbort == null) || (connectionToDoom == null), "Should not specify both connectionToDoom and connectionToAbort");
#endif
task.ContinueWith(
(Task tsk, object state2) =>
{
Expand All @@ -245,12 +190,9 @@ internal static void ContinueTaskWithState(Task task,
Exception exc = tsk.Exception.InnerException;
if (exceptionConverter != null)
{
exc = exceptionConverter(exc
#if NETFRAMEWORK
, state2
#endif
);
exc = exceptionConverter(exc);
}

try
{
onFailure?.Invoke(exc, state2);
Expand All @@ -271,24 +213,13 @@ internal static void ContinueTaskWithState(Task task,
completion.TrySetCanceled();
}
}
else if (connectionToDoom != null || connectionToAbort != null)
{
try
{
onSuccess(state2);
}
// @TODO: CER Exception Handling was removed here (see GH#3581)
catch (Exception e)
{
completion.SetException(e);
}
}
else
{
try
{
onSuccess(state2);
}
// @TODO: CER Exception Handling was removed here (see GH#3581)
catch (Exception e)
{
completion.SetException(e);
Expand Down
Loading
Loading