Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 75 additions & 41 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (d *Decoder) setDefaults(t reflect.Type, v reflect.Value) MultiError {
if v.Type().Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous {
if field.Kind() == reflect.Ptr && field.IsNil() && isAnonymousField(t, i) {
field.Set(reflect.New(field.Type().Elem()))
}
}
Expand All @@ -128,53 +128,64 @@ func (d *Decoder) setDefaults(t reflect.Type, v reflect.Value) MultiError {
for _, f := range struc.fields {
vCurrent := v.FieldByName(f.name)

if vCurrent.Type().Kind() == reflect.Struct && f.defaultValue == "" {
errs.merge(d.setDefaults(vCurrent.Type(), vCurrent))
} else if isPointerToStruct(vCurrent) && f.defaultValue == "" {
errs.merge(d.setDefaults(vCurrent.Elem().Type(), vCurrent.Elem()))
if f.defaultValue == "" {
if vCurrent.Type().Kind() == reflect.Struct {
errs.merge(d.setDefaults(vCurrent.Type(), vCurrent))
} else if isPointerToStruct(vCurrent) {
errs.merge(d.setDefaults(vCurrent.Elem().Type(), vCurrent.Elem()))
}
continue
}

if f.defaultValue != "" && f.isRequired {
errs.merge(MultiError{"default-" + f.name: errors.New("required fields cannot have a default value")})
} else if f.defaultValue != "" && vCurrent.IsZero() && !f.isRequired {
if f.typ.Kind() == reflect.Struct {
errs.merge(MultiError{"default-" + f.name: errors.New("default option is supported only on: bool, float variants, string, unit variants types or their corresponding pointers or slices")})
} else if f.typ.Kind() == reflect.Slice {
vals := strings.Split(f.defaultValue, "|")

// check if slice has one of the supported types for defaults
if _, ok := builtinConverters[f.typ.Elem().Kind()]; !ok {
errs.merge(MultiError{"default-" + f.name: errors.New("default option is supported only on: bool, float variants, string, unit variants types or their corresponding pointers or slices")})
continue
}
if f.isRequired {
errs.add("default-"+f.name, errRequiredFieldsCantHaveDefaults)
continue
}
if !vCurrent.IsZero() {
continue
}

defaultSlice := reflect.MakeSlice(f.typ, 0, cap(vals))
for _, val := range vals {
// this check is to handle if the wrong value is provided
convertedVal := builtinConverters[f.typ.Elem().Kind()](val)
if !convertedVal.IsValid() {
errs.merge(MultiError{"default-" + f.name: fmt.Errorf("failed setting default: %s is not compatible with field %s type", val, f.name)})
break
}
defaultSlice = reflect.Append(defaultSlice, convertedVal)
}
vCurrent.Set(defaultSlice)
} else if f.typ.Kind() == reflect.Ptr {
t1 := f.typ.Elem()
switch f.typ.Kind() {
case reflect.Struct:
errs.add("default-"+f.name, errUnsupportedDefaultFieldType)

if t1.Kind() == reflect.Struct || t1.Kind() == reflect.Slice {
errs.merge(MultiError{"default-" + f.name: errors.New("default option is supported only on: bool, float variants, string, unit variants types or their corresponding pointers or slices")})
}
case reflect.Slice:
vals := strings.Split(f.defaultValue, "|")

// check if slice has one of the supported types for defaults
if _, ok := builtinConverters[f.typ.Elem().Kind()]; !ok {
errs.add("default-"+f.name, errUnsupportedDefaultFieldType)
continue
}

defaultSlice := reflect.MakeSlice(f.typ, 0, cap(vals))
for _, val := range vals {
// this check is to handle if the wrong value is provided
if convertedVal := convertPointer(t1.Kind(), f.defaultValue); convertedVal.IsValid() {
vCurrent.Set(convertedVal)
}
} else {
// this check is to handle if the wrong value is provided
if convertedVal := builtinConverters[f.typ.Kind()](f.defaultValue); convertedVal.IsValid() {
vCurrent.Set(builtinConverters[f.typ.Kind()](f.defaultValue))
convertedVal := builtinConverters[f.typ.Elem().Kind()](val)
if !convertedVal.IsValid() {
errs.add("default-"+f.name, errIncompatibleValue(val, f))
break
}
defaultSlice = reflect.Append(defaultSlice, convertedVal)
}
vCurrent.Set(defaultSlice)

case reflect.Ptr:
t1 := f.typ.Elem()

if t1.Kind() == reflect.Struct || t1.Kind() == reflect.Slice {
errs.add("default-"+f.name, errUnsupportedDefaultFieldType)
}

// this check is to handle if the wrong value is provided
if convertedVal := convertPointer(t1.Kind(), f.defaultValue); convertedVal.IsValid() {
vCurrent.Set(convertedVal)
}

default:
// this check is to handle if the wrong value is provided
if convertedVal := builtinConverters[f.typ.Kind()](f.defaultValue); convertedVal.IsValid() {
vCurrent.Set(builtinConverters[f.typ.Kind()](f.defaultValue))
}
}
}
Expand All @@ -186,6 +197,11 @@ func isPointerToStruct(v reflect.Value) bool {
return !v.IsZero() && v.Type().Kind() == reflect.Ptr && v.Elem().Type().Kind() == reflect.Struct
}

//go:noinline
func isAnonymousField(t reflect.Type, nr int) bool {
return t.Field(nr).Anonymous
}

// checkRequired checks whether required fields are empty
//
// check type t recursively if t has struct fields.
Expand Down Expand Up @@ -546,6 +562,16 @@ type unmarshaler struct {

// Errors ---------------------------------------------------------------------

var (
errRequiredFieldsCantHaveDefaults = errors.New("required fields cannot have a default value")
errUnsupportedDefaultFieldType = errors.New("default option is supported only on: bool, float variants, string, unit variants types or their corresponding pointers or slices")
)

//go:noinline
func errIncompatibleValue(val string, f *fieldInfo) error {
return fmt.Errorf("failed setting default: %s is not compatible with field %s type", val, f.name)
}

// ConversionError stores information about a failed conversion.
type ConversionError struct {
Key string // key from the source map.
Expand Down Expand Up @@ -611,6 +637,14 @@ func (e MultiError) Error() string {
return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1)
}

//go:noinline
func (e MultiError) add(key string, err error) {
if e[key] == nil {
e[key] = err
}
}

//go:noinline
func (e MultiError) merge(errors MultiError) {
for key, err := range errors {
if e[key] == nil {
Expand Down
Loading