Skip to content

Commit a4a7f60

Browse files
authoredMar 4, 2025
Add support for specifying allowed auth methods (npgsql#6036)
Closes npgsql#6035
1 parent 061a5f2 commit a4a7f60

6 files changed

+264
-6
lines changed
 

‎src/Npgsql/Internal/NpgsqlConnector.Auth.cs

+21
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,43 @@ partial class NpgsqlConnector
1818
{
1919
async Task Authenticate(string username, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken)
2020
{
21+
var requiredAuthModes = Settings.RequireAuthModes;
22+
if (requiredAuthModes == default)
23+
requiredAuthModes = NpgsqlConnectionStringBuilder.ParseAuthMode(PostgresEnvironment.RequireAuth);
24+
25+
var authenticated = false;
26+
2127
while (true)
2228
{
2329
timeout.CheckAndApply(this);
2430
var msg = ExpectAny<AuthenticationRequestMessage>(await ReadMessage(async).ConfigureAwait(false), this);
2531
switch (msg.AuthRequestType)
2632
{
2733
case AuthenticationRequestType.Ok:
34+
// If we didn't complete authentication, check whether it's allowed
35+
if (!authenticated)
36+
ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.None);
2837
return;
2938

3039
case AuthenticationRequestType.CleartextPassword:
40+
ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.Password);
3141
await AuthenticateCleartext(username, async, cancellationToken).ConfigureAwait(false);
3242
break;
3343

3444
case AuthenticationRequestType.MD5Password:
45+
ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.MD5);
3546
await AuthenticateMD5(username, ((AuthenticationMD5PasswordMessage)msg).Salt, async, cancellationToken).ConfigureAwait(false);
3647
break;
3748

3849
case AuthenticationRequestType.SASL:
50+
ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.ScramSHA256);
3951
await AuthenticateSASL(((AuthenticationSASLMessage)msg).Mechanisms, username, async,
4052
cancellationToken).ConfigureAwait(false);
4153
break;
4254

4355
case AuthenticationRequestType.GSS:
4456
case AuthenticationRequestType.SSPI:
57+
ThrowIfNotAllowed(requiredAuthModes, msg.AuthRequestType == AuthenticationRequestType.GSS ? RequireAuthMode.GSS : RequireAuthMode.SSPI);
4558
await DataSource.IntegratedSecurityHandler.NegotiateAuthentication(async, this).ConfigureAwait(false);
4659
return;
4760

@@ -51,6 +64,14 @@ await AuthenticateSASL(((AuthenticationSASLMessage)msg).Mechanisms, username, as
5164
default:
5265
throw new NotSupportedException($"Authentication method not supported (Received: {msg.AuthRequestType})");
5366
}
67+
68+
authenticated = true;
69+
}
70+
71+
static void ThrowIfNotAllowed(RequireAuthMode requiredAuthModes, RequireAuthMode requestedAuthMode)
72+
{
73+
if (!requiredAuthModes.HasFlag(requestedAuthMode))
74+
throw new NpgsqlException($"\"{requestedAuthMode}\" authentication method is not allowed. Allowed methods: {requiredAuthModes}");
5475
}
5576
}
5677

‎src/Npgsql/NpgsqlConnectionStringBuilder.cs

+100
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,70 @@ public ChannelBinding ChannelBinding
683683
}
684684
ChannelBinding _channelBinding;
685685

