diff --git a/array.go b/array.go index bdd97d3f..d9c60389 100644 --- a/array.go +++ b/array.go @@ -4,758 +4,257 @@ import ( "bytes" "database/sql" "database/sql/driver" - "encoding/hex" "fmt" "reflect" "strconv" "strings" + "time" + _ "unsafe" ) -var typeByteSlice = reflect.TypeFor[[]byte]() -var typeDriverValuer = reflect.TypeFor[driver.Valuer]() -var typeSQLScanner = reflect.TypeFor[sql.Scanner]() +var ( + _ sql.Scanner = (*ArrayOf[any])(nil) + _ driver.Valuer = (*ArrayOf[any])(nil) +) -// Array returns the optimal driver.Valuer and sql.Scanner for an array or -// slice of any dimension. -// -// For example: -// -// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// TODO: hopefully will be exported in the future? +// https://github.com/golang/go/issues/62146#issuecomment-3921700836 // -// var x []sql.NullInt64 -// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) -// -// Scanning multi-dimensional arrays is not supported. Arrays where the lower -// bound is not one (such as `[0:0]={1}') are not supported. -func Array(a any) interface { - driver.Valuer - sql.Scanner -} { - switch a := a.(type) { - case []bool: - return (*BoolArray)(&a) - case []float64: - return (*Float64Array)(&a) - case []float32: - return (*Float32Array)(&a) - case []int64: - return (*Int64Array)(&a) - case []int32: - return (*Int32Array)(&a) - case []string: - return (*StringArray)(&a) - case [][]byte: - return (*ByteaArray)(&a) - - case *[]bool: - return (*BoolArray)(a) - case *[]float64: - return (*Float64Array)(a) - case *[]float32: - return (*Float32Array)(a) - case *[]int64: - return (*Int64Array)(a) - case *[]int32: - return (*Int32Array)(a) - case *[]string: - return (*StringArray)(a) - case *[][]byte: - return (*ByteaArray)(a) - } +//go:linkname convertAssign database/sql.convertAssign +func convertAssign(dest, src any) error - return GenericArray{a} -} - -// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner -// to override the array delimiter used by GenericArray. -type ArrayDelimiter interface { - // ArrayDelimiter returns the delimiter character(s) for this element's type. - ArrayDelimiter() string -} - -// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. -type BoolArray []bool +// ArrayOf wraps a slice with Scan() and Value() methods that use PostgreSQL's +// array syntax. +// +// Types may optionally implement the "ArrayDelimiter() string" method to use a +// different array delimiter. +type ArrayOf[T any] []T -// Scan implements the sql.Scanner interface. -func (a *BoolArray) Scan(src any) error { +func (a *ArrayOf[T]) Scan(src any) error { switch src := src.(type) { case []byte: - return a.scanBytes(src) + return a.scan(src) case string: - return a.scanBytes([]byte(src)) + return a.scan([]byte(src)) case nil: *a = nil return nil } - - return fmt.Errorf("pq: cannot convert %T to BoolArray", src) + return fmt.Errorf("pq: cannot convert %T to %T", src, a) } -func (a *BoolArray) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "BoolArray") +func (a *ArrayOf[T]) scan(src []byte) error { + var zero T + dims, elems, err := parseArray(src, arrayDelimiter(any(zero))) if err != nil { return err } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(BoolArray, len(elems)) - for i, v := range elems { - if len(v) != 1 { - return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) - } - switch v[0] { - case 't': - b[i] = true - case 'f': - b[i] = false - default: - return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) - } - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a BoolArray) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be exactly two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1+2*n) - - for i := range n { - b[2*i] = ',' - if a[i] { - b[1+2*i] = 't' - } else { - b[1+2*i] = 'f' - } - } - - b[0] = '{' - b[2*n] = '}' - - return string(b), nil - } - - return "{}", nil -} - -// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. -type ByteaArray [][]byte - -// Scan implements the sql.Scanner interface. -func (a *ByteaArray) Scan(src any) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil + if len(dims) > 1 { + return fmt.Errorf("pq: cannot convert ARRAY%s to %T", strings.Replace(fmt.Sprint(dims), " ", "][", -1), a) } - return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) -} - -func (a *ByteaArray) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") - if err != nil { - return err - } if *a != nil && len(elems) == 0 { *a = (*a)[:0] - } else { - b := make(ByteaArray, len(elems)) - for i, v := range elems { - b[i], err = parseBytea(v) - if err != nil { - return fmt.Errorf("could not parse bytea array index %d: %w", i, err) - } - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. It uses the "hex" format which -// is only supported on PostgreSQL 9.0 or newer. -func (a ByteaArray) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, 2*N bytes of quotes, - // 3*N bytes of hex formatting, and N-1 bytes of delimiters. - size := 1 + 6*n - for _, x := range a { - size += hex.EncodedLen(len(x)) - } - - b := make([]byte, size) - - for i, s := 0, b; i < n; i++ { - o := copy(s, `,"\\x`) - o += hex.Encode(s[o:], a[i]) - s[o] = '"' - s = s[o+1:] - } - - b[0] = '{' - b[size-1] = '}' - - return string(b), nil - } - - return "{}", nil -} - -// Float64Array represents a one-dimensional array of the PostgreSQL double -// precision type. -type Float64Array []float64 - -// Scan implements the sql.Scanner interface. -func (a *Float64Array) Scan(src any) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil return nil } - return fmt.Errorf("pq: cannot convert %T to Float64Array", src) -} - -func (a *Float64Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Float64Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Float64Array, len(elems)) + switch any(zero).(type) { + case time.Time, *time.Time: // convertAssign only scans time.Time if src (v here) is a time.Time. + _, ptr := any(zero).(*time.Time) + b := make([]T, len(elems)) for i, v := range elems { - b[i], err = strconv.ParseFloat(string(v), 64) + if v == nil { // NULL + if ptr { + continue // Use nil zero value. + } + return fmt.Errorf("pq: array index %d: cannot convert NULL to time.Time", i) + } + t, err := ParseTimestamp(nil, string(v)) if err != nil { - return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + return fmt.Errorf("array index %d: %w", i, err) } + x := any(t) + if ptr { + x = any(&t) + } + b[i] = x.(T) } *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a Float64Array) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendFloat(b, a[0], 'f', -1, 64) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendFloat(b, a[i], 'f', -1, 64) - } - - return string(append(b, '}')), nil - } - - return "{}", nil -} - -// Float32Array represents a one-dimensional array of the PostgreSQL double -// precision type. -type Float32Array []float32 - -// Scan implements the sql.Scanner interface. -func (a *Float32Array) Scan(src any) error { - switch src := src.(type) { case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to Float32Array", src) -} - -func (a *Float32Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Float32Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Float32Array, len(elems)) + b := make([]T, len(elems)) for i, v := range elems { - x, err := strconv.ParseFloat(string(v), 32) + x, err := parseBytea(v) if err != nil { - return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + return fmt.Errorf("pq: array index %d: %w", i, err) } - b[i] = float32(x) + b[i] = any(x).(T) } *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a Float32Array) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) - } - - return string(append(b, '}')), nil - } - - return "{}", nil -} - -// GenericArray implements the driver.Valuer and sql.Scanner interfaces for -// an array or slice of any dimension. -type GenericArray struct{ A any } - -func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { - var assign func([]byte, reflect.Value) error - var del = "," - - // TODO calculate the assign function for other types - // TODO repeat this section on the element type of arrays or slices (multidimensional) - { - if reflect.PointerTo(rt).Implements(typeSQLScanner) { - // dest is always addressable because it is an element of a slice. - assign = func(src []byte, dest reflect.Value) (err error) { - ss := dest.Addr().Interface().(sql.Scanner) - if src == nil { - err = ss.Scan(nil) - } else { - err = ss.Scan(src) - } - return - } - goto FoundType - } - - assign = func([]byte, reflect.Value) error { - return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) + case nil: // Treat any as string, rather than []byte + b := make([]T, len(elems)) + for i, v := range elems { + b[i] = any(string(v)).(T) } - } - -FoundType: - - if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { - del = ad.ArrayDelimiter() - } - - return rt, assign, del -} - -// Scan implements the sql.Scanner interface. -func (a GenericArray) Scan(src any) error { - dpv := reflect.ValueOf(a.A) - switch { - case dpv.Kind() != reflect.Pointer: - return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) - case dpv.IsNil(): - return fmt.Errorf("pq: destination %T is nil", a.A) - } - - dv := dpv.Elem() - switch dv.Kind() { - case reflect.Slice: - case reflect.Array: + *a = b default: - return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) - } - - switch src := src.(type) { - case []byte: - return a.scanBytes(src, dv) - case string: - return a.scanBytes([]byte(src), dv) - case nil: - if dv.Kind() == reflect.Slice { - dv.Set(reflect.Zero(dv.Type())) - return nil - } - } - - return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) -} - -func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { - dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) - dims, elems, err := parseArray(src, []byte(del)) - if err != nil { - return err - } - - // TODO allow multidimensional - - if len(dims) > 1 { - return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", - strings.Replace(fmt.Sprint(dims), " ", "][", -1)) - } - - // Treat a zero-dimensional array like an array with a single dimension of zero. - if len(dims) == 0 { - dims = append(dims, 0) - } - - for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { - switch rt.Kind() { - case reflect.Slice: - case reflect.Array: - if rt.Len() != dims[i] { - return fmt.Errorf("pq: cannot convert ARRAY%s to %s", - strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) + b := make([]T, len(elems)) + for i, v := range elems { + if v == nil { // NULL + if !isPointer(zero) { + return fmt.Errorf("pq: array index %d: cannot convert NULL to %T", i, zero) + } + continue // Just use zero value } - default: - // TODO handle multidimensional - } - } - - values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) - for i, e := range elems { - err := assign(e, values.Index(i)) - if err != nil { - return fmt.Errorf("pq: parsing array element index %d: %w", i, err) - } - } - - // TODO handle multidimensional - - switch dv.Kind() { - case reflect.Slice: - dv.Set(values.Slice(0, dims[0])) - case reflect.Array: - for i := 0; i < dims[0]; i++ { - dv.Index(i).Set(values.Index(i)) - } - } - - return nil -} - -// Value implements the driver.Valuer interface. -func (a GenericArray) Value() (driver.Value, error) { - if a.A == nil { - return nil, nil - } - - rv := reflect.ValueOf(a.A) - - switch rv.Kind() { - case reflect.Slice: - if rv.IsNil() { - return nil, nil - } - case reflect.Array: - default: - return nil, fmt.Errorf("pq: unable to convert %T to array", a.A) - } - - if n := rv.Len(); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 0, 1+2*n) - b, _, err := appendArray(b, rv, n) - return string(b), err - } - - return "{}", nil -} - -// Int64Array represents a one-dimensional array of the PostgreSQL integer types. -type Int64Array []int64 - -// Scan implements the sql.Scanner interface. -func (a *Int64Array) Scan(src any) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to Int64Array", src) -} - -func (a *Int64Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Int64Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Int64Array, len(elems)) - for i, v := range elems { - b[i], err = strconv.ParseInt(string(v), 10, 64) + err := convertAssign(&b[i], v) if err != nil { - return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + return fmt.Errorf("pq: array index %d: %s", i, strings.TrimPrefix(err.Error(), "sql/driver: ")) } } *a = b } + return nil } -// Value implements the driver.Valuer interface. -func (a Int64Array) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendInt(b, a[0], 10) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendInt(b, a[i], 10) - } - - return string(append(b, '}')), nil +func isPointer(t any) bool { + switch t.(type) { + case *string, *bool, *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, *float32, *float64: + return true } - - return "{}", nil + return reflect.ValueOf(t).Kind() == reflect.Ptr } -// Int32Array represents a one-dimensional array of the PostgreSQL integer types. -type Int32Array []int32 - -// Scan implements the sql.Scanner interface. -func (a *Int32Array) Scan(src any) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) - case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil +func arrayDelimiter(v any) []byte { + if d, ok := v.(interface{ ArrayDelimiter() string }); ok { + return []byte(d.ArrayDelimiter()) } - - return fmt.Errorf("pq: cannot convert %T to Int32Array", src) + return []byte{','} } -func (a *Int32Array) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "Int32Array") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(Int32Array, len(elems)) - for i, v := range elems { - x, err := strconv.ParseInt(string(v), 10, 32) - if err != nil { - return fmt.Errorf("pq: parsing array element index %d: %w", i, err) - } - b[i] = int32(x) - } - *a = b - } - return nil -} +const hextable = "0123456789abcdef" -// Value implements the driver.Valuer interface. -func (a Int32Array) Value() (driver.Value, error) { +func (a ArrayOf[T]) Value() (driver.Value, error) { if a == nil { return nil, nil } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, N bytes of values, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+2*n) - b[0] = '{' - - b = strconv.AppendInt(b, int64(a[0]), 10) - for i := 1; i < n; i++ { - b = append(b, ',') - b = strconv.AppendInt(b, int64(a[i]), 10) - } - - return string(append(b, '}')), nil + if len(a) == 0 { + return "{}", nil } - return "{}", nil -} - -// StringArray represents a one-dimensional array of the PostgreSQL character types. -type StringArray []string + var zero T + del := arrayDelimiter(any(zero)) -// Scan implements the sql.Scanner interface. -func (a *StringArray) Scan(src any) error { - switch src := src.(type) { - case []byte: - return a.scanBytes(src) + // Pick a reasonable initial buffer length. + sz := 2 + (len(a)-1)*len(del) // Start/end {} and n-1 delimiters. + switch any(zero).(type) { + case bool: + sz += len(a) // Always 1 byte. case string: - return a.scanBytes([]byte(src)) - case nil: - *a = nil - return nil - } - - return fmt.Errorf("pq: cannot convert %T to StringArray", src) -} - -func (a *StringArray) scanBytes(src []byte) error { - elems, err := scanLinearArray(src, []byte{','}, "StringArray") - if err != nil { - return err - } - if *a != nil && len(elems) == 0 { - *a = (*a)[:0] - } else { - b := make(StringArray, len(elems)) - for i, v := range elems { - if b[i] = string(v); v == nil { - return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) - } - } - *a = b - } - return nil -} - -// Value implements the driver.Valuer interface. -func (a StringArray) Value() (driver.Value, error) { - if a == nil { - return nil, nil - } - - if n := len(a); n > 0 { - // There will be at least two curly brackets, 2*N bytes of quotes, - // and N-1 bytes of delimiters. - b := make([]byte, 1, 1+3*n) - b[0] = '{' - - b = appendArrayQuotedBytes(b, []byte(a[0])) - for i := 1; i < n; i++ { - b = append(b, ',') - b = appendArrayQuotedBytes(b, []byte(a[i])) + sz += len(a) * 4 // Start/end quote, assume 2 bytes per entry. + case []byte: + sz += len(a) * 3 // Start \\x + for _, aa := range a { + sz += len(any(aa).([]byte)) * 2 } - - return string(append(b, '}')), nil + case float32, float64: + sz += len(a) * 3 // Assume 3 bytes per entry. + case time.Time: + sz += len(a) * 22 // 2 quotes and assumed 20 bytes per entry (timestamp w/o subseconds but with "Z") + default: + sz += len(a) * 2 // Assume 2 bytes per entry. } - - return "{}", nil -} - -// appendArray appends rv to the buffer, returning the extended buffer and the -// delimiter used between elements. -// -// Returns an error when n <= 0 or rv is not a reflect.Array or reflect.Slice. -func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { - var del string - var err error + b := make([]byte, 0, sz) b = append(b, '{') - - if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { - return b, del, err - } - - for i := 1; i < n; i++ { - b = append(b, del...) - if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { - return b, del, err + for i, aa := range a { + if i > 0 { + b = append(b, del...) } - } - - return append(b, '}'), del, nil -} -// appendArrayElement appends rv to the buffer, returning the extended buffer -// and the delimiter to use before the next element. -// -// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted -// using driver.DefaultParameterConverter and the resulting []byte or string -// is double-quoted. -// -// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO -func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { - if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { - if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { - if n := rv.Len(); n > 0 { - return appendArray(b, rv, n) + swval := any(aa) + if v, ok := swval.(driver.Valuer); ok { + var err error + swval, err = v.Value() + if err != nil { + return nil, fmt.Errorf("pq: %w", err) } - - return b, "", nil } - } - - var del = "," - var err error - var iv = rv.Interface() - - if ad, ok := iv.(ArrayDelimiter); ok { - del = ad.ArrayDelimiter() - } - - if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { - return b, del, err - } - switch v := iv.(type) { - case nil: - return append(b, "NULL"...), del, nil - case []byte: - return appendArrayQuotedBytes(b, v), del, nil - case string: - return appendArrayQuotedBytes(b, []byte(v)), del, nil + restart: + switch v := swval.(type) { + default: + rv := reflect.ValueOf(aa) + switch rv.Kind() { + case reflect.String: + b = appendArrayQuotedText(b, []byte(rv.String())) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + b = strconv.AppendInt(b, rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + b = strconv.AppendUint(b, rv.Uint(), 10) + case reflect.Float32, reflect.Float64: + b = strconv.AppendFloat(b, rv.Float(), 'f', -1, 32) + case reflect.Ptr: + if rv.IsNil() { + b = append(b, "NULL"...) + continue + } + swval = rv.Elem().Interface() + goto restart + default: + return nil, fmt.Errorf("pq: unsupported array type %T", zero) + } + case []byte: + b = append(b, `"\\x`...) + for _, c := range v { + b = append(b, hextable[c>>4], hextable[c&0x0f]) + } + b = append(b, `"`...) + case string: + b = appendArrayQuotedText(b, []byte(v)) + case int: + b = strconv.AppendInt(b, int64(v), 10) + case int8: + b = strconv.AppendInt(b, int64(v), 10) + case int16: + b = strconv.AppendInt(b, int64(v), 10) + case int32: + b = strconv.AppendInt(b, int64(v), 10) + case int64: + b = strconv.AppendInt(b, v, 10) + case uint: + b = strconv.AppendUint(b, uint64(v), 10) + case uint8: + b = strconv.AppendUint(b, uint64(v), 10) + case uint16: + b = strconv.AppendUint(b, uint64(v), 10) + case uint32: + b = strconv.AppendUint(b, uint64(v), 10) + case uint64: + b = strconv.AppendUint(b, v, 10) + case float32: + b = strconv.AppendFloat(b, float64(v), 'f', -1, 32) + case float64: + b = strconv.AppendFloat(b, v, 'f', -1, 64) + case bool: + if any(aa).(bool) { + b = append(b, 't') + } else { + b = append(b, 'f') + } + case time.Time: + b = append(b, '"') + b = append(b, FormatTimestamp(v)...) + b = append(b, '"') + } } + b = append(b, '}') - b, err = appendValue(b, iv) - return b, del, err + return string(b), nil } -func appendArrayQuotedBytes(b, v []byte) []byte { +func appendArrayQuotedText(b, v []byte) []byte { b = append(b, '"') for { i := bytes.IndexAny(v, `"\`) @@ -772,14 +271,6 @@ func appendArrayQuotedBytes(b, v []byte) []byte { return append(b, '"') } -func appendValue(b []byte, v driver.Value) ([]byte, error) { - enc, err := encode(v, 0) - if err != nil { - return nil, err - } - return append(b, enc...), nil -} - // parseArray extracts the dimensions and elements of an array represented in // text format. Only representations emitted by the backend are supported. // Notably, whitespace around brackets and delimiters is significant, and NULL @@ -787,12 +278,11 @@ func appendValue(b []byte, v driver.Value) ([]byte, error) { // // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { - var depth, i int - if len(src) < 1 || src[0] != '{' { return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) } + var depth, i int Open: for i < len(src) { switch src[i] { @@ -890,14 +380,3 @@ Close: } return } - -func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { - dims, elems, err := parseArray(src, del) - if err != nil { - return nil, err - } - if len(dims) > 1 { - return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) - } - return elems, err -} diff --git a/array_test.go b/array_test.go index 55fb8cff..2e736687 100644 --- a/array_test.go +++ b/array_test.go @@ -9,6 +9,7 @@ import ( "reflect" "strings" "testing" + "time" "github.com/lib/pq/internal/pqtest" ) @@ -97,39 +98,42 @@ func TestArrayParseError(t *testing.T) { func TestArrayFunc(t *testing.T) { tests := []struct { - in any - want any + in any + want any + scan string + wantScanErr string + wantValueErr string }{ - {[]bool{}, &BoolArray{}}, - {[]float64{}, &Float64Array{}}, - {[]int64{}, &Int64Array{}}, - {[]float32{}, &Float32Array{}}, - {[]int32{}, &Int32Array{}}, - {[]string{}, &StringArray{}}, - {[][]byte{}, &ByteaArray{}}, - {nil, GenericArray{nil}}, - {[]driver.Value{}, GenericArray{[]driver.Value{}}}, - {[][]bool{}, GenericArray{[][]bool{}}}, - {[][]float64{}, GenericArray{[][]float64{}}}, - {[][]int64{}, GenericArray{[][]int64{}}}, - {[][]float32{}, GenericArray{[][]float32{}}}, - {[][]int32{}, GenericArray{[][]int32{}}}, - {[][]string{}, GenericArray{[][]string{}}}, - - {&[]bool{}, &BoolArray{}}, - {&[]float64{}, &Float64Array{}}, - {&[]int64{}, &Int64Array{}}, - {&[]float32{}, &Float32Array{}}, - {&[]int32{}, &Int32Array{}}, - {&[]string{}, &StringArray{}}, - {&[][]byte{}, &ByteaArray{}}, - {&[]sql.Scanner{}, GenericArray{&[]sql.Scanner{}}}, - {&[][]bool{}, GenericArray{&[][]bool{}}}, - {&[][]float64{}, GenericArray{&[][]float64{}}}, - {&[][]int64{}, GenericArray{&[][]int64{}}}, - {&[][]float32{}, GenericArray{&[][]float32{}}}, - {&[][]int32{}, GenericArray{&[][]int32{}}}, - {&[][]string{}, GenericArray{&[][]string{}}}, + {[]bool{}, &BoolArray{}, `{f,t}`, ``, ``}, + {[]float64{}, &Float64Array{}, `{1.1,1.2}`, ``, ``}, + {[]int64{}, &Int64Array{}, `{2,3}`, ``, ``}, + {[]float32{}, &Float32Array{}, `{1.1,1.2}`, ``, ``}, + {[]int32{}, &Int32Array{}, `{2,3}`, ``, ``}, + {[]string{}, &StringArray{}, `{"a","b"}`, ``, ``}, + {[][]byte{}, &ByteaArray{}, `{"\\x10","\\x11"}`, ``, ``}, + {nil, GenericArray{nil}, `{}`, `pq: destination is not a pointer to array or slice`, ``}, + {[]driver.Value{}, GenericArray{[]driver.Value{}}, `{}`, `[]driver.Value is not a pointer`, ``}, + {[][]bool{}, GenericArray{[][]bool{}}, `{}`, `[][]bool is not a pointer`, ``}, + {[][]float64{}, GenericArray{[][]float64{}}, `{}`, `[][]float64 is not a pointer`, ``}, + {[][]int64{}, GenericArray{[][]int64{}}, `{}`, `[][]int64 is not a pointer`, ``}, + {[][]float32{}, GenericArray{[][]float32{}}, `{}`, `[][]float32 is not a pointer`, ``}, + {[][]int32{}, GenericArray{[][]int32{}}, `{}`, `[][]int32 is not a pointer`, ``}, + {[][]string{}, GenericArray{[][]string{}}, `{}`, `[][]string is not a pointer`, ``}, + + {&[]bool{}, &BoolArray{}, `{f,t}`, ``, ``}, + {&[]float64{}, &Float64Array{}, `{1.1,1.2}`, ``, ``}, + {&[]int64{}, &Int64Array{}, `{1,2}`, ``, ``}, + {&[]float32{}, &Float32Array{}, `{1.1,1.2}`, ``, ``}, + {&[]int32{}, &Int32Array{}, `{1,2}`, ``, ``}, + {&[]string{}, &StringArray{}, `{"a","b"}`, ``, ``}, + {&[][]byte{}, &ByteaArray{}, `{"\\x10","\\x11"}`, ``, ``}, + {&[]sql.Scanner{}, GenericArray{&[]sql.Scanner{}}, `{}`, ``, `pq: unable to convert *[]sql.Scanner to array`}, + {&[][]bool{}, GenericArray{&[][]bool{}}, `{}`, ``, `*[][]bool to array`}, + {&[][]float64{}, GenericArray{&[][]float64{}}, `{}`, ``, `*[][]float64 to array`}, + {&[][]int64{}, GenericArray{&[][]int64{}}, `{}`, ``, `*[][]int64 to array`}, + {&[][]float32{}, GenericArray{&[][]float32{}}, `{}`, ``, `*[][]float32 to array`}, + {&[][]int32{}, GenericArray{&[][]int32{}}, `{}`, ``, `*[][]int32 to array`}, + {&[][]string{}, GenericArray{&[][]string{}}, `{}`, ``, `*[][]string to array`}, } for _, tt := range tests { @@ -138,11 +142,16 @@ func TestArrayFunc(t *testing.T) { if !reflect.DeepEqual(have, tt.want) { t.Errorf("\nhave: %#v\nwant: %#v", have, tt.want) } - if _, ok := have.(sql.Scanner); !ok { - t.Error("not a sql.Scanner") + err := have.Scan(tt.scan) + if !pqtest.ErrorContains(err, tt.wantScanErr) { + t.Errorf("wrong Scan() error:\nhave: %s\nwant: %s", err, tt.wantScanErr) + } + v, err := have.Value() + if !pqtest.ErrorContains(err, tt.wantValueErr) { + t.Errorf("wrong Value() error:\nhave: %s\nwant: %s", err, tt.wantValueErr) } - if _, ok := have.(driver.Valuer); !ok { - t.Error("not a driver.Valuer") + if tt.wantScanErr == "" && tt.wantValueErr == "" && v != tt.scan { + t.Errorf("not equal after Scan()/Value()\nhave: %v\nwant: %v", v, tt.scan) } }) } @@ -209,14 +218,14 @@ func TestArrayScan(t *testing.T) { {&BoolArray{true, true, true}, `{t}`, &BoolArray{true}, ``}, {&BoolArray{true, true, true}, `{f,t}`, &BoolArray{false, true}, ``}, - {&BoolArray{}, 1, &BoolArray{}, `int to BoolArray`}, + {&BoolArray{}, 1, &BoolArray{}, `int to *pq.ArrayOf[bool]`}, {newBool(), ``, newBool(), `unable to parse array`}, {newBool(), `{`, newBool(), `unable to parse array`}, - {newBool(), `{{t},{f}}`, newBool(), `cannot convert ARRAY[2][1] to BoolArray`}, - {newBool(), `{NULL}`, newBool(), `could not parse boolean array index 0: invalid boolean ""`}, - {newBool(), `{a}`, newBool(), `could not parse boolean array index 0: invalid boolean "a"`}, - {newBool(), `{t,b}`, newBool(), `could not parse boolean array index 1: invalid boolean "b"`}, - {newBool(), `{t,f,cd}`, newBool(), `could not parse boolean array index 2: invalid boolean "cd"`}, + {newBool(), `{{t},{f}}`, newBool(), `cannot convert ARRAY[2][1] to *pq.ArrayOf[bool]`}, + {newBool(), `{NULL}`, newBool(), `array index 0: cannot convert NULL to bool`}, + {newBool(), `{a}`, newBool(), `array index 0: couldn't convert "a" into type bool`}, + {newBool(), `{t,b}`, newBool(), `array index 1: couldn't convert "b" into type bool`}, + {newBool(), `{t,f,cd}`, newBool(), `array index 2: couldn't convert "cd" into type bool`}, {&ByteaArray{}, nil, new(ByteaArray), ``}, {&ByteaArray{}, `{}`, &ByteaArray{}, ``}, @@ -226,8 +235,8 @@ func TestArrayScan(t *testing.T) { {newBytea(), `{"\\xdead","\\xbeef"}`, &ByteaArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}, ``}, {&ByteaArray{{2}, {6}, {0, 0}}, ``, newBytea(), `unable to parse array`}, {&ByteaArray{{2}, {6}, {0, 0}}, `{`, newBytea(), `unable to parse array`}, - {&ByteaArray{{2}, {6}, {0, 0}}, `{{"\\xfeff"},{"\\xbeef"}}`, newBytea(), `cannot convert ARRAY[2][1] to ByteaArray`}, - {&ByteaArray{{2}, {6}, {0, 0}}, `{"\\abc"}`, newBytea(), `could not parse bytea array index 0: could not parse bytea value`}, + {&ByteaArray{{2}, {6}, {0, 0}}, `{{"\\xfeff"},{"\\xbeef"}}`, newBytea(), `cannot convert ARRAY[2][1] to *pq.ArrayOf[[]uint8]`}, + {&ByteaArray{{2}, {6}, {0, 0}}, `{"\\abc"}`, newBytea(), `array index 0: could not parse bytea value: strconv.ParseUint: parsing "abc": invalid syntax`}, {&StringArray{}, nil, new(StringArray), ``}, {&StringArray{}, `{}`, &StringArray{}, ``}, @@ -236,13 +245,13 @@ func TestArrayScan(t *testing.T) { {newString(), `{t}`, &StringArray{"t"}, ``}, {newString(), `{f,1}`, &StringArray{"f", "1"}, ``}, {newString(), `{"a\\b","c d",","}`, &StringArray{"a\\b", "c d", ","}, ``}, - {newString(), true, newString(), `cannot convert bool to StringArray`}, + {newString(), true, newString(), `cannot convert bool to *pq.ArrayOf[string]`}, {newString(), ``, newString(), `unable to parse array`}, {newString(), `{`, newString(), `unable to parse array`}, - {newString(), `{{a},{b}}`, newString(), `cannot convert ARRAY[2][1] to StringArray`}, - {newString(), `{NULL}`, newString(), `parsing array element index 0: cannot convert nil to string`}, - {newString(), `{a,NULL}`, newString(), `parsing array element index 1: cannot convert nil to string`}, - {newString(), `{a,b,NULL}`, newString(), `parsing array element index 2: cannot convert nil to string`}, + {newString(), `{{a},{b}}`, newString(), `cannot convert ARRAY[2][1] to *pq.ArrayOf[string]`}, + {newString(), `{NULL}`, newString(), `array index 0: cannot convert NULL to string`}, + {newString(), `{a,NULL}`, newString(), `array index 1: cannot convert NULL to string`}, + {newString(), `{a,b,NULL}`, newString(), `array index 2: cannot convert NULL to string`}, {&Int64Array{}, nil, new(Int64Array), ``}, {&Int64Array{}, `{}`, &Int64Array{}, ``}, @@ -250,14 +259,14 @@ func TestArrayScan(t *testing.T) { {newInt64(), `{}`, &Int64Array{}, ``}, {newInt64(), `{12}`, &Int64Array{12}, ``}, {newInt64(), `{345,678}`, &Int64Array{345, 678}, ``}, - {newInt64(), true, newInt64(), `cannot convert bool to Int64Array`}, + {newInt64(), true, newInt64(), `cannot convert bool to *pq.ArrayOf[int64]`}, {newInt64(), ``, newInt64(), `unable to parse array`}, {newInt64(), `{`, newInt64(), `unable to parse array`}, - {newInt64(), `{{5},{6}}`, newInt64(), `cannot convert ARRAY[2][1] to Int64Array`}, - {newInt64(), `{NULL}`, newInt64(), `parsing array element index 0:`}, - {newInt64(), `{a}`, newInt64(), `parsing array element index 0:`}, - {newInt64(), `{5,a}`, newInt64(), `parsing array element index 1:`}, - {newInt64(), `{5,6,a}`, newInt64(), `parsing array element index 2:`}, + {newInt64(), `{{5},{6}}`, newInt64(), `cannot convert ARRAY[2][1] to *pq.ArrayOf[int64]`}, + {newInt64(), `{NULL}`, newInt64(), `array index 0:`}, + {newInt64(), `{a}`, newInt64(), `array index 0:`}, + {newInt64(), `{5,a}`, newInt64(), `array index 1:`}, + {newInt64(), `{5,6,a}`, newInt64(), `array index 2:`}, {&Int32Array{}, nil, new(Int32Array), ``}, {&Int32Array{}, `{}`, &Int32Array{}, ``}, @@ -265,14 +274,14 @@ func TestArrayScan(t *testing.T) { {newInt32(), `{}`, &Int32Array{}, ``}, {newInt32(), `{12}`, &Int32Array{12}, ``}, {newInt32(), `{345,678}`, &Int32Array{345, 678}, ``}, - {newInt32(), true, newInt32(), `cannot convert bool to Int32Array`}, + {newInt32(), true, newInt32(), `cannot convert bool to *pq.ArrayOf[int32]`}, {newInt32(), ``, newInt32(), `unable to parse array`}, {newInt32(), `{`, newInt32(), `unable to parse array`}, - {newInt32(), `{{5},{6}}`, newInt32(), `cannot convert ARRAY[2][1] to Int32Array`}, - {newInt32(), `{NULL}`, newInt32(), `parsing array element index 0:`}, - {newInt32(), `{a}`, newInt32(), `parsing array element index 0:`}, - {newInt32(), `{5,a}`, newInt32(), `parsing array element index 1:`}, - {newInt32(), `{5,6,a}`, newInt32(), `parsing array element index 2:`}, + {newInt32(), `{{5},{6}}`, newInt32(), `cannot convert ARRAY[2][1] to *pq.ArrayOf[int32]`}, + {newInt32(), `{NULL}`, newInt32(), `array index 0:`}, + {newInt32(), `{a}`, newInt32(), `array index 0:`}, + {newInt32(), `{5,a}`, newInt32(), `array index 1:`}, + {newInt32(), `{5,6,a}`, newInt32(), `array index 2:`}, {&Float64Array{}, nil, new(Float64Array), ``}, {&Float64Array{}, `{}`, &Float64Array{}, ``}, @@ -281,14 +290,14 @@ func TestArrayScan(t *testing.T) { {newFloat64(), `{1.2}`, &Float64Array{1.2}, ``}, {newFloat64(), `{3.456,7.89}`, &Float64Array{3.456, 7.89}, ``}, {newFloat64(), `{3,1,2}`, &Float64Array{3, 1, 2}, ``}, - {newFloat64(), true, newFloat64(), `cannot convert bool to Float64Array`}, + {newFloat64(), true, newFloat64(), `cannot convert bool to *pq.ArrayOf[float64]`}, {newFloat64(), ``, newFloat64(), `unable to parse array`}, {newFloat64(), `{`, newFloat64(), `unable to parse array`}, - {newFloat64(), `{{5.6},{7.8}}`, newFloat64(), `cannot convert ARRAY[2][1] to Float64Array`}, - {newFloat64(), `{NULL}`, newFloat64(), `parsing array element index 0:`}, - {newFloat64(), `{a}`, newFloat64(), `parsing array element index 0:`}, - {newFloat64(), `{5.6,a}`, newFloat64(), `parsing array element index 1:`}, - {newFloat64(), `{5.6,7.8,a}`, newFloat64(), `parsing array element index 2:`}, + {newFloat64(), `{{5.6},{7.8}}`, newFloat64(), `cannot convert ARRAY[2][1] to *pq.ArrayOf[float64]`}, + {newFloat64(), `{NULL}`, newFloat64(), `array index 0:`}, + {newFloat64(), `{a}`, newFloat64(), `array index 0:`}, + {newFloat64(), `{5.6,a}`, newFloat64(), `array index 1:`}, + {newFloat64(), `{5.6,7.8,a}`, newFloat64(), `array index 2:`}, {&Float32Array{}, nil, new(Float32Array), ``}, {&Float32Array{}, `{}`, &Float32Array{}, ``}, @@ -297,14 +306,14 @@ func TestArrayScan(t *testing.T) { {newFloat32(), `{1.2}`, &Float32Array{1.2}, ``}, {newFloat32(), `{3.456,7.89}`, &Float32Array{3.456, 7.89}, ``}, {newFloat32(), `{3,1,2}`, &Float32Array{3, 1, 2}, ``}, - {newFloat32(), true, newFloat32(), `cannot convert bool to Float32Array`}, + {newFloat32(), true, newFloat32(), `cannot convert bool to *pq.ArrayOf[float32]`}, {newFloat32(), ``, newFloat32(), `unable to parse array`}, {newFloat32(), `{`, newFloat32(), `unable to parse array`}, - {newFloat32(), `{{5.6},{7.8}}`, newFloat32(), `cannot convert ARRAY[2][1] to Float32Array`}, - {newFloat32(), `{NULL}`, newFloat32(), `parsing array element index 0:`}, - {newFloat32(), `{a}`, newFloat32(), `parsing array element index 0:`}, - {newFloat32(), `{5.6,a}`, newFloat32(), `parsing array element index 1:`}, - {newFloat32(), `{5.6,7.8,a}`, newFloat32(), `parsing array element index 2:`}, + {newFloat32(), `{{5.6},{7.8}}`, newFloat32(), `cannot convert ARRAY[2][1] to *pq.ArrayOf[float32]`}, + {newFloat32(), `{NULL}`, newFloat32(), `array index 0:`}, + {newFloat32(), `{a}`, newFloat32(), `array index 0:`}, + {newFloat32(), `{5.6,a}`, newFloat32(), `array index 1:`}, + {newFloat32(), `{5.6,7.8,a}`, newFloat32(), `array index 2:`}, { &GenericArray{ptr([]sql.NullString{})}, @@ -357,20 +366,20 @@ func TestArrayScan(t *testing.T) { t.Run(strings.TrimPrefix(fmt.Sprintf("%T", tt.array), "*pq."), func(t *testing.T) { err := tt.array.Scan(tt.in) if !pqtest.ErrorContains(err, tt.wantErr) { - t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr) + t.Fatalf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr) } if !reflect.DeepEqual(tt.array, tt.want) { - t.Errorf("\nhave: %#v\nwant: %#v", tt.array, tt.want) + t.Fatalf("\nhave: %#v\nwant: %#v", tt.array, tt.want) } // Run again but with []byte input instead of string. if str, ok := tt.in.(string); ok { err := tt.array.Scan([]byte(str)) if !pqtest.ErrorContains(err, tt.wantErr) { - t.Errorf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr) + t.Fatalf("wrong error:\nhave: %s\nwant: %s", err, tt.wantErr) } if !reflect.DeepEqual(tt.array, tt.want) { - t.Errorf("\nhave: %#v\nwant: %#v", tt.array, tt.want) + t.Fatalf("\nhave: %#v\nwant: %#v", tt.array, tt.want) } } }) @@ -527,6 +536,92 @@ func TestArrayValueBackend(t *testing.T) { } } +type ( + typedInt int32 + typedUint uint8 + typedFloat float32 + pipeString string + myType struct{ field any } +) + +func (pipeString) ArrayDelimiter() string { return "||" } +func (m *myType) Scan(src any) error { m.field = fmt.Sprintf("%s", src); return nil } +func (m myType) Value() (driver.Value, error) { return fmt.Sprintf("%s", m.field), nil } + +func TestArrayOf(t *testing.T) { + tests := []struct { + arr any + scan string + wantErr string + }{ + + {&ArrayOf[[]byte]{}, `{}`, ``}, + {&ArrayOf[string]{}, `{}`, ``}, + {&ArrayOf[int]{}, `{}`, ``}, + + {&ArrayOf[[]byte]{}, `{"\\x10","\\x11"}`, ``}, + {&ArrayOf[string]{}, `{"a","b"}`, ``}, + {&ArrayOf[int]{}, `{1,2}`, ``}, + {&ArrayOf[int8]{}, `{1,2}`, ``}, + {&ArrayOf[int16]{}, `{1,2}`, ``}, + {&ArrayOf[int32]{}, `{1,2}`, ``}, + {&ArrayOf[int64]{}, `{1,2}`, ``}, + {&ArrayOf[uint]{}, `{1,2}`, ``}, + {&ArrayOf[uint8]{}, `{1,2}`, ``}, + {&ArrayOf[uint16]{}, `{1,2}`, ``}, + {&ArrayOf[uint32]{}, `{1,2}`, ``}, + {&ArrayOf[uint64]{}, `{1,2}`, ``}, + {&ArrayOf[float32]{}, `{1.1,2.2}`, ``}, + {&ArrayOf[float64]{}, `{1.1,2.2}`, ``}, + {&ArrayOf[bool]{}, `{f,t}`, ``}, + {&ArrayOf[time.Time]{}, `{"2020-02-03 19:20:21Z"}`, ``}, + {&ArrayOf[pipeString]{}, `{"a"||"b"}`, ``}, + {&ArrayOf[any]{}, `{"1","2"}`, ``}, + {&ArrayOf[typedInt]{}, `{1,2}`, ``}, + {&ArrayOf[typedUint]{}, `{3,4}`, ``}, + {&ArrayOf[typedFloat]{}, `{1.1,2.2}`, ``}, + {&ArrayOf[myType]{}, `{"abc","def"}`, ``}, + + {&ArrayOf[*int]{}, `{1,NULL,2}`, ``}, + {&ArrayOf[*string]{}, `{"a",NULL,"b"}`, ``}, + {&ArrayOf[*time.Time]{}, `{"2020-02-03 19:20:21Z",NULL,"2021-02-03 19:20:21Z"}`, ``}, + + {&ArrayOf[int]{}, `{1,NULL,2}`, `array index 1: cannot convert NULL to int`}, + {&ArrayOf[string]{}, `{"a",NULL,"b"}`, `array index 1: cannot convert NULL to string`}, + {&ArrayOf[int]{}, `{"asd"}`, `array index 0: converting driver.Value type []uint8 ("asd") to a int: invalid syntax`}, + {&ArrayOf[time.Time]{}, `{"2020-02-03 19:20:21Z",NULL,"2021-02-03 19:20:21Z"}`, `array index 1: cannot convert NULL to time.Time`}, + {&ArrayOf[time.Time]{}, `{"asd"}`, `array index 0: invalid timestamp`}, + } + + for _, tt := range tests { + n := strings.ReplaceAll(strings.TrimPrefix(strings.TrimPrefix(fmt.Sprintf("%T", tt.arr), "*pq.ArrayOf")[1:], "github.com/lib/pq."), "interface {}", "any") + n = n[:len(n)-1] + t.Run(n, func(t *testing.T) { + sv, ok := tt.arr.(interface { + sql.Scanner + driver.Valuer + }) + if !ok { + t.Fatalf("not a sql.Scanner or driver.Valuer") + } + err := sv.Scan(tt.scan) + if !pqtest.ErrorContains(err, tt.wantErr) { + t.Fatalf("wrong Scan() error:\nhave: %s\nwant: %s", err, tt.wantErr) + } + v, err := sv.Value() + if err != nil { + t.Fatalf("Value error: %s", err) + } + if tt.wantErr != "" { + tt.scan = `{}` + } + if v != tt.scan { + t.Errorf("not equal after Scan()/Value()\nhave: %v\nwant: %v\nsv: %#v", v, tt.scan, sv) + } + }) + } +} + func BenchmarkArray(b *testing.B) { tests := []struct { arr interface { diff --git a/deprecated.go b/deprecated.go index 86107677..9276eca9 100644 --- a/deprecated.go +++ b/deprecated.go @@ -4,6 +4,9 @@ import ( "bytes" "database/sql" "database/sql/driver" + "fmt" + "reflect" + "strings" "github.com/lib/pq/pqerror" ) @@ -138,3 +141,301 @@ func makeStmt(b *bytes.Buffer, columns ...string) { } b.WriteString(") FROM STDIN") } + +// ArrayDelimiter may be optionally implemented to override the array delimiter. +// +// Deprecated: this doesn't need to be exported. +type ArrayDelimiter interface{ ArrayDelimiter() string } + +// Array returns the optimal driver.Valuer and sql.Scanner for an array or +// slice of any dimension. +// +// For example: +// +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) +// +// Scanning multi-dimensional arrays is not supported. Arrays where the lower +// bound is not one (such as `[0:0]={1}') are not supported. +// +// Deprecated: use ArrayOf[T] +func Array(a any) interface { + driver.Valuer + sql.Scanner +} { + switch a := a.(type) { + case []bool: + return (*ArrayOf[bool])(&a) + case []float64: + return (*ArrayOf[float64])(&a) + case []float32: + return (*ArrayOf[float32])(&a) + case []int64: + return (*ArrayOf[int64])(&a) + case []int32: + return (*ArrayOf[int32])(&a) + case []string: + return (*ArrayOf[string])(&a) + case [][]byte: + return (*ArrayOf[[]byte])(&a) + case *[]bool: + return (*BoolArray)(a) + case *[]float64: + return (*Float64Array)(a) + case *[]float32: + return (*Float32Array)(a) + case *[]int64: + return (*Int64Array)(a) + case *[]int32: + return (*Int32Array)(a) + case *[]string: + return (*StringArray)(a) + case *[][]byte: + return (*ByteaArray)(a) + } + return GenericArray{a} +} + +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +// +// Deprecated: use ArrayOf[bool] +// +//go:fix inline +type BoolArray = ArrayOf[bool] + +// StringArray represents a one-dimensional array of the PostgreSQL character types. +// +// Deprecated: use ArrayOf[string] +// +//go:fix inline +type StringArray = ArrayOf[string] + +// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. +// +// Deprecated: use ArrayOf[[]byte] +// +//go:fix inline +type ByteaArray = ArrayOf[[]byte] + +// Float64Array represents a one-dimensional array of the PostgreSQL double precision type. +// +// Deprecated: use ArrayOf[float32] +// +//go:fix inline +type Float64Array = ArrayOf[float64] + +// Float32Array represents a one-dimensional array of the PostgreSQL double precision type. +// +// Deprecated: use ArrayOf[float32] +// +//go:fix inline +type Float32Array = ArrayOf[float32] + +// Int64Array represents a one-dimensional array of the PostgreSQL integer type. +// +// Deprecated: use ArrayOf[int64] +// +//go:fix inline +type Int64Array = ArrayOf[int64] + +// Int32Array represents a one-dimensional array of the PostgreSQL integer type. +// +// Deprecated: use ArrayOf[int32] +// +//go:fix inline +type Int32Array = ArrayOf[int32] + +// GenericArray implements the driver.Valuer and sql.Scanner interfaces for an +// array or slice of any dimension. +// +// Deprecated: use ArrayOf[myType] or ArrayOf[sql.Scanner] +type GenericArray struct{ A any } + +var ( + typeByteSlice = reflect.TypeOf([]byte{}) + typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +) + +// Value implements the driver.Valuer interface. +func (a GenericArray) Value() (driver.Value, error) { + if a.A == nil { + return nil, nil + } + + rv := reflect.ValueOf(a.A) + switch rv.Kind() { + default: + return nil, fmt.Errorf("pq: unable to convert %T to array", a.A) + case reflect.Slice: + if rv.IsNil() { + return nil, nil + } + case reflect.Array: + // Do nothing + } + + l := rv.Len() + if l == 0 { + return "{}", nil + } + + b := make([]byte, 0, 1+2*l) + b, _, err := appendArray(b, rv, l) + return string(b), err +} + +// Scan implements the sql.Scanner interface. +func (a GenericArray) Scan(src any) error { + dpv := reflect.ValueOf(a.A) + switch { + case dpv.Kind() != reflect.Pointer: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + case dpv.IsNil(): + return fmt.Errorf("pq: destination %T is nil", a.A) + } + + dv := dpv.Elem() + switch dv.Kind() { + case reflect.Slice: + case reflect.Array: + default: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + } + + switch src := src.(type) { + case []byte: + return a.scanBytes(src, dv) + case string: + return a.scanBytes([]byte(src), dv) + case nil: + if dv.Kind() == reflect.Slice { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + } + return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) +} + +func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { + dtype := dv.Type().Elem() + dims, elems, err := parseArray(src, arrayDelimiter(reflect.Zero(dtype).Interface())) + if err != nil { + return err + } + if len(dims) > 1 { + return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", + strings.Replace(fmt.Sprint(dims), " ", "][", -1)) + } + // Treat a zero-dimensional array like an array with a single dimension of zero. + if len(dims) == 0 { + dims = append(dims, 0) + } + + for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { + switch rt.Kind() { + case reflect.Slice: + case reflect.Array: + if rt.Len() != dims[i] { + return fmt.Errorf("pq: cannot convert ARRAY%s to %s", + strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) + } + default: + } + } + + assign := func([]byte, reflect.Value) error { + return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", dtype) + } + if reflect.PointerTo(dtype).Implements(typeSQLScanner) { + // dest is always addressable because it is an element of a slice. + assign = func(src []byte, dest reflect.Value) error { + ss := dest.Addr().Interface().(sql.Scanner) + if src == nil { + return ss.Scan(nil) + } + return ss.Scan(src) + } + } + values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) + for i, e := range elems { + err := assign(e, values.Index(i)) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %w", i, err) + } + } + + switch dv.Kind() { + case reflect.Slice: + dv.Set(values.Slice(0, dims[0])) + case reflect.Array: + for i := 0; i < dims[0]; i++ { + dv.Index(i).Set(values.Index(i)) + } + } + return nil +} + +// appendArray appends rv to the buffer, returning the extended buffer and the +// delimiter used between elements. +// +// Returns an error when n <= 0 or rv is not a reflect.Array or reflect.Slice. +func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { + b = append(b, '{') + + b, del, err := appendArrayElement(b, rv.Index(0)) + if err != nil { + return b, del, err + } + for i := 1; i < n; i++ { + b = append(b, del...) + b, del, err = appendArrayElement(b, rv.Index(i)) + if err != nil { + return b, del, err + } + } + return append(b, '}'), del, nil +} + +// appendArrayElement appends rv to the buffer, returning the extended buffer +// and the delimiter to use before the next element. +// +// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted +// using driver.DefaultParameterConverter and the resulting []byte or string is +// double-quoted. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { + if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { + if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { + if n := rv.Len(); n > 0 { + return appendArray(b, rv, n) + } + return b, "", nil + } + } + + iv := rv.Interface() + del := string(arrayDelimiter(iv)) + iv, err := driver.DefaultParameterConverter.ConvertValue(iv) + if err != nil { + return b, del, err + } + + switch v := iv.(type) { + case nil: + return append(b, "NULL"...), del, nil + case []byte: + return appendArrayQuotedText(b, v), del, nil + case string: + return appendArrayQuotedText(b, []byte(v)), del, nil + } + + enc, err := encode(iv, 0) + if err != nil { + return b, del, err + } + return append(b, enc...), del, err +}