Skip to content

Add Oauth authentication for Azure AD SSO #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Snowflake.Client.Tests/Models/TestConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
public class TestConfiguration
{
public SnowflakeConnectionInfo Connection { get; set; }
public string AdClientId { get; set; }
public string AdClientSecret { get; set; }
public string AdServicePrincipalObjectId { get; set; }
public string AdTenantId { get; set; }
public string AdScope { get; set; }
}
}
2 changes: 2 additions & 0 deletions Snowflake.Client.Tests/Snowflake.Client.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Identity.Client" Version="4.53.0" />
<PackageReference Include="moq" Version="4.18.4" />
<PackageReference Include="NUnit" Version="3.13.3" />
<PackageReference Include="NUnit3TestAdapter" Version="4.3.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.4.0" />
Expand Down
34 changes: 34 additions & 0 deletions Snowflake.Client.Tests/UnitTests/AzureAdAuthInfoTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System;
using NUnit.Framework;
using Snowflake.Client.Tests.Models;
using Snowflake.Client.Model;
using System.IO;
using System.Text.Json;

namespace Snowflake.Client.Tests.IntegrationTests
{
[TestFixture]
public class AzureAdAuthInfoTests
{
protected readonly AzureAdAuthInfo _azureAdAuthInfo;

public AzureAdAuthInfoTests()
{
var configJson = File.ReadAllText("testconfig.json");
var testParameters = JsonSerializer.Deserialize<TestConfiguration>(configJson, new JsonSerializerOptions() { PropertyNameCaseInsensitive = true });
var connectionInfo = testParameters.Connection;

_azureAdAuthInfo = new AzureAdAuthInfo(
testParameters.AdClientId,
testParameters.AdClientSecret,
testParameters.AdServicePrincipalObjectId,
testParameters.AdTenantId,
testParameters.AdScope,
connectionInfo.Region,
connectionInfo.Account,
connectionInfo.User,
connectionInfo.Host,
connectionInfo.Role);
}
}
}
32 changes: 32 additions & 0 deletions Snowflake.Client.Tests/UnitTests/AzureAdTokenProviderTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Microsoft.Identity.Client;
using Moq;
using NUnit.Framework;
using Snowflake.Client;
using Snowflake.Client.Model;
using Snowflake.Client.Tests.IntegrationTests;
using System;
using System.Threading;
using System.Threading.Tasks;