686+
/// <summary>
687+
/// Controls the available authentication methods.
688+
/// </summary>
689+
[Category("Security")]
690+
[Description("Controls the available authentication methods.")]
691+
[DisplayName("Require Auth")]
692+
[NpgsqlConnectionStringProperty]
693+
public string? RequireAuth
694+
{
695+
get => _requireAuth;
696+
set
697+
{
698+
RequireAuthModes = ParseAuthMode(value);
699+
_requireAuth = value;
700+
SetValue(nameof(RequireAuth), value);
701+
}
702+
}
703+
string? _requireAuth;
704+
705+
internal RequireAuthMode RequireAuthModes { get; private set; }
706+
707+
internal static RequireAuthMode ParseAuthMode(string? value)
708+
{
709+
var modes = value?.Split(',', StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries);
710+
if (modes is not { Length: > 0 })
711+
return RequireAuthMode.All;
712+
713+
var isNegative = false;
714+
RequireAuthMode parsedModes = default;
715+
for (var i = 0; i < modes.Length; i++)
716+
{
717+
var mode = modes[i];
718+
var modeToParse = mode.AsSpan();
719+
if (mode.StartsWith('!'))
720+
{
721+
if (i > 0 && !isNegative)
722+
throw new ArgumentException("Mixing both positive and negative authentication methods is not supported");
723+
724+
modeToParse = modeToParse.Slice(1);
725+
isNegative = true;
726+
}
727+
else
728+
{
729+
if (i > 0 && isNegative)
730+
throw new ArgumentException("Mixing both positive and negative authentication methods is not supported");
731+
}
732+
733+
// Explicitly disallow 'All' as libpq doesn't have it
734+
if (!Enum.TryParse<RequireAuthMode>(modeToParse, out var parsedMode) || parsedMode == RequireAuthMode.All)
735+
throw new ArgumentException($"Unable to parse authentication method \"{modeToParse}\"");
736+
737+
parsedModes |= parsedMode;
738+
}
739+
740+
var allowedModes = isNegative
741+
? (RequireAuthMode)(RequireAuthMode.All - parsedModes)
742+
: parsedModes;
743+
744+
if (allowedModes == default)
745+
throw new ArgumentException($"No authentication method is allowed. Check \"{nameof(RequireAuth)}\" in connection string.");
746+
747+
return allowedModes;
748+
}
749+
686750
#endregion
687751

688752
#region Properties - Pooling
@@ -1735,4 +1799,40 @@ enum ReplicationMode
17351799
Logical
17361800
}
17371801

1802+
/// <summary>
1803+
/// Specifies which authentication methods are supported.
1804+
/// </summary>
1805+
[Flags]
1806+
enum RequireAuthMode
1807+
{
1808+
/// <summary>
1809+
/// Plaintext password.
1810+
/// </summary>
1811+
Password = 1,
1812+
/// <summary>
1813+
/// MD5 hashed password.
1814+
/// </summary>
1815+
MD5 = 2,
1816+
/// <summary>
1817+
/// Kerberos.
1818+
/// </summary>
1819+
GSS = 4,
1820+
/// <summary>
1821+
/// Windows SSPI.
1822+
/// </summary>
1823+
SSPI = 8,
1824+
/// <summary>
1825+
/// SASL.
1826+
/// </summary>
1827+
ScramSHA256 = 16,
1828+
/// <summary>
1829+
/// No authentication exchange.
1830+
/// </summary>
1831+
None = 32,
1832+
/// <summary>
1833+
/// All authentication methods. For internal use.
1834+
/// </summary>
1835+
All = Password | MD5 | GSS | SSPI | ScramSHA256 | None
1836+
}
1837+
17381838
#endregion

‎src/Npgsql/PostgresEnvironment.cs

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ internal static string? SslCertRootDefault
5050

5151
internal static string? SslNegotiation => Environment.GetEnvironmentVariable("PGSSLNEGOTIATION");
5252

53+
internal static string? RequireAuth => Environment.GetEnvironmentVariable("PGREQUIREAUTH");
54+
5355
static string? GetHomeDir()
5456
=> Environment.GetEnvironmentVariable(RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "APPDATA" : "HOME");
5557

