Skip to content

fix: fix downloading different URLs to same destination #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions Tests.Vpn.Service/DownloaderTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,34 @@ public async Task Download(CancellationToken ct)
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
}

[Test(Description = "Perform 2 downloads with the same destination")]
[CancelAfter(30_000)]
public async Task DownloadSameDest(CancellationToken ct)
{
using var httpServer = EchoServer();
var url0 = new Uri(httpServer.BaseUrl + "/test0");
var url1 = new Uri(httpServer.BaseUrl + "/test1");
var destPath = Path.Combine(_tempDir, "test");

var manager = new Downloader(NullLogger<Downloader>.Instance);
var startTask0 = manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url0), destPath,
NullDownloadValidator.Instance, ct);
var startTask1 = manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url1), destPath,
NullDownloadValidator.Instance, ct);
var dlTask0 = await startTask0;
await dlTask0.Task;
Assert.That(dlTask0.TotalBytes, Is.EqualTo(5));
Assert.That(dlTask0.BytesRead, Is.EqualTo(5));
Assert.That(dlTask0.Progress, Is.EqualTo(1));
Assert.That(dlTask0.IsCompleted, Is.True);
var dlTask1 = await startTask1;
await dlTask1.Task;
Assert.That(dlTask1.TotalBytes, Is.EqualTo(5));
Assert.That(dlTask1.BytesRead, Is.EqualTo(5));
Assert.That(dlTask1.Progress, Is.EqualTo(1));
Assert.That(dlTask1.IsCompleted, Is.True);
}

[Test(Description = "Download with custom headers")]
[CancelAfter(30_000)]
public async Task WithHeaders(CancellationToken ct)
Expand Down Expand Up @@ -347,17 +375,17 @@ public async Task DownloadExistingDifferentContent(CancellationToken ct)

[Test(Description = "Unexpected response code from server")]
[CancelAfter(30_000)]
public void UnexpectedResponseCode(CancellationToken ct)
public async Task UnexpectedResponseCode(CancellationToken ct)
{
using var httpServer = new TestHttpServer(ctx => { ctx.Response.StatusCode = 404; });
var url = new Uri(httpServer.BaseUrl + "/test");
var destPath = Path.Combine(_tempDir, "test");

var manager = new Downloader(NullLogger<Downloader>.Instance);
// The "outer" Task should fail.
var ex = Assert.ThrowsAsync<HttpRequestException>(async () =>
await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, ct));
// The "inner" Task should fail.
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, ct);
var ex = Assert.ThrowsAsync<HttpRequestException>(async () => await dlTask.Task);
Assert.That(ex.Message, Does.Contain("404"));
}

Expand All @@ -384,22 +412,6 @@ public async Task MismatchedETag(CancellationToken ct)
Assert.That(ex.Message, Does.Contain("ETag does not match SHA1 hash of downloaded file").And.Contains("beef"));
}

[Test(Description = "Timeout on response headers")]
[CancelAfter(30_000)]
public void CancelledOuter(CancellationToken ct)
{
using var httpServer = new TestHttpServer(async _ => { await Task.Delay(TimeSpan.FromSeconds(5), ct); });
var url = new Uri(httpServer.BaseUrl + "/test");
var destPath = Path.Combine(_tempDir, "test");

var manager = new Downloader(NullLogger<Downloader>.Instance);
// The "outer" Task should fail.
var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token;
Assert.ThrowsAsync<TaskCanceledException>(
async () => await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, smallerCt));
}

[Test(Description = "Timeout on response body")]
[CancelAfter(30_000)]
public async Task CancelledInner(CancellationToken ct)
Expand Down Expand Up @@ -451,12 +463,10 @@ public async Task ValidationFailureExistingFile(CancellationToken ct)
await File.WriteAllTextAsync(destPath, "test", ct);

var manager = new Downloader(NullLogger<Downloader>.Instance);
// The "outer" Task should fail because the inner task never starts.
var ex = Assert.ThrowsAsync<Exception>(async () =>
{
await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
new TestDownloadValidator(new Exception("test exception")), ct);
});
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
new TestDownloadValidator(new Exception("test exception")), ct);
// The "inner" Task should fail.
var ex = Assert.ThrowsAsync<Exception>(async () => { await dlTask.Task; });
Assert.That(ex.Message, Does.Contain("Existing file failed validation"));
Assert.That(ex.InnerException, Is.Not.Null);
Assert.That(ex.InnerException!.Message, Is.EqualTo("test exception"));
Expand Down
41 changes: 29 additions & 12 deletions Vpn.Service/Downloader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Formats.Asn1;
using System.Net;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Coder.Desktop.Vpn.Utilities;
Expand Down Expand Up @@ -288,7 +289,26 @@ public async Task<DownloadTask> StartDownloadAsync(HttpRequestMessage req, strin
{
var task = _downloads.GetOrAdd(destinationPath,
_ => new DownloadTask(_logger, req, destinationPath, validator));
await task.EnsureStartedAsync(ct);
// EnsureStarted is a no-op if we didn't create a new DownloadTask.
// So, we will only remove the destination once for each time we start a new task.
task.EnsureStarted(tsk =>
{
// remove the key first, before checking the exception, to ensure
// we still clean up.
_downloads.TryRemove(destinationPath, out _);
if (tsk.Exception == null)
{
return;
}

if (tsk.Exception.InnerException != null)
{
ExceptionDispatchInfo.Capture(tsk.Exception.InnerException).Throw();
}

// not sure if this is hittable, but just in case:
throw tsk.Exception;
}, ct);

// If the existing (or new) task is for the same URL, return it.
if (task.Request.RequestUri == req.RequestUri)
Expand Down Expand Up @@ -357,21 +377,19 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination
".download-" + Path.GetRandomFileName());
}

internal async Task<Task> EnsureStartedAsync(CancellationToken ct = default)
internal void EnsureStarted(Action<Task> continuation, CancellationToken ct = default)
{
using var _ = await _semaphore.LockAsync(ct);
using var _ = _semaphore.Lock();
if (Task == null!)
Task = await StartDownloadAsync(ct);

return Task;
Task = Start(ct).ContinueWith(continuation, ct);
}

/// <summary>
/// Starts downloading the file. The request will be performed in this task, but once started, the task will complete
/// and the download will continue in the background. The provided CancellationToken can be used to cancel the
/// download.
/// </summary>
private async Task<Task> StartDownloadAsync(CancellationToken ct = default)
private async Task Start(CancellationToken ct = default)
{
Directory.CreateDirectory(_destinationDirectory);

Expand All @@ -398,8 +416,7 @@ private async Task<Task> StartDownloadAsync(CancellationToken ct = default)
throw new Exception("Existing file failed validation after 304 Not Modified", e);
}

Task = Task.CompletedTask;
return Task;
return;
}

if (res.StatusCode != HttpStatusCode.OK)
Expand Down Expand Up @@ -432,11 +449,11 @@ private async Task<Task> StartDownloadAsync(CancellationToken ct = default)
throw;
}

Task = DownloadAsync(res, tempFile, ct);
return Task;
await Download(res, tempFile, ct);
return;
}

private async Task DownloadAsync(HttpResponseMessage res, FileStream tempFile, CancellationToken ct)
private async Task Download(HttpResponseMessage res, FileStream tempFile, CancellationToken ct)
{
try
{
Expand Down
Loading