Skip to content

Commit 6ec779a

Browse files
committed
Replace CredentialProvider with a more powerful AuthHandler
1 parent 6675966 commit 6ec779a

File tree

11 files changed

+349
-189
lines changed

11 files changed

+349
-189
lines changed

driver/driver_options_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ func TestDriverOptions_namedValueChecker(t *testing.T) {
266266
}
267267

268268
func createMockServer(t *testing.T) *testServer {
269-
inMemProvider := server.NewInMemoryProvider()
270-
require.NoError(t, inMemProvider.AddUser(*testUser, *testPassword))
269+
authHandler := server.NewInMemoryAuthenticationHandler()
270+
require.NoError(t, authHandler.AddUser(*testUser, *testPassword))
271271
defaultServer := server.NewDefaultServer()
272272

273273
l, err := net.Listen("tcp", "127.0.0.1:3307")
@@ -285,7 +285,7 @@ func createMockServer(t *testing.T) *testServer {
285285
}
286286

287287
go func() {
288-
co, err := s.NewCustomizedConn(conn, inMemProvider, handler)
288+
co, err := s.NewCustomizedConn(conn, authHandler, handler)
289289
if err != nil {
290290
return
291291
}

mysql/util.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,14 @@ func CalcNativePassword(scramble, password []byte) []byte {
5656
return Xor(scrambleHash, stage1)
5757
}
5858

59-
// Xor modifies hash1 in-place with XOR against hash2
59+
// Xor returns a new slice with hash1 XOR hash2
6060
func Xor(hash1 []byte, hash2 []byte) []byte {
6161
l := min(len(hash1), len(hash2))
62+
result := make([]byte, l)
6263
for i := range l {
63-
hash1[i] ^= hash2[i]
64+
result[i] = hash1[i] ^ hash2[i]
6465
}
65-
return hash1
66+
return result
6667
}
6768

6869
// hash_stage1 = xor(reply, sha1(public_seed, hash_stage2))

server/auth.go

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,26 +30,25 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err
3030
return c.serverConf.authProvider.Authenticate(c, authPluginName, clientAuthData)
3131
}
3232

33-
func (c *Conn) acquirePassword() error {
34-
if c.credential.Password != "" {
33+
func (c *Conn) acquireCredential() error {
34+
if len(c.credential.Passwords) > 0 {
3535
return nil
3636
}
37-
credential, found, err := c.credentialProvider.GetCredential(c.user)
37+
credential, found, err := c.authHandler.GetCredential(c.user)
3838
if err != nil {
3939
return err
4040
}
41-
if !found {
41+
if !found || len(credential.Passwords) == 0 {
4242
return mysql.NewDefaultError(mysql.ER_NO_SUCH_USER, c.user, c.RemoteAddr().String())
4343
}
4444
c.credential = credential
4545
return nil
4646
}
4747

4848
func errAccessDenied(credential Credential) error {
49-
if credential.Password == "" {
49+
if credential.HasEmptyPassword() {
5050
return ErrAccessDeniedNoPassword
5151
}
52-
5352
return ErrAccessDenied
5453
}
5554

@@ -74,20 +73,26 @@ func scrambleValidation(cached, nonce, scramble []byte) bool {
7473
}
7574

7675
func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential Credential) error {
77-
password, err := mysql.DecodePasswordHex(c.credential.Password)
78-
if err != nil {
79-
return errAccessDenied(credential)
80-
}
81-
if mysql.CompareNativePassword(clientAuthData, password, c.salt) {
82-
return nil
76+
for _, password := range credential.Passwords {
77+
hash, err := credential.HashPassword(password)
78+
if err != nil {
79+
continue
80+
}
81+
decoded, err := mysql.DecodePasswordHex(hash)
82+
if err != nil {
83+
continue
84+
}
85+
if mysql.CompareNativePassword(clientAuthData, decoded, c.salt) {
86+
return nil
87+
}
8388
}
8489
return errAccessDenied(credential)
8590
}
8691

8792
func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential Credential) error {
8893
// Empty passwords are not hashed, but sent as empty string
8994
if len(clientAuthData) == 0 {
90-
if credential.Password == "" {
95+
if credential.HasEmptyPassword() {
9196
return nil
9297
}
9398
return ErrAccessDenied
@@ -113,20 +118,26 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential C
113118
clientAuthData = clientAuthData[:l-1]
114119
}
115120
}
116-
check, err := mysql.Check256HashingPassword([]byte(credential.Password), string(clientAuthData))
117-
if err != nil {
118-
return err
119-
}
120-
if check {
121-
return nil
121+
for _, password := range credential.Passwords {
122+
hash, err := credential.HashPassword(password)
123+
if err != nil {
124+
continue
125+
}
126+
check, err := mysql.Check256HashingPassword([]byte(hash), string(clientAuthData))
127+
if err != nil {
128+
continue
129+
}
130+
if check {
131+
return nil
132+
}
122133
}
123-
return ErrAccessDenied
134+
return errAccessDenied(credential)
124135
}
125136

126137
func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error {
127138
// Empty passwords are not hashed, but sent as empty string
128139
if len(clientAuthData) == 0 {
129-
if c.credential.Password == "" {
140+
if c.credential.HasEmptyPassword() {
130141
return nil
131142
}
132143
return ErrAccessDenied

server/auth_handler.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package server
2+
3+
import (
4+
"sync"
5+
6+
"github.com/go-mysql-org/go-mysql/mysql"
7+
"github.com/pingcap/errors"
8+
"github.com/pingcap/tidb/pkg/parser/auth"
9+
)
10+
11+
// AuthenticationHandler provides user credentials and authentication lifecycle hooks.
12+
//
13+
// # Important Note
14+
//
15+
// if the password in a third-party auth handler could be updated at runtime, we have to invalidate the caching
16+
// for 'caching_sha2_password' by calling 'func (s *Server)InvalidateCache(string, string)'.
17+
type AuthenticationHandler interface {
18+
// get user credential (supports multiple valid passwords per user)
19+
GetCredential(username string) (credential Credential, found bool, err error)
20+
21+
// OnAuthSuccess is called after successful authentication, before the OK packet.
22+
// Return an error to reject the connection (error will be sent to client instead of OK).
23+
// Return nil to proceed with sending the OK packet.
24+
OnAuthSuccess(conn *Conn) error
25+
26+
// OnAuthFailure is called after authentication fails, before the error packet.
27+
// This is informational only - the connection will be closed regardless.
28+
OnAuthFailure(conn *Conn, err error)
29+
}
30+
31+
func NewInMemoryAuthenticationHandler(defaultAuthMethod ...string) *InMemoryAuthenticationHandler {
32+
d := mysql.AUTH_CACHING_SHA2_PASSWORD
33+
if len(defaultAuthMethod) > 0 {
34+
d = defaultAuthMethod[0]
35+
}
36+
return &InMemoryAuthenticationHandler{
37+
userPool: sync.Map{},
38+
defaultAuthMethod: d,
39+
}
40+
}
41+
42+
type Credential struct {
43+
Passwords []string // raw passwords, hashed on demand during comparison
44+
AuthPluginName string
45+
}
46+
47+
// HashPassword computes the password hash for a given password using the credential's auth plugin.
48+
func (c Credential) HashPassword(password string) (string, error) {
49+
if password == "" {
50+
return "", nil
51+
}
52+
53+
switch c.AuthPluginName {
54+
case mysql.AUTH_NATIVE_PASSWORD:
55+
return mysql.EncodePasswordHex(mysql.NativePasswordHash([]byte(password))), nil
56+
57+
case mysql.AUTH_CACHING_SHA2_PASSWORD:
58+
return auth.NewHashPassword(password, mysql.AUTH_CACHING_SHA2_PASSWORD), nil
59+
60+
case mysql.AUTH_SHA256_PASSWORD:
61+
return mysql.NewSha256PasswordHash(password)
62+
63+
case mysql.AUTH_CLEAR_PASSWORD:
64+
return password, nil
65+
66+
default:
67+
return "", errors.Errorf("unknown authentication plugin name '%s'", c.AuthPluginName)
68+
}
69+
}
70+
71+
// HasEmptyPassword returns true if any password in the credential is empty.
72+
func (c Credential) HasEmptyPassword() bool {
73+
for _, p := range c.Passwords {
74+
if p == "" {
75+
return true
76+
}
77+
}
78+
return false
79+
}
80+
81+
// InMemoryAuthenticationHandler implements AuthenticationHandler with in-memory credential storage.
82+
type InMemoryAuthenticationHandler struct {
83+
userPool sync.Map // username -> Credential
84+
defaultAuthMethod string
85+
}
86+
87+
func (h *InMemoryAuthenticationHandler) CheckUsername(username string) (found bool, err error) {
88+
_, ok := h.userPool.Load(username)
89+
return ok, nil
90+
}
91+
92+
func (h *InMemoryAuthenticationHandler) GetCredential(username string) (credential Credential, found bool, err error) {
93+
v, ok := h.userPool.Load(username)
94+
if !ok {
95+
return Credential{}, false, nil
96+
}
97+
c, valid := v.(Credential)
98+
if !valid {
99+
return Credential{}, true, errors.Errorf("invalid credential")
100+
}
101+
return c, true, nil
102+
}
103+
104+
func (h *InMemoryAuthenticationHandler) AddUser(username, password string, optionalAuthPluginName ...string) error {
105+
authPluginName := h.defaultAuthMethod
106+
if len(optionalAuthPluginName) > 0 {
107+
authPluginName = optionalAuthPluginName[0]
108+
}
109+
110+
if !isAuthMethodSupported(authPluginName) {
111+
return errors.Errorf("unknown authentication plugin name '%s'", authPluginName)
112+
}
113+
114+
h.userPool.Store(username, Credential{
115+
Passwords: []string{password},
116+
AuthPluginName: authPluginName,
117+
})
118+
return nil
119+
}
120+
121+
func (h *InMemoryAuthenticationHandler) OnAuthSuccess(conn *Conn) error {
122+
return nil
123+
}
124+
125+
func (h *InMemoryAuthenticationHandler) OnAuthFailure(conn *Conn, err error) {
126+
}

server/auth_handler_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package server
2+
3+
import (
4+
"database/sql"
5+
"net"
6+
"sync/atomic"
7+
"testing"
8+
"time"
9+
10+
_ "github.com/go-sql-driver/mysql"
11+
"github.com/go-mysql-org/go-mysql/mysql"
12+
"github.com/pingcap/errors"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
type hookTrackingAuthenticationHandler struct {
17+
*InMemoryAuthenticationHandler
18+
onSuccessCalled atomic.Int32
19+
onFailureCalled atomic.Int32
20+
rejectOnSuccess bool
21+
}
22+
23+
func (h *hookTrackingAuthenticationHandler) OnAuthSuccess(conn *Conn) error {
24+
h.onSuccessCalled.Add(1)
25+
if h.rejectOnSuccess {
26+
return errors.New("connection rejected by policy")
27+
}
28+
return nil
29+
}
30+
31+
func (h *hookTrackingAuthenticationHandler) OnAuthFailure(conn *Conn, err error) {
32+
h.onFailureCalled.Add(1)
33+
}
34+
35+
func TestOnAuthSuccessCalled(t *testing.T) {
36+
handler := &hookTrackingAuthenticationHandler{
37+
InMemoryAuthenticationHandler: NewInMemoryAuthenticationHandler(mysql.AUTH_NATIVE_PASSWORD),
38+
}
39+
require.NoError(t, handler.AddUser("testuser", "testpass"))
40+
41+
l, err := net.Listen("tcp", "127.0.0.1:0")
42+
require.NoError(t, err)
43+
defer l.Close()
44+
45+
go func() {
46+
conn, _ := l.Accept()
47+
co, _ := NewDefaultServer().NewCustomizedConn(conn, handler, &EmptyHandler{})
48+
if co != nil {
49+
for co.HandleCommand() == nil {
50+
}
51+
}
52+
}()
53+
54+
db, err := sql.Open("mysql", "testuser:testpass@tcp("+l.Addr().String()+")/test")
55+
require.NoError(t, err)
56+
defer db.Close()
57+
db.SetConnMaxLifetime(time.Second)
58+
59+
require.NoError(t, db.Ping())
60+
require.Equal(t, int32(1), handler.onSuccessCalled.Load())
61+
require.Equal(t, int32(0), handler.onFailureCalled.Load())
62+
}
63+
64+
func TestOnAuthSuccessCanReject(t *testing.T) {
65+
handler := &hookTrackingAuthenticationHandler{
66+
InMemoryAuthenticationHandler: NewInMemoryAuthenticationHandler(mysql.AUTH_NATIVE_PASSWORD),
67+
rejectOnSuccess: true,
68+
}
69+
require.NoError(t, handler.AddUser("testuser", "testpass"))
70+
71+
l, err := net.Listen("tcp", "127.0.0.1:0")
72+
require.NoError(t, err)
73+
defer l.Close()
74+
75+
go func() {
76+
conn, _ := l.Accept()
77+
co, _ := NewDefaultServer().NewCustomizedConn(conn, handler, &EmptyHandler{})
78+
if co != nil {
79+
for co.HandleCommand() == nil {
80+
}
81+
}
82+
}()
83+
84+
db, err := sql.Open("mysql", "testuser:testpass@tcp("+l.Addr().String()+")/test")
85+
require.NoError(t, err)
86+
defer db.Close()
87+
db.SetConnMaxLifetime(time.Second)
88+
89+
err = db.Ping()
90+
require.Error(t, err)
91+
require.Contains(t, err.Error(), "connection rejected by policy")
92+
require.Equal(t, int32(1), handler.onSuccessCalled.Load())
93+
}
94+
95+
func TestOnAuthFailureCalled(t *testing.T) {
96+
handler := &hookTrackingAuthenticationHandler{
97+
InMemoryAuthenticationHandler: NewInMemoryAuthenticationHandler(mysql.AUTH_NATIVE_PASSWORD),
98+
}
99+
require.NoError(t, handler.AddUser("testuser", "testpass"))
100+
101+
l, err := net.Listen("tcp", "127.0.0.1:0")
102+
require.NoError(t, err)
103+
defer l.Close()
104+
105+
go func() {
106+
conn, _ := l.Accept()
107+
co, _ := NewDefaultServer().NewCustomizedConn(conn, handler, &EmptyHandler{})
108+
if co != nil {
109+
for co.HandleCommand() == nil {
110+
}
111+
}
112+
}()
113+
114+
db, err := sql.Open("mysql", "testuser:wrongpass@tcp("+l.Addr().String()+")/test")
115+
require.NoError(t, err)
116+
defer db.Close()
117+
db.SetConnMaxLifetime(time.Second)
118+
119+
require.Error(t, db.Ping())
120+
require.Equal(t, int32(0), handler.onSuccessCalled.Load())
121+
require.Equal(t, int32(1), handler.onFailureCalled.Load())
122+
}

0 commit comments

Comments
 (0)