‎src/Npgsql/PublicAPI.Unshipped.txt

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ abstract Npgsql.NpgsqlDataSource.Clear() -> void
33
Npgsql.NpgsqlConnection.CloneWithAsync(string! connectionString, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask<Npgsql.NpgsqlConnection!>
44
Npgsql.NpgsqlConnection.SslClientAuthenticationOptionsCallback.get -> System.Action<System.Net.Security.SslClientAuthenticationOptions!>?
55
Npgsql.NpgsqlConnection.SslClientAuthenticationOptionsCallback.set -> void
6+
Npgsql.NpgsqlConnectionStringBuilder.RequireAuth.get -> string?
7+
Npgsql.NpgsqlConnectionStringBuilder.RequireAuth.set -> void
68
Npgsql.NpgsqlConnectionStringBuilder.SslNegotiation.get -> Npgsql.SslNegotiation
79
Npgsql.NpgsqlConnectionStringBuilder.SslNegotiation.set -> void
810
Npgsql.NpgsqlDataSourceBuilder.ConfigureTypeLoading(System.Action<Npgsql.NpgsqlTypeLoadingOptionsBuilder!>! configureAction) -> Npgsql.NpgsqlDataSourceBuilder!

‎test/Npgsql.Tests/ConnectionTests.cs

+133
Original file line numberDiff line numberDiff line change
@@ -1679,6 +1679,139 @@ public async Task PhysicalConnectionInitializer_disposes_connection()
16791679

16801680
#endregion Physical connection initialization
16811681

1682+
#region Require auth
1683+
1684+
[Test]
1685+
public async Task Connect_with_any_auth()
1686+
{
1687+
await using var dataSource = CreateDataSource(csb =>
1688+
{
1689+
csb.RequireAuth = $"{RequireAuthMode.Password},{RequireAuthMode.MD5},{RequireAuthMode.GSS},{RequireAuthMode.SSPI},{RequireAuthMode.ScramSHA256},{RequireAuthMode.None}";
1690+
});
1691+
await using var conn = await dataSource.OpenConnectionAsync();
1692+
}
1693+
1694+
[Test]
1695+
[NonParallelizable] // Sets environment variable
1696+
public async Task Connect_with_any_auth_env()
1697+
{
1698+
using var _ = SetEnvironmentVariable("PGREQUIREAUTH", $"{RequireAuthMode.Password},{RequireAuthMode.MD5},{RequireAuthMode.GSS},{RequireAuthMode.SSPI},{RequireAuthMode.ScramSHA256},{RequireAuthMode.None}");
1699+
await using var dataSource = CreateDataSource();
1700+
await using var conn = await dataSource.OpenConnectionAsync();
1701+
}
1702+
1703+
[Test]
1704+
public async Task Connect_with_any_except_none_auth()
1705+
{
1706+
await using var dataSource = CreateDataSource(csb =>
1707+
{
1708+
csb.RequireAuth = $"!{RequireAuthMode.None}";
1709+
});
1710+
await using var conn = await dataSource.OpenConnectionAsync();
1711+
}
1712+
1713+
[Test]
1714+
[NonParallelizable] // Sets environment variable
1715+
public async Task Connect_with_any_except_none_auth_env()
1716+
{
1717+
using var _ = SetEnvironmentVariable("PGREQUIREAUTH", $"!{RequireAuthMode.None}");
1718+
await using var dataSource = CreateDataSource();
1719+
await using var conn = await dataSource.OpenConnectionAsync();
1720+
}
1721+
1722+
[Test]
1723+
public async Task Fail_connect_with_none_auth()
1724+
{
1725+
await using var dataSource = CreateDataSource(csb =>
1726+
{
1727+
csb.RequireAuth = $"{RequireAuthMode.None}";
1728+
});
1729+
var ex = Assert.ThrowsAsync<NpgsqlException>(async () => await dataSource.OpenConnectionAsync())!;
1730+
Assert.That(ex.Message, Does.Contain("authentication method is not allowed"));
1731+
}
1732+
1733+
[Test]
1734+
[NonParallelizable] // Sets environment variable
1735+
public async Task Fail_connect_with_none_auth_env()
1736+
{
1737+
using var _ = SetEnvironmentVariable("PGREQUIREAUTH", $"{RequireAuthMode.None}");
1738+
await using var dataSource = CreateDataSource();
1739+
var ex = Assert.ThrowsAsync<NpgsqlException>(async () => await dataSource.OpenConnectionAsync())!;
1740+
Assert.That(ex.Message, Does.Contain("authentication method is not allowed"));
1741+
}
1742+
1743+
[Test]
1744+
public async Task Connect_with_md5_auth()
1745+
{
1746+
await using var dataSource = CreateDataSource(csb =>
1747+
{
1748+
csb.RequireAuth = $"{RequireAuthMode.MD5}";
1749+
});
1750+
try
1751+
{
1752+
await using var conn = await dataSource.OpenConnectionAsync();
1753+
}
1754+
catch (Exception e) when (!IsOnBuildServer)
1755+
{
1756+
Console.WriteLine(e);
1757+
Assert.Ignore("MD5 authentication doesn't seem to be set up");
1758+
}
1759+
}
1760+
1761+
[Test]
1762+
[NonParallelizable] // Sets environment variable
1763+
public async Task Connect_with_md5_auth_env()
1764+
{
1765+
using var _ = SetEnvironmentVariable("PGREQUIREAUTH", $"{RequireAuthMode.MD5}");
1766+
await using var dataSource = CreateDataSource();
1767+
try
1768+
{
1769+
await using var conn = await dataSource.OpenConnectionAsync();
1770+
}
1771+
catch (Exception e) when (!IsOnBuildServer)
1772+
{
1773+
Console.WriteLine(e);
1774+
Assert.Ignore("MD5 authentication doesn't seem to be set up");
1775+
}
1776+
}
1777+
1778+
[Test]
1779+
public void Mixed_auth_methods_not_supported([Values(
1780+
$"{nameof(RequireAuthMode.ScramSHA256)},!{nameof(RequireAuthMode.None)}",
1781+
$"!{nameof(RequireAuthMode.ScramSHA256)},{nameof(RequireAuthMode.None)}")]
1782+
string authMethods)
1783+
{
1784+
var csb = new NpgsqlConnectionStringBuilder();
1785+
Assert.Throws<ArgumentException>(() => csb.RequireAuth = authMethods);
1786+
}
1787+
1788+
[Test]
1789+
public void Remove_all_auth_methods_throws()
1790+
{
1791+
var csb = new NpgsqlConnectionStringBuilder();
1792+
Assert.Throws<ArgumentException>(() =>
1793+
csb.RequireAuth = $"!{RequireAuthMode.Password},!{RequireAuthMode.MD5},!{RequireAuthMode.GSS},!{RequireAuthMode.SSPI},!{RequireAuthMode.ScramSHA256},!{RequireAuthMode.None}");
1794+
}
1795+
1796+
[Test]
1797+
public void Unknown_auth_method_throws()
1798+
{
1799+
var csb = new NpgsqlConnectionStringBuilder();
1800+
Assert.Throws<ArgumentException>(() => csb.RequireAuth = "SuperSecure");
1801+
}
1802+
1803+
[Test]
1804+
public void Auth_methods_are_trimmed()
1805+
{
1806+
var csb = new NpgsqlConnectionStringBuilder
1807+
{
1808+
RequireAuth = $"{RequireAuthMode.Password} , {RequireAuthMode.MD5}"
1809+
};
1810+
Assert.That(csb.RequireAuthModes, Is.EqualTo(RequireAuthMode.Password | RequireAuthMode.MD5));
1811+
}
1812+
1813+
#endregion Require auth
1814+
16821815
[Test]
16831816
[NonParallelizable] // Modifies global database info factories
16841817
[IssueLink("https://github.com/npgsql/npgsql/issues/4425")]

‎test/Npgsql.Tests/SecurityTests.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,27 @@ namespace Npgsql.Tests;
1313
public class SecurityTests : TestBase
1414
{
1515
[Test, Description("Establishes an SSL connection, assuming a self-signed server certificate")]
16-
public void Basic_ssl()
16+
public async Task Basic_ssl()
1717
{
18-
using var dataSource = CreateDataSource(csb =>
18+
await using var dataSource = CreateDataSource(csb =>
1919
{
2020
csb.SslMode = SslMode.Require;
2121
});
22-
using var conn = dataSource.OpenConnection();
22+
await using var conn = await dataSource.OpenConnectionAsync();
2323
Assert.That(conn.IsSecure, Is.True);
2424
}
2525

2626
[Test, Description("Default user must run with md5 password encryption")]
27-
public void Default_user_uses_md5_password()
27+
public async Task Default_user_uses_md5_password()
2828
{
2929
if (!IsOnBuildServer)
3030
Assert.Ignore("Only executed in CI");
3131

32-
using var dataSource = CreateDataSource(csb =>
32+
await using var dataSource = CreateDataSource(csb =>
3333
{
3434
csb.SslMode = SslMode.Require;
3535
});
36-
using var conn = dataSource.OpenConnection();
36+
await using var conn = await dataSource.OpenConnectionAsync();
3737
Assert.That(conn.IsScram, Is.False);
3838
Assert.That(conn.IsScramPlus, Is.False);
3939
}

0 commit comments

Comments
 (0)
Please sign in to comment.