diff --git a/conn.go b/conn.go index 688329cde..78687c155 100644 --- a/conn.go +++ b/conn.go @@ -87,6 +87,7 @@ type parameterStatus struct { serverVersion int currentLocation *time.Location inHotStandby, defaultTransactionReadOnly sql.NullBool + isRedshift bool } type format int @@ -1558,6 +1559,8 @@ func (cn *conn) processParameterStatus(r *readBuf) { switch r.string() { default: // ignore + case "padb_version": + cn.parameterStatus.isRedshift = true case "server_version": var major1, major2 int _, err := fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) diff --git a/oid/doc.go b/oid/doc.go deleted file mode 100644 index a48650663..000000000 --- a/oid/doc.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:generate go run ./gen.go - -// Package oid contains OID constants as defined by the Postgres server. -package oid - -// Oid is a Postgres Object ID. -type Oid uint32 diff --git a/oid/gen.go b/oid/gen.go deleted file mode 100644 index 519bddb16..000000000 --- a/oid/gen.go +++ /dev/null @@ -1,81 +0,0 @@ -//go:build ignore - -// Generate the table of OID values -// Run with 'go run gen.go'. -package main - -import ( - "database/sql" - "fmt" - "log" - "os" - "os/exec" - "strings" - - _ "github.com/lib/pq" - "github.com/lib/pq/internal/pqtest" -) - -// OID represent a postgres Object Identifier Type. -type OID struct { - ID int - Type string -} - -// Name returns an upper case version of the oid type. -func (o OID) Name() string { - return strings.ToUpper(o.Type) -} - -func main() { - db, err := sql.Open("postgres", pqtest.DSN()) - if err != nil { - log.Fatal(err) - } - defer db.Close() - - rows, err := db.Query(`select typname, oid from pg_type where oid < 10000 order by oid`) - if err != nil { - log.Fatal(err) - } - oids := make([]*OID, 0) - for rows.Next() { - var oid OID - if err = rows.Scan(&oid.Type, &oid.ID); err != nil { - log.Fatal(err) - } - oids = append(oids, &oid) - } - if err = rows.Err(); err != nil { - log.Fatal(err) - } - cmd := exec.Command("gofmt") - cmd.Stderr = os.Stderr - w, err := cmd.StdinPipe() - if err != nil { - log.Fatal(err) - } - f, err := os.Create("types.go") - if err != nil { - log.Fatal(err) - } - cmd.Stdout = f - err = cmd.Start() - if err != nil { - log.Fatal(err) - } - fmt.Fprintln(w, "// Code generated by gen.go. DO NOT EDIT.") - fmt.Fprintln(w, "\npackage oid") - fmt.Fprintln(w, "const (") - for _, oid := range oids { - fmt.Fprintf(w, "T_%s Oid = %d\n", oid.Type, oid.ID) - } - fmt.Fprintln(w, ")") - fmt.Fprintln(w, "var TypeName = map[Oid]string{") - for _, oid := range oids { - fmt.Fprintf(w, "T_%s: \"%s\",\n", oid.Type, oid.Name()) - } - fmt.Fprintln(w, "}") - w.Close() - cmd.Wait() -} diff --git a/oid/types.go b/oid/types.go index ecc84c2c8..44ad358f5 100644 --- a/oid/types.go +++ b/oid/types.go @@ -1,7 +1,9 @@ -// Code generated by gen.go. DO NOT EDIT. - +// Package oid contains OID constants as defined by the Postgres server. package oid +// Oid is a Postgres Object ID. +type Oid uint32 + const ( T_bool Oid = 16 T_bytea Oid = 17 diff --git a/rows.go b/rows.go index 2029bfed2..60a613471 100644 --- a/rows.go +++ b/rows.go @@ -161,6 +161,11 @@ func (rs *rows) ColumnTypeScanType(index int) reflect.Type { // ColumnTypeDatabaseTypeName return the database system type name. func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { + if rs.cn.parameterStatus.isRedshift { + if n, ok := redshiftTypeName[rs.colTyps[index].OID]; ok { + return n + } + } return rs.colTyps[index].Name() } @@ -243,3 +248,26 @@ func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { return 0, 0, false } } + +var redshiftTypeName = map[oid.Oid]string{ + 86: "PG_SHADOW", + 87: "PG_GROUP", + 88: "PG_DATABASE", + 90: "PG_TABLESPACE", + 635: "_SPECTRUM_ARRAY", + 636: "_SPECTRUM_MAP", + 637: "_SPECTRUM_STRUCT", + 1188: "INTERVALY2M", + 1189: "_INTERVALY2M", + 1190: "INTERVALD2S", + 1191: "_INTERVALD2S", + 2935: "HLLSKETCH", + 3000: "GEOMETRY", + 3001: "GEOGRAPHY", + 4000: "SUPER", + 4600: "USERITEM", + 4601: "_USERITEM", + 4602: "ROLEITEM", + 4603: "_ROLEITEM", + 6551: "VARBYTE", +} diff --git a/rows_test.go b/rows_test.go index 491a9bcce..d73769def 100644 --- a/rows_test.go +++ b/rows_test.go @@ -13,30 +13,42 @@ import ( func TestDataTypeName(t *testing.T) { tests := []struct { - typ oid.Oid - name string + typ oid.Oid + name string + redshift bool }{ - {oid.T_int8, "INT8"}, - {oid.T_int4, "INT4"}, - {oid.T_int2, "INT2"}, - {oid.T_varchar, "VARCHAR"}, - {oid.T_text, "TEXT"}, - {oid.T_bit, "BIT"}, - {oid.T_varbit, "VARBIT"}, - {oid.T_bool, "BOOL"}, - {oid.T_numeric, "NUMERIC"}, - {oid.T_date, "DATE"}, - {oid.T_time, "TIME"}, - {oid.T_timetz, "TIMETZ"}, - {oid.T_timestamp, "TIMESTAMP"}, - {oid.T_timestamptz, "TIMESTAMPTZ"}, - {oid.T_bytea, "BYTEA"}, + {oid.T_int8, "INT8", false}, + {oid.T_int4, "INT4", false}, + {oid.T_int2, "INT2", false}, + {oid.T_varchar, "VARCHAR", false}, + {oid.T_text, "TEXT", false}, + {oid.T_bit, "BIT", false}, + {oid.T_varbit, "VARBIT", false}, + {oid.T_bool, "BOOL", false}, + {oid.T_numeric, "NUMERIC", false}, + {oid.T_date, "DATE", false}, + {oid.T_time, "TIME", false}, + {oid.T_timetz, "TIMETZ", false}, + {oid.T_timestamp, "TIMESTAMP", false}, + {oid.T_timestamptz, "TIMESTAMPTZ", false}, + {oid.T_bytea, "BYTEA", false}, + + {oid.T_int8, "INT8", true}, + {635, "_SPECTRUM_ARRAY", true}, + {636, "_SPECTRUM_MAP", true}, + {637, "_SPECTRUM_STRUCT", true}, + {4000, "SUPER", true}, + + {635, "", false}, } for _, tt := range tests { t.Run("", func(t *testing.T) { - have := fieldDesc{OID: tt.typ} - if name := have.Name(); name != tt.name { + have := &rows{ + cn: &conn{parameterStatus: parameterStatus{isRedshift: tt.redshift}}, + rowsHeader: rowsHeader{colTyps: []fieldDesc{{OID: tt.typ}}}, + } + if name := have.ColumnTypeDatabaseTypeName(0); name != tt.name { t.Errorf("\nhave: %s\nwant: %s", name, tt.name) } })