namespace Snowflake.Client.Tests
{
public class AzureAdTokenProviderTests : AzureAdAuthInfoTests
{
[Test]
public async Task GetAzureAdAccessTokenAsync_ReturnsAccessToken()
{
var expectedAccessToken = "accessToken";
var mockTokenProvider = new Mock<IAzureAdTokenProvider>();

mockTokenProvider
.Setup(provider => provider.GetAzureAdAccessTokenAsync(It.IsAny<AzureAdAuthInfo>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(expectedAccessToken);

// Act
string actualAccessToken = await mockTokenProvider.Object.GetAzureAdAccessTokenAsync(_azureAdAuthInfo);

// Assert
Assert.AreEqual(expectedAccessToken, actualAccessToken);
}
}
}
41 changes: 41 additions & 0 deletions Snowflake.Client/AzureAdTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using Microsoft.Identity.Client;
using System;
using System.Threading;
using System.Threading.Tasks;
using Snowflake.Client.Model;

namespace Snowflake.Client
{
public class AzureAdTokenProvider : IAzureAdTokenProvider
{
public async Task<string> GetAzureAdAccessTokenAsync(AzureAdAuthInfo authInfo, CancellationToken ct = default)
{
try
{
if (authInfo.ClientId == null || authInfo.ClientSecret == null || authInfo.ServicePrincipalObjectId == null || authInfo.TenantId == null || authInfo.Scope == null)
{
throw new SnowflakeException("Error: One or more required environment variables are missing.", 400);
}

return await GetAccessTokenAsync(authInfo.ClientId, authInfo.ClientSecret, authInfo.ServicePrincipalObjectId, authInfo.TenantId, authInfo.Scope);
}
catch (Exception ex)
{
throw new SnowflakeException($"Failed getting the Azure Token. Message: {ex.Message}", ex);
}
}

private async Task<string> GetAccessTokenAsync(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope)
{
IConfidentialClientApplication app = ConfidentialClientApplicationBuilder.Create(clientId)
.WithClientSecret(clientSecret)
.WithAuthority(new Uri($"https://login.microsoftonline.com/{tenantId}/"))
.Build();

var scopes = new[] { scope };

AuthenticationResult result = await app.AcquireTokenForClient(scopes).ExecuteAsync();
return result.AccessToken;
}
}
}
11 changes: 11 additions & 0 deletions Snowflake.Client/IAzureAdTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System.Threading;
using System.Threading.Tasks;
using Snowflake.Client.Model;

namespace Snowflake.Client
{
public interface IAzureAdTokenProvider
{
Task<string> GetAzureAdAccessTokenAsync(AzureAdAuthInfo authInfo, CancellationToken ct = default);
}
}
2 changes: 1 addition & 1 deletion Snowflake.Client/Model/AuthInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
/// <summary>
/// Snowflake Authentication information.
/// </summary>
public class AuthInfo
public class AuthInfo : IAuthInfo
{
/// <summary>
/// Your Snowflake account name
Expand Down
29 changes: 29 additions & 0 deletions Snowflake.Client/Model/AzureAdAuthInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
namespace Snowflake.Client.Model
{
public class AzureAdAuthInfo : AuthInfo
{
public string ClientId { get; set; }
public string ClientSecret { get; set; }
public string ServicePrincipalObjectId { get; set; }
public string TenantId { get; set; }
public string Scope { get; set; }
public string Host {get; set; }
public string Role {get; set; }


public AzureAdAuthInfo(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope, string region, string account, string user, string host, string role)
: base(user, account, region)
{
ClientId = clientId;
ClientSecret = clientSecret;
ServicePrincipalObjectId = servicePrincipalObjectId;
TenantId = tenantId;
Scope = scope;
Region = region;
Account = account;
User = user;
Host = host;
Role = role;
}
}
}
11 changes: 11 additions & 0 deletions Snowflake.Client/Model/IAuthInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace Snowflake.Client.Model
{
public interface IAuthInfo
{
string Account { get; set; }
string User { get; set; }
string Region { get; set; }

string ToString();
}
}
28 changes: 18 additions & 10 deletions Snowflake.Client/RequestBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,27 @@ internal void ClearSessionTokens()
_masterToken = null;
}

internal HttpRequestMessage BuildLoginRequest(AuthInfo authInfo, SessionInfo sessionInfo)
internal HttpRequestMessage BuildLoginRequest(AuthInfo authInfo, SessionInfo sessionInfo, String azureAdAccessToken = null)
{
var requestUri = BuildLoginUrl(sessionInfo);
var data = new LoginRequestData();

if (authInfo is AzureAdAuthInfo azureAdAuthInfo) {
data = new LoginRequestData() {
Authenticator = "OAUTH",
Token = azureAdAccessToken,
};
} else {
data = new LoginRequestData() {
Password = authInfo.Password,
};
}

var data = new LoginRequestData()
{
LoginName = authInfo.User,
Password = authInfo.Password,
AccountName = authInfo.Account,
ClientAppId = _clientInfo.DriverName,
ClientAppVersion = _clientInfo.DriverVersion,
ClientEnvironment = _clientInfo.Environment
};
data.LoginName = authInfo.User;
data.AccountName = authInfo.Account;
data.ClientAppId = _clientInfo.DriverName;
data.ClientAppVersion = _clientInfo.DriverVersion;
data.ClientEnvironment = _clientInfo.Environment;

var requestBody = new LoginRequest() { Data = data };
var jsonBody = JsonSerializer.Serialize(requestBody, _jsonSerializerOptions);
Expand Down
4 changes: 4 additions & 0 deletions Snowflake.Client/Snowflake.Client.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,8 @@ Provides straightforward and efficient way to execute SQL queries in Snowflake a
</None>
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Identity.Client" Version="4.53.0" />
</ItemGroup>

</Project>
51 changes: 51 additions & 0 deletions Snowflake.Client/SnowflakeClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client;

namespace Snowflake.Client
{
Expand All @@ -23,11 +24,41 @@ public class SnowflakeClient : ISnowflakeClient
/// </summary>
public SnowflakeClientSettings Settings => _clientSettings;

/// <summary>
/// Azure AD Token Provider
/// </summary>
private readonly AzureAdTokenProvider _azureAdTokenProvider;

private SnowflakeSession _snowflakeSession;
private readonly RestClient _restClient;
private readonly RequestBuilder _requestBuilder;
private readonly SnowflakeClientSettings _clientSettings;

/// <summary>
/// Creates new Snowflake client.
/// </summary>
/// <param name="clientId">Client ID</param>
/// <param name="clientSecret">Client Secret</param>
/// <param name="servicePrincipalObjectId">Service Principal Object ID</param>
/// <param name="tenantId">Tenant ID</param>
/// <param name="scope">Scope</param>
/// <param name="region">Region: "us-east-1", etc. Required for all except for US West Oregon (us-west-2).</param>
/// <param name="account">Account</param>
/// <param name="user">Username</param>
/// <param name="host">Host</param>
/// <param name="role">Role</param>
public SnowflakeClient(string clientId, string clientSecret, string servicePrincipalObjectId, string tenantId, string scope, string region, string account, string user, string host, string role)
: this(new AzureAdAuthInfo(clientId, clientSecret, servicePrincipalObjectId, tenantId, scope, region, account, user, host, role), urlInfo: new UrlInfo
{
Host = host,
},
sessionInfo: new SessionInfo
{
Role = role,
})
{
}

/// <summary>
/// Creates new Snowflake client.
/// </summary>
Expand All @@ -52,6 +83,11 @@ public SnowflakeClient(AuthInfo authInfo, SessionInfo sessionInfo = null, UrlInf
{
}

public SnowflakeClient(AzureAdAuthInfo authInfo, SessionInfo sessionInfo = null, UrlInfo urlInfo = null, JsonSerializerOptions jsonMapperOptions = null)
: this(new SnowflakeClientSettings(authInfo, sessionInfo, urlInfo, jsonMapperOptions))
{
}

/// <summary>
/// Creates new Snowflake client.
/// </summary>
Expand All @@ -63,6 +99,7 @@ public SnowflakeClient(SnowflakeClientSettings settings)
_clientSettings = settings;
_restClient = new RestClient();
_requestBuilder = new RequestBuilder(settings.UrlInfo);
_azureAdTokenProvider = new AzureAdTokenProvider();

SnowflakeDataMapper.Configure(settings.JsonMapperOptions);
ChunksDownloader.Configure(settings.ChunksDownloaderOptions);
Expand Down Expand Up @@ -104,10 +141,24 @@ public async Task<bool> InitNewSessionAsync(CancellationToken ct = default)
return true;
}

/// <summary>
/// Authenticates user and returns new Snowflake session.
/// </summary>
/// <returns>New Snowflake session</returns>
private async Task<SnowflakeSession> AuthenticateAsync(AuthInfo authInfo, SessionInfo sessionInfo, CancellationToken ct)
{
var loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo);

if(authInfo is AzureAdAuthInfo azureAdAuthInfo)
{
var azureAdAccessToken = await _azureAdTokenProvider.GetAzureAdAccessTokenAsync(azureAdAuthInfo, ct).ConfigureAwait(false);
loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo, azureAdAccessToken);
}
else
{
loginRequest = _requestBuilder.BuildLoginRequest(authInfo, sessionInfo);
}

var response = await _restClient.SendAsync<LoginResponse>(loginRequest, ct).ConfigureAwait(false);

if (!response.Success)
Expand Down