diff --git a/server/auth.go b/server/auth.go index 8d6830727..e675ed7ef 100644 --- a/server/auth.go +++ b/server/auth.go @@ -45,14 +45,6 @@ func (c *Conn) acquirePassword() error { return nil } -func errAccessDenied(credential Credential) error { - if credential.Password == "" { - return ErrAccessDeniedNoPassword - } - - return ErrAccessDenied -} - func scrambleValidation(cached, nonce, scramble []byte) bool { // SHA256(SHA256(SHA256(STORED_PASSWORD)), NONCE) crypt := sha256.New() @@ -74,14 +66,21 @@ func scrambleValidation(cached, nonce, scramble []byte) bool { } func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential Credential) error { + if len(clientAuthData) == 0 { + if credential.Password == "" { + return nil + } + return ErrAccessDeniedNoPassword + } + password, err := mysql.DecodePasswordHex(c.credential.Password) if err != nil { - return errAccessDenied(credential) + return ErrAccessDenied } if mysql.CompareNativePassword(clientAuthData, password, c.salt) { return nil } - return errAccessDenied(credential) + return ErrAccessDenied } func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential Credential) error { @@ -90,7 +89,7 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential C if credential.Password == "" { return nil } - return ErrAccessDenied + return ErrAccessDeniedNoPassword } if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok { if !tlsConn.ConnectionState().HandshakeComplete { @@ -129,7 +128,7 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { if c.credential.Password == "" { return nil } - return ErrAccessDenied + return ErrAccessDeniedNoPassword } // the caching of 'caching_sha2_password' in MySQL, see: https://dev.mysql.com/worklog/task/?id=9591 // check if we have a cached value @@ -141,7 +140,7 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { return c.writeAuthMoreDataFastAuth() } - return errAccessDenied(c.credential) + return ErrAccessDenied } // cache miss, do full auth if err := c.writeAuthMoreDataFullAuth(); err != nil { diff --git a/server/auth_switch_response.go b/server/auth_switch_response.go index 5de841acb..79ed4c34c 100644 --- a/server/auth_switch_response.go +++ b/server/auth_switch_response.go @@ -71,11 +71,18 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { } func (c *Conn) checkSha2CacheCredentials(clientAuthData []byte, credential Credential) error { + if len(clientAuthData) == 0 { + if credential.Password == "" { + return nil + } + return ErrAccessDeniedNoPassword + } + match, err := auth.CheckHashingPassword([]byte(credential.Password), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD) if match && err == nil { return nil } - return errAccessDenied(credential) + return ErrAccessDenied } func (c *Conn) writeCachingSha2Cache(authData []byte) { diff --git a/server/auth_switch_response_test.go b/server/auth_switch_response_test.go new file mode 100644 index 000000000..9b77ef260 --- /dev/null +++ b/server/auth_switch_response_test.go @@ -0,0 +1,43 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckSha2CacheCredentials_EmptyPassword(t *testing.T) { + tests := []struct { + name string + clientAuthData []byte + serverPassword string + wantErr error + }{ + { + name: "empty client auth, empty server password", + clientAuthData: []byte{}, + serverPassword: "", + wantErr: nil, + }, + { + name: "empty client auth, non-empty server password", + clientAuthData: []byte{}, + serverPassword: "secret", + wantErr: ErrAccessDeniedNoPassword, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Conn{ + credential: Credential{Password: tt.serverPassword}, + } + err := c.checkSha2CacheCredentials(tt.clientAuthData, c.credential) + if tt.wantErr == nil { + require.NoError(t, err) + } else { + require.ErrorIs(t, err, tt.wantErr) + } + }) + } +} diff --git a/server/auth_test.go b/server/auth_test.go new file mode 100644 index 000000000..a7f24227b --- /dev/null +++ b/server/auth_test.go @@ -0,0 +1,122 @@ +package server + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestErrAccessDenied(t *testing.T) { + require.True(t, errors.Is(ErrAccessDenied, ErrAccessDenied)) + require.True(t, errors.Is(ErrAccessDeniedNoPassword, ErrAccessDenied)) + require.False(t, errors.Is(ErrAccessDenied, ErrAccessDeniedNoPassword)) +} + +func TestCompareNativePasswordAuthData_EmptyPassword(t *testing.T) { + tests := []struct { + name string + clientAuthData []byte + serverPassword string + wantErr error + }{ + { + name: "empty client auth, empty server password", + clientAuthData: []byte{}, + serverPassword: "", + wantErr: nil, + }, + { + name: "empty client auth, non-empty server password", + clientAuthData: []byte{}, + serverPassword: "secret", + wantErr: ErrAccessDeniedNoPassword, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Conn{ + credential: Credential{Password: tt.serverPassword}, + } + err := c.compareNativePasswordAuthData(tt.clientAuthData, c.credential) + if tt.wantErr == nil { + require.NoError(t, err) + } else { + require.ErrorIs(t, err, tt.wantErr) + } + }) + } +} + +func TestCompareSha256PasswordAuthData_EmptyPassword(t *testing.T) { + tests := []struct { + name string + clientAuthData []byte + serverPassword string + wantErr error + }{ + { + name: "empty client auth, empty server password", + clientAuthData: []byte{}, + serverPassword: "", + wantErr: nil, + }, + { + name: "empty client auth, non-empty server password", + clientAuthData: []byte{}, + serverPassword: "secret", + wantErr: ErrAccessDeniedNoPassword, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Conn{ + credential: Credential{Password: tt.serverPassword}, + } + err := c.compareSha256PasswordAuthData(tt.clientAuthData, c.credential) + if tt.wantErr == nil { + require.NoError(t, err) + } else { + require.ErrorIs(t, err, tt.wantErr) + } + }) + } +} + +func TestCompareCacheSha2PasswordAuthData_EmptyPassword(t *testing.T) { + tests := []struct { + name string + clientAuthData []byte + serverPassword string + wantErr error + }{ + { + name: "empty client auth, empty server password", + clientAuthData: []byte{}, + serverPassword: "", + wantErr: nil, + }, + { + name: "empty client auth, non-empty server password", + clientAuthData: []byte{}, + serverPassword: "secret", + wantErr: ErrAccessDeniedNoPassword, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Conn{ + credential: Credential{Password: tt.serverPassword}, + } + err := c.compareCacheSha2PasswordAuthData(tt.clientAuthData) + if tt.wantErr == nil { + require.NoError(t, err) + } else { + require.ErrorIs(t, err, tt.wantErr) + } + }) + } +}