From 83e88ad787a8b7ae0cd52e8725cc13987f8386de Mon Sep 17 00:00:00 2001 From: Duc NH Date: Sun, 4 Aug 2024 22:59:36 +0700 Subject: [PATCH] Recognize float types in ColumnTypeScanType() --- CHANGELOG.md | 3 + rows.go | 4 ++ rows_test.go | 151 +++++++++++++++++++++++++-------------------------- 3 files changed, 82 insertions(+), 76 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 094dcb58..dfa65421 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ newer. Previously PostgreSQL 8.4 and newer were supported. - Decode bpchar into a string ([#949]). +- Recognize float types in ColumnTypeScanType() ([#1166]). + [#595]: https://github.com/lib/pq/pull/595 [#745]: https://github.com/lib/pq/pull/745 [#949]: https://github.com/lib/pq/pull/949 @@ -39,6 +41,7 @@ newer. Previously PostgreSQL 8.4 and newer were supported. [#1129]: https://github.com/lib/pq/pull/1129 [#1133]: https://github.com/lib/pq/pull/1133 [#1161]: https://github.com/lib/pq/pull/1161 +[#1166]: https://github.com/lib/pq/pull/1166 [#1179]: https://github.com/lib/pq/pull/1179 [#1184]: https://github.com/lib/pq/pull/1184 [#1211]: https://github.com/lib/pq/pull/1211 diff --git a/rows.go b/rows.go index 5900fccc..08732332 100644 --- a/rows.go +++ b/rows.go @@ -29,6 +29,10 @@ func (fd fieldDesc) Type() reflect.Type { return reflect.TypeOf(int32(0)) case oid.T_int2: return reflect.TypeOf(int16(0)) + case oid.T_float8: + return reflect.TypeOf(float64(0)) + case oid.T_float4: + return reflect.TypeOf(float32(0)) case oid.T_varchar, oid.T_text: return reflect.TypeOf("") case oid.T_bool: diff --git a/rows_test.go b/rows_test.go index 5065b685..f6fd27c1 100644 --- a/rows_test.go +++ b/rows_test.go @@ -114,106 +114,105 @@ func TestDataTypePrecisionScale(t *testing.T) { } func TestRowsColumnTypes(t *testing.T) { - columnTypesTests := []struct { - Name string - TypeName string - Length struct { + type ( + length struct { Len int64 OK bool } - DecimalSize struct { + decimalSize struct { Precision int64 Scale int64 OK bool } - ScanType reflect.Type + ) + tests := []struct { + Name string + TypeName string + Length length + DecimalSize decimalSize + ScanType reflect.Type }{ { - Name: "a", - TypeName: "INT4", - Length: struct { - Len int64 - OK bool - }{ - Len: 0, - OK: false, - }, - DecimalSize: struct { - Precision int64 - Scale int64 - OK bool - }{ - Precision: 0, - Scale: 0, - OK: false, - }, - ScanType: reflect.TypeOf(int32(0)), - }, { - Name: "bar", - TypeName: "TEXT", - Length: struct { - Len int64 - OK bool - }{ - Len: math.MaxInt64, - OK: true, - }, - DecimalSize: struct { - Precision int64 - Scale int64 - OK bool - }{ - Precision: 0, - Scale: 0, - OK: false, - }, - ScanType: reflect.TypeOf(""), + Name: "a", + TypeName: "INT4", + Length: length{Len: 0, OK: false}, + DecimalSize: decimalSize{Precision: 0, Scale: 0, OK: false}, + ScanType: reflect.TypeOf(int32(0)), + }, + { + Name: "bar", + TypeName: "TEXT", + Length: length{Len: math.MaxInt64, OK: true}, + DecimalSize: decimalSize{Precision: 0, Scale: 0, OK: false}, + ScanType: reflect.TypeOf(""), + }, + { + Name: "dec", + TypeName: "NUMERIC", + Length: length{Len: 0, OK: false}, + DecimalSize: decimalSize{Precision: 9, Scale: 2, OK: true}, + ScanType: reflect.TypeOf(new(any)).Elem(), + }, + { + Name: "f", + TypeName: "FLOAT8", + Length: length{Len: 0, OK: false}, + DecimalSize: decimalSize{Precision: 0, Scale: 0, OK: false}, + ScanType: reflect.TypeOf(float64(0)), }, } db := pqtest.MustDB(t) defer db.Close() - rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") + rows, err := db.Query(`select + 1 as a, + text 'bar' as bar, + 1.28::numeric(9, 2) as dec, + 3.1415::float8 as f + `) if err != nil { t.Fatal(err) } + defer rows.Close() columns, err := rows.ColumnTypes() if err != nil { t.Fatal(err) } - if len(columns) != 3 { - t.Errorf("expected 3 columns found %d", len(columns)) + if len(columns) != 4 { + t.Errorf("expected 4 columns found %d", len(columns)) } - for i, tt := range columnTypesTests { - c := columns[i] - if c.Name() != tt.Name { - t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) - } - if c.DatabaseTypeName() != tt.TypeName { - t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) - } - l, ok := c.Length() - if l != tt.Length.Len { - t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) - } - if ok != tt.Length.OK { - t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) - } - p, s, ok := c.DecimalSize() - if p != tt.DecimalSize.Precision { - t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) - } - if s != tt.DecimalSize.Scale { - t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) - } - if ok != tt.DecimalSize.OK { - t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) - } - if c.ScanType() != tt.ScanType { - t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) - } + for i, tt := range tests { + t.Run("", func(t *testing.T) { + c := columns[i] + if c.Name() != tt.Name { + t.Errorf("have: %s, want: %s", c.Name(), tt.Name) + } + if c.DatabaseTypeName() != tt.TypeName { + t.Errorf("have: %s, want: %s", c.DatabaseTypeName(), tt.TypeName) + } + l, ok := c.Length() + if l != tt.Length.Len { + t.Errorf("have: %d, want: %d", l, tt.Length.Len) + } + if ok != tt.Length.OK { + t.Errorf("have: %t, want: %t", ok, tt.Length.OK) + } + p, s, ok := c.DecimalSize() + if p != tt.DecimalSize.Precision { + t.Errorf("have: %d, want: %d", p, tt.DecimalSize.Precision) + } + if s != tt.DecimalSize.Scale { + t.Errorf("have: %d, want: %d", s, tt.DecimalSize.Scale) + } + if ok != tt.DecimalSize.OK { + t.Errorf("have: %t, want: %t", ok, tt.DecimalSize.OK) + } + if c.ScanType() != tt.ScanType { + t.Errorf("have: %v, want: %v", c.ScanType(), tt.ScanType) + } + }) } }