Skip to content

Commit b987c81

Browse files
Preserve User Intent When Parsing Pointers (#336)
* Only set a field if the flag was set by the user, otherwise set it to nil. This prevents accidental overrides from Cobra defaults (eg. if --auto-deploy isn't set the field will be nil), allows for better type support and cleaner parsing code * Refactor types/service.go to use typed input from cobra parsing for serviceType, Runtime and region * Use typed values in clone and add test helper GitOrigin-RevId: cf88047a428e2d80ec414f7e0f2eacd93f968eb1
1 parent 859ce9d commit b987c81

13 files changed

Lines changed: 290 additions & 107 deletions

File tree

pkg/command/inputs.go

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,50 @@ func getBoolValue(flags *pflag.FlagSet, args []string, tag string) (*bool, error
171171
return &val, nil
172172
}
173173

174+
func shouldSetPointerField(flags *pflag.FlagSet, cliTag string) bool {
175+
return isArg(cliTag) || flags.Changed(cliTag)
176+
}
177+
178+
func setStringPointerField(elemField reflect.Value, shouldSet bool, val *string) {
179+
if !shouldSet || val == nil {
180+
elemField.SetZero()
181+
return
182+
}
183+
ptr := reflect.New(elemField.Type().Elem())
184+
ptr.Elem().SetString(*val)
185+
elemField.Set(ptr)
186+
}
187+
188+
func setIntPointerField(elemField reflect.Value, shouldSet bool, val *int) {
189+
if !shouldSet || val == nil {
190+
elemField.SetZero()
191+
return
192+
}
193+
ptr := reflect.New(elemField.Type().Elem())
194+
ptr.Elem().SetInt(int64(*val))
195+
elemField.Set(ptr)
196+
}
197+
198+
func setFloat64PointerField(elemField reflect.Value, shouldSet bool, val *float64) {
199+
if !shouldSet || val == nil {
200+
elemField.SetZero()
201+
return
202+
}
203+
ptr := reflect.New(elemField.Type().Elem())
204+
ptr.Elem().SetFloat(*val)
205+
elemField.Set(ptr)
206+
}
207+
208+
func setBoolPointerField(elemField reflect.Value, shouldSet bool, val *bool) {
209+
if !shouldSet || val == nil {
210+
elemField.SetZero()
211+
return
212+
}
213+
ptr := reflect.New(elemField.Type().Elem())
214+
ptr.Elem().SetBool(*val)
215+
elemField.Set(ptr)
216+
}
217+
174218
func ParseCommandInteractiveOnly(cmd *cobra.Command, args []string, v any) error {
175219
format := GetFormatFromContext(cmd.Context())
176220
if !format.Interactive() {
@@ -221,25 +265,25 @@ func ParseCommand(cmd *cobra.Command, args []string, v any) error {
221265
if err != nil {
222266
return err
223267
}
224-
elemField.Set(reflect.ValueOf(val))
268+
setStringPointerField(elemField, shouldSetPointerField(flags, cliTag), val)
225269
case reflect.Int:
226270
val, err := getIntValue(flags, args, cliTag)
227271
if err != nil {
228272
return err
229273
}
230-
elemField.Set(reflect.ValueOf(val))
274+
setIntPointerField(elemField, shouldSetPointerField(flags, cliTag), val)
231275
case reflect.Float64:
232276
val, err := getFloat64Value(flags, args, cliTag)
233277
if err != nil {
234278
return err
235279
}
236-
elemField.Set(reflect.ValueOf(val))
280+
setFloat64PointerField(elemField, shouldSetPointerField(flags, cliTag), val)
237281
case reflect.Bool:
238282
val, err := getBoolValue(flags, args, cliTag)
239283
if err != nil {
240284
return err
241285
}
242-
elemField.Set(reflect.ValueOf(val))
286+
setBoolPointerField(elemField, shouldSetPointerField(flags, cliTag), val)
243287
}
244288
case reflect.Slice:
245289
switch field.Type.Elem().Kind() {

pkg/command/inputs_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,36 @@ func TestParseCommand(t *testing.T) {
9595
require.Equal(t, "bar", *v.Foo)
9696
})
9797

98+
t.Run("unset pointer flag remains nil", func(t *testing.T) {
99+
type testStruct struct {
100+
Foo *bool `cli:"foo"`
101+
}
102+
var v testStruct
103+
cmd := &cobra.Command{}
104+
cmd.Flags().Bool("foo", true, "")
105+
require.NoError(t, cmd.ParseFlags([]string{}))
106+
107+
err := command.ParseCommand(cmd, []string{}, &v)
108+
require.NoError(t, err)
109+
require.Nil(t, v.Foo)
110+
})
111+
112+
t.Run("parse pointer alias type", func(t *testing.T) {
113+
type myString string
114+
type testStruct struct {
115+
Foo *myString `cli:"foo"`
116+
}
117+
var v testStruct
118+
cmd := &cobra.Command{}
119+
cmd.Flags().String("foo", "", "")
120+
require.NoError(t, cmd.ParseFlags([]string{"--foo", "bar"}))
121+
122+
err := command.ParseCommand(cmd, []string{}, &v)
123+
require.NoError(t, err)
124+
require.NotNil(t, v.Foo)
125+
require.Equal(t, myString("bar"), *v.Foo)
126+
})
127+
98128
t.Run("parse slice", func(t *testing.T) {
99129
type testStruct struct {
100130
Foo []string `cli:"foo"`

pkg/service/clone.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func extractCloneSourceDefaults(source *client.Service) sourceDefaults {
7575
}
7676

7777
func applyBaseDefaults(input *servicetypes.Service, defaults sourceDefaults) {
78-
input.Type = withDefaultFromValue(input.Type, defaults.serviceType)
78+
input.Type = withDefaultFromValue(input.Type, servicetypes.ServiceType(defaults.serviceType))
7979
input.RootDirectory = withDefault(input.RootDirectory, defaults.rootDirectory)
8080
input.EnvironmentID = withDefault(input.EnvironmentID, defaults.environmentID)
8181
}
@@ -112,7 +112,7 @@ func applyRuntimeDefaults(input *servicetypes.Service, defaults sourceDefaults)
112112
return
113113
}
114114

115-
input.Runtime = withDefaultFromValue(input.Runtime, *defaults.runtime)
115+
input.Runtime = withDefaultFromValue(input.Runtime, servicetypes.ServiceRuntime(*defaults.runtime))
116116
}
117117

118118
func applyRegistryCredentialDefault(input *servicetypes.Service, defaults sourceDefaults) {
@@ -136,8 +136,11 @@ func withDefault(dst *string, src *string) *string {
136136
return pointers.From(*src)
137137
}
138138

139-
func withDefaultFromValue[S ~string](dst *string, src S) *string {
140-
return withDefault(dst, pointers.From(string(src)))
139+
func withDefaultFromValue[T ~string](dst *T, src T) *T {
140+
if dst != nil {
141+
return dst
142+
}
143+
return pointers.From(src)
141144
}
142145

143146
// RuntimeFromSourceService extracts runtime from a service when that service type has a runtime field.

pkg/service/clone_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ func TestCLIInputFromSource(t *testing.T) {
3232

3333
ServiceFromAPI(&input, source)
3434

35-
require.Equal(t, "web_service", *input.Type)
35+
require.Equal(t, servicetypes.ServiceType("web_service"), *input.Type)
3636
require.Equal(t, "https://github.com/renderinc/api", *input.Repo)
3737
require.Equal(t, "master", *input.Branch)
3838
require.Equal(t, "services/api", *input.RootDirectory)
3939
require.Equal(t, "evm-123", *input.EnvironmentID)
40-
require.Equal(t, "node", *input.Runtime)
40+
require.Equal(t, servicetypes.ServiceRuntime("node"), *input.Runtime)
4141
})
4242

4343
t.Run("hydrates image-backed defaults", func(t *testing.T) {
@@ -61,7 +61,7 @@ func TestCLIInputFromSource(t *testing.T) {
6161

6262
require.Equal(t, "docker.io/org/app:latest", *input.Image)
6363
require.Equal(t, "rgc-123", *input.RegistryCredential)
64-
require.Equal(t, "image", *input.Runtime)
64+
require.Equal(t, servicetypes.ServiceRuntime("image"), *input.Runtime)
6565
})
6666

6767
t.Run("hydrates docker runtime registry credential from docker details", func(t *testing.T) {
@@ -95,7 +95,7 @@ func TestCLIInputFromSource(t *testing.T) {
9595

9696
ServiceFromAPI(&input, source)
9797

98-
require.Equal(t, "docker", *input.Runtime)
98+
require.Equal(t, servicetypes.ServiceRuntime("docker"), *input.Runtime)
9999
require.Equal(t, "rgc-456", *input.RegistryCredential)
100100
})
101101

@@ -119,12 +119,12 @@ func TestCLIInputFromSource(t *testing.T) {
119119

120120
ServiceFromAPI(&input, source)
121121

122-
require.Equal(t, "background_worker", *input.Type)
122+
require.Equal(t, servicetypes.ServiceType("background_worker"), *input.Type)
123123
require.Equal(t, "https://github.com/renderinc/worker", *input.Repo)
124124
require.Equal(t, "main", *input.Branch)
125125
require.Equal(t, "workers/processor", *input.RootDirectory)
126126
require.Equal(t, "evm-456", *input.EnvironmentID)
127-
require.Equal(t, "python", *input.Runtime)
127+
require.Equal(t, servicetypes.ServiceRuntime("python"), *input.Runtime)
128128
})
129129

130130
t.Run("hydrates static site defaults", func(t *testing.T) {
@@ -147,7 +147,7 @@ func TestCLIInputFromSource(t *testing.T) {
147147

148148
ServiceFromAPI(&input, source)
149149

150-
require.Equal(t, "static_site", *input.Type)
150+
require.Equal(t, servicetypes.ServiceType("static_site"), *input.Type)
151151
require.Equal(t, "https://github.com/renderinc/docs", *input.Repo)
152152
require.Equal(t, "main", *input.Branch)
153153
require.Equal(t, "website", *input.RootDirectory)
@@ -182,18 +182,18 @@ func TestCLIInputFromSource(t *testing.T) {
182182
input := servicetypes.Service{
183183
Name: "clone-explicit",
184184
From: pointers.From("srv-source"),
185-
Type: pointers.From("private_service"),
185+
Type: svcTypeRaw("private_service"),
186186
Repo: pointers.From("https://github.com/org/custom"),
187187
Branch: pointers.From("feature-x"),
188-
Runtime: pointers.From("docker"),
188+
Runtime: svcRuntime(servicetypes.ServiceRuntimeDocker),
189189
}
190190

191191
ServiceFromAPI(&input, source)
192192

193-
require.Equal(t, "private_service", *input.Type)
193+
require.Equal(t, servicetypes.ServiceType("private_service"), *input.Type)
194194
require.Equal(t, "https://github.com/org/custom", *input.Repo)
195195
require.Equal(t, "feature-x", *input.Branch)
196-
require.Equal(t, "docker", *input.Runtime)
196+
require.Equal(t, servicetypes.ServiceRuntime("docker"), *input.Runtime)
197197
})
198198

199199
t.Run("does not copy repo defaults when image is explicitly provided", func(t *testing.T) {
@@ -214,7 +214,7 @@ func TestCLIInputFromSource(t *testing.T) {
214214
require.Equal(t, "docker.io/custom/image:latest", *input.Image)
215215
require.Nil(t, input.Repo)
216216
require.Nil(t, input.Branch)
217-
require.Equal(t, "image", *input.Runtime)
217+
require.Equal(t, servicetypes.ServiceRuntime("image"), *input.Runtime)
218218
})
219219

220220
t.Run("does not copy image defaults when repo is explicitly provided", func(t *testing.T) {

pkg/service/create.go

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55

66
"github.com/render-oss/cli/pkg/client"
77
"github.com/render-oss/cli/pkg/pointers"
8-
"github.com/render-oss/cli/pkg/types"
8+
types "github.com/render-oss/cli/pkg/types"
99
servicetypes "github.com/render-oss/cli/pkg/types/service"
1010
)
1111

@@ -29,10 +29,10 @@ func BuildCreateRequest(cliInput servicetypes.Service, ownerID string) (client.C
2929
return client.CreateServiceJSONRequestBody{}, err
3030
}
3131

32-
typedServiceType := toTypedFromAlias[servicetypes.ServiceType, client.ServiceType](serviceType)
33-
typedRuntime := toTypedFromAlias[servicetypes.ServiceRuntime, client.ServiceRuntime](runtime)
34-
typedRegion := toTypedFromAlias[types.Region, client.Region](region)
35-
typedPlan := toTypedFromAlias[string, client.Plan](cliInput.Plan)
32+
typedServiceType := toClientType(serviceType)
33+
typedRuntime := toClientRuntime(runtime)
34+
typedRegion := toClientRegion(region)
35+
typedPlan := toClientPlan(cliInput.Plan)
3636

3737
// When an image is provided without an explicit runtime, default to "image" runtime.
3838
// This allows users to omit --runtime when using --image.
@@ -99,14 +99,6 @@ func BuildCreateRequest(cliInput servicetypes.Service, ownerID string) (client.C
9999
return body, nil
100100
}
101101

102-
func toTypedFromAlias[S ~string, T ~string](value *S) *T {
103-
if value == nil {
104-
return nil
105-
}
106-
typed := T(*value)
107-
return &typed
108-
}
109-
110102
func parseEnvVarInputs(values []string) ([]client.EnvVarInput, error) {
111103
if len(values) == 0 {
112104
return nil, nil
@@ -310,3 +302,35 @@ func buildCronEnvSpecificDetails(buildCommand *string, cronCommand *string) (*cl
310302

311303
return envSpecificDetails, nil
312304
}
305+
306+
func toClientType(value *servicetypes.ServiceType) *client.ServiceType {
307+
if value == nil {
308+
return nil
309+
}
310+
typed := client.ServiceType(*value)
311+
return &typed
312+
}
313+
314+
func toClientRuntime(value *servicetypes.ServiceRuntime) *client.ServiceRuntime {
315+
if value == nil {
316+
return nil
317+
}
318+
typed := client.ServiceRuntime(*value)
319+
return &typed
320+
}
321+
322+
func toClientRegion(value *types.Region) *client.Region {
323+
if value == nil {
324+
return nil
325+
}
326+
typed := client.Region(*value)
327+
return &typed
328+
}
329+
330+
func toClientPlan(value *string) *client.Plan {
331+
if value == nil {
332+
return nil
333+
}
334+
typed := client.Plan(*value)
335+
return &typed
336+
}

0 commit comments

Comments
 (0)