diff --git a/CHANGELOG.md b/CHANGELOG.md index 624095be..4a7f1c0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,11 +9,17 @@ unreleased ### Features +- Implement `require_auth` connection parameter ([#1310]). + ### Fixes - Add Redshift-specific OID mappings ([#1291], [#1317]). +- Use correct environment variable name for `PGSSLMINPROTOCOLVERSION` and + `PGSSLMAXPROTOCOLVERSION` ([#1310]). + [#1291]: https://github.com/lib/pq/pull/1291 +[#1310]: https://github.com/lib/pq/pull/1310 [#1317]: https://github.com/lib/pq/pull/1317 diff --git a/conn.go b/conn.go index e4661c84..667c5fa9 100644 --- a/conn.go +++ b/conn.go @@ -15,6 +15,7 @@ import ( "net" "os" "reflect" + "slices" "strconv" "strings" "sync" @@ -1255,6 +1256,7 @@ func (cn *conn) startup(cfg Config) error { return err } + var didauth bool for { t, r, err := cn.recv() if err != nil { @@ -1271,7 +1273,11 @@ func (cn *conn) startup(cfg Config) error { case proto.ParameterStatus: cn.processParameterStatus(r) case proto.AuthenticationRequest: - err := cn.auth(r, cfg) + code := proto.AuthCode(r.int32()) + if code != proto.AuthReqOk { + didauth = true + } + err := cn.auth(code, r, cfg) if err != nil { return err } @@ -1282,6 +1288,9 @@ func (cn *conn) startup(cfg Config) error { return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor) } case proto.ReadyForQuery: + if len(cn.cfg.RequireAuth) > 0 && !didauth && !slices.Contains(cn.cfg.RequireAuth, RequireAuthNone) { + return fmt.Errorf("pq: authentication method requirement %q failed: server did not perform any authentication", cn.cfg.RequireAuth) + } cn.processReadyForQuery(r) return nil default: @@ -1290,8 +1299,8 @@ func (cn *conn) startup(cfg Config) error { } } -func (cn *conn) auth(r *readBuf, cfg Config) error { - switch code := proto.AuthCode(r.int32()); code { +func (cn *conn) auth(code proto.AuthCode, r *readBuf, cfg Config) error { + switch code { default: return fmt.Errorf("pq: unknown authentication response: %s", code) case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI: @@ -1300,6 +1309,9 @@ func (cn *conn) auth(r *readBuf, cfg Config) error { return nil case proto.AuthReqPassword: + if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthPassword) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) { + return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthPassword) + } w := cn.writeBuf(proto.PasswordMessage) w.string(cfg.Password) // Don't need to check AuthOk response here; auth() is called in a loop, @@ -1307,6 +1319,9 @@ func (cn *conn) auth(r *readBuf, cfg Config) error { return cn.send(w) case proto.AuthReqMD5: + if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthMD5) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) { + return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthMD5) + } s := string(r.next(4)) w := cn.writeBuf(proto.PasswordMessage) w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s)) @@ -1369,6 +1384,9 @@ func (cn *conn) auth(r *readBuf, cfg Config) error { return nil case proto.AuthReqSASL: + if len(cn.cfg.RequireAuth) > 0 && !slices.Contains(cn.cfg.RequireAuth, RequireAuthScramSHA256) && !slices.Contains(cn.cfg.RequireAuth, RequireAuthAny) { + return fmt.Errorf("pq: authentication method requirement %q failed: server requested %q", cn.cfg.RequireAuth, RequireAuthScramSHA256) + } sc := scram.NewClient(sha256.New, cfg.User, cfg.Password) sc.Step(nil) if sc.Err() != nil { diff --git a/conn_test.go b/conn_test.go index 09b2337f..1a39ecfc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1607,18 +1607,18 @@ func TestCommitInFailedTransactionWithCancelContext(t *testing.T) { func TestAuth(t *testing.T) { tests := []struct { - buf readBuf + code proto.AuthCode wantErr string }{ - {readBuf{0, 0, 0, 9}, `pq: unsupported authentication method: SSPI (9)`}, - {readBuf{0, 0, 0, 99}, `unknown authentication response: (99)`}, + {proto.AuthCode(9), `pq: unsupported authentication method: SSPI (9)`}, + {proto.AuthCode(99), `unknown authentication response: (99)`}, } t.Parallel() for _, tt := range tests { t.Run("", func(t *testing.T) { t.Run("unsupported auth", func(t *testing.T) { - err := (&conn{}).auth(&tt.buf, Config{}) + err := (&conn{}).auth(tt.code, &readBuf{}, Config{}) if !pqtest.ErrorContains(err, tt.wantErr) { t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr) } @@ -1646,10 +1646,34 @@ func TestAuth(t *testing.T) { {"user=pqgoscram password=wordpass", ``}, {"user=pqgounknown password=wordpass", `or:role "pqgounknown" does not exist|password authentication failed for user pqgounknown`}, + {"user=pqgounknown password=wordpass require_auth=md5", `or:role "pqgounknown" does not exist|password authentication failed for user pqgounknown`}, + + // require_auth + {"user=pqgomd5 password=wordpass require_auth=md5,password", ``}, + {"user=pqgopassword password=wordpass require_auth=md5,password", ``}, + {"user=pqgoscram password=wordpass require_auth=md5,password,scram-sha-256", ``}, + {"user=pqgomd5 password=wordpass require_auth=!none", ``}, + {"user=pqgopassword password=wordpass require_auth=!none", ``}, + {"user=pqgoscram password=wordpass require_auth=!none", ``}, + + {"user=pqgomd5 password=wordpass require_auth=password", `"password" failed: server requested "md5"`}, + {"user=pqgopassword password=wordpass require_auth=md5", `"md5" failed: server requested "password"`}, + {"user=pqgoscram password=wordpass require_auth=md5,password", `authentication method requirement "md5,password" failed: server requested "scram-sha-256"`}, + {"user=pqgomd5 password=wordpass require_auth=!md5,!password", `"!md5,!password" failed: server requested "md5"`}, + {"user=pqgopassword password=wordpass require_auth=!md5,!password", `"!md5,!password" failed: server requested "password"`}, + {"user=pqgoscram password=wordpass require_auth=!md5,!password,!scram-sha-256", `"!md5,!password,!scram-sha-256" failed: server requested "scram-sha-256"`}, + {"user=pqgomd5 password=wordpass require_auth=password", `"password" failed: server requested "md5"`}, + + {"user=pqgo password=unused require_auth=none", ``}, + {"user=pqgo password=unused require_auth=!none", `"!none" failed: server did not perform any authentication`}, + {"user=pqgo password=unused require_auth=md5,password,scram-sha-256", `"md5,password,scram-sha-256" failed: server did not perform any authentication`}, } for _, tt := range tests { t.Run(tt.conn, func(t *testing.T) { + if strings.Contains(tt.conn, "md5") { + pqtest.SkipCockroach(t) // md5 not supported + } _, err := pqtest.DB(t, tt.conn) if !pqtest.ErrorContains(err, tt.wantErr) { t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr) diff --git a/connector.go b/connector.go index 01050db0..32bfe890 100644 --- a/connector.go +++ b/connector.go @@ -45,6 +45,12 @@ type ( // SSLProtocolVersion is a ssl_min_protocol_version or // ssl_max_protocol_version setting. SSLProtocolVersion string + + // RequireAuth is a require_auth setting. + RequireAuth string + + // RequireAuths is a require_auth setting. + RequireAuths []RequireAuth ) // Values for [SSLMode] that pq supports. @@ -179,6 +185,41 @@ func (s SSLProtocolVersion) tlsconf() uint16 { } } +// Values for [RequireAuth] that pq supports. +const ( + RequireAuthNone = RequireAuth("none") + RequireAuthPassword = RequireAuth("password") + RequireAuthMD5 = RequireAuth("md5") + RequireAuthGSS = RequireAuth("gss") + RequireAuthScramSHA256 = RequireAuth("scram-sha-256") + RequireAuthAny = RequireAuth("!none") + RequireAuthNotPassword = RequireAuth("!password") + RequireAuthNotMD5 = RequireAuth("!md5") + RequireAuthNotGSS = RequireAuth("!gss") + RequireAuthNotScramSHA256 = RequireAuth("!scram-sha-256") + + // Not (yet) supported by pq + // RequireAuthSSPI = "sspi" + // RequireAuthOAuth = "oauth" + // RequireAuthNotSSPI = "!sspi" + // RequireAuthNotOAuth = "!oauth" +) + +var requireAuths = []RequireAuth{RequireAuthNone, RequireAuthPassword, RequireAuthMD5, + RequireAuthGSS, RequireAuthScramSHA256, RequireAuthAny, RequireAuthNotPassword, + RequireAuthNotMD5, RequireAuthNotGSS, RequireAuthNotScramSHA256} + +func (r RequireAuths) String() string { + var b strings.Builder + for i, rr := range r { + if i > 0 { + b.WriteString(",") + } + b.WriteString(string(rr)) + } + return b.String() +} + // Connector represents a fixed configuration for the pq driver with a given // dsn. Connector satisfies the [database/sql/driver.Connector] interface and // can be used to create any number of DB Conn's via [sql.OpenDB]. @@ -341,14 +382,14 @@ type Config struct { // // The default is determined by [tls.Config.MinVersion], which is TLSv1.2 at // the time of writing. - SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"SSLPGMINPROTOCOLVERSION"` + SSLMinProtocolVersion SSLProtocolVersion `postgres:"ssl_min_protocol_version" env:"PGSSLMINPROTOCOLVERSION"` // Maximum SSL/TLS protocol version to allow for the connection. If not set, // this parameter is ignored and the connection will use the maximum bound // defined by the backend, if set. Setting the maximum protocol version is // mainly useful for testing or if some component has issues working with a // newer protocol. - SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"SSLPGMAXPROTOCOLVERSION"` + SSLMaxProtocolVersion SSLProtocolVersion `postgres:"ssl_max_protocol_version" env:"PGSSLMAXPROTOCOLVERSION"` // Interpert sslcert and sslkey as PEM encoded data, rather than a path to a // PEM file. This is a pq extension, not supported in libpq. @@ -431,6 +472,25 @@ type Config struct { // Path to connection service file. Defaults to ~/.pg_service.conf. ServiceFile string `postgres:"-" env:"PGSERVICEFILE"` + // Require an authentication method from the server and refuse to connect if + // the server does not use the requested method. + // + // This accepts a comma-separated list. + // + // Methods may be negated with a ! prefix, in which case the server must + // *not* attempt the listed method, and the server is free not to + // authenticate the client at all. Negated and non-negated forms may not be + // combined in the same setting with a comma-separated list. + // + // As a special case the "none" method requires the server not to use an + // authentication challenge. This does not prohibit client certificate + // authentication via TLS or GSS authentication via its encrypted transport. + // This can be negated to require some form of authentication. + // + // By default any authentication method is accepted and the server is free + // to skip authentication altogether. + RequireAuth RequireAuths `postgres:"require_auth" env:"PGREQUIREAUTH"` + // Runtime parameters: any unrecognized parameter in the DSN will be added // to this and sent to PostgreSQL during startup. Runtime map[string]string `postgres:"-" env:"-"` @@ -517,7 +577,8 @@ func NewConfig(dsn string) (Config, error) { // Clone returns a copy of the [Config]. func (cfg Config) Clone() Config { c := cfg - c.Runtime, c.Multi, c.set = maps.Clone(cfg.Runtime), slices.Clone(cfg.Multi), slices.Clone(cfg.set) + c.Runtime, c.Multi, c.RequireAuth, c.set = maps.Clone(cfg.Runtime), slices.Clone(cfg.Multi), + slices.Clone(cfg.RequireAuth), slices.Clone(cfg.set) return c } @@ -672,8 +733,8 @@ func (cfg *Config) fromEnv(env []string) error { switch k { case "PGREQUIRESSL", "PGSSLCOMPRESSION", // Deprecated. "PGREALM", "PGGSSENCMODE", "PGGSSDELEGATION", "PGGSSLIB", // krb stuff - "PGREQUIREAUTH", "PGCHANNELBINDING", - "PGSSLCERTMODE", "PGSSLCRL", "PGSSLCRLDIR", "PGREQUIREPEER": + "PGCHANNELBINDING", "PGSSLCRL", "PGSSLCRLDIR", + "PGSSLCERTMODE", "PGREQUIREPEER": return fmt.Errorf("pq: environment variable $%s is not supported", k) case "PGKRBSRVNAME": if newGss == nil { @@ -833,8 +894,9 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err loadbalancehosts = (tag == "postgres" && k == "load_balance_hosts") || (tag == "env" && k == "PGLOADBALANCEHOSTS") minprotocolversion = (tag == "postgres" && k == "min_protocol_version") || (tag == "env" && k == "PGMINPROTOCOLVERSION") maxprotocolversion = (tag == "postgres" && k == "max_protocol_version") || (tag == "env" && k == "PGMAXPROTOCOLVERSION") - sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "SSLPGMINPROTOCOLVERSION") - sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "SSLPGMAXPROTOCOLVERSION") + sslminprotocolversion = (tag == "postgres" && k == "ssl_min_protocol_version") || (tag == "env" && k == "PGSSLMINPROTOCOLVERSION") + sslmaxprotocolversion = (tag == "postgres" && k == "ssl_max_protocol_version") || (tag == "env" && k == "PGSSLMAXPROTOCOLVERSION") + requireauth = (tag == "postgres" && k == "require_auth") || (tag == "env" && k == "PGREQUIREAUTH") ) if k == "" || k == "-" { continue @@ -908,6 +970,31 @@ func (cfg *Config) setFromTag(o map[string]string, tag string, service bool) err cfg.multiHost = append(cfg.multiHost, vv[1:]...) } rv.SetString(v) + case reflect.Slice: + if requireauth { + if v == "" { + rv.Set(reflect.ValueOf((RequireAuths)(nil))) + continue + } + var ( + vv = strings.Split(v, ",") + s = make(RequireAuths, len(vv)) + neg = len(vv) > 0 && strings.HasPrefix(vv[0], "!") + ) + for i := range vv { + if !slices.Contains(requireAuths, RequireAuth(vv[i])) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, vv[i], pqutil.Join(requireAuths)) + } + if neg && !strings.HasPrefix(vv[i], "!") { + return fmt.Errorf(f+`require_auth method %q cannot be mixed with negative methods`, k, vv[i]) + } + if !neg && strings.HasPrefix(vv[i], "!") { + return fmt.Errorf(f+`negative require_auth method %q cannot be mixed with non-negative methods`, k, vv[i]) + } + s[i] = RequireAuth(vv[i]) + } + rv.Set(reflect.ValueOf(s)) + } case reflect.Int64: n, err := strconv.ParseInt(v, 10, 64) if err != nil { diff --git a/connector_test.go b/connector_test.go index a3cf1cb5..b9699b59 100644 --- a/connector_test.go +++ b/connector_test.go @@ -435,6 +435,15 @@ func TestNewConfig(t *testing.T) { {"", []string{"PGMINPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMINPROTOCOLVERSION: "bogus" is not supported`}, {"", []string{"PGMAXPROTOCOLVERSION=bogus"}, "", `pq: wrong value for $PGMAXPROTOCOLVERSION: "bogus" is not supported`}, {"min_protocol_version=3.2 max_protocol_version=3.0", nil, "", `min_protocol_version "3.2" cannot be greater than max_protocol_version "3.0"`}, + + // requireauth + {"require_auth=", nil, "require_auth=''", ``}, + {"require_auth=none", nil, "require_auth=none", ""}, + {"require_auth=md5,scram-sha-256", nil, "require_auth=md5,scram-sha-256", ""}, + {"require_auth=md5,scram-sha256", nil, "", `wrong value for "require_auth": "scram-sha256" is not supported`}, + {"require_auth=!md5,!scram-sha-256", nil, "require_auth=!md5,!scram-sha-256", ""}, + {"require_auth=md5,!password", nil, "", `negative require_auth method "!password" cannot be mixed with non-negative methods`}, + {"require_auth=!md5,password", nil, "", `require_auth method "password" cannot be mixed with negative methods`}, } t.Parallel()