Skip to content

Commit 1c8196f

Browse files
authored
Enrich policy input (#540)
* Update SDK * Add hook to policy input for enrichment * Fix bug in checking the verdict of policy execution This currently only considers boolean values, as it is trying to figure out if it should include `terminal` in the result output. * Update method signatures to include hook struct * Fix action runs and skip actions with false verdicts * Update tests to reflect changes * Add test case to show how to use hook info
1 parent 94a747f commit 1c8196f

File tree

8 files changed

+233
-44
lines changed

8 files changed

+233
-44
lines changed

act/registry.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
type IRegistry interface {
1616
Add(policy *sdkAct.Policy)
17-
Apply(signals []sdkAct.Signal) []*sdkAct.Output
17+
Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output
1818
Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError)
1919
}
2020

@@ -107,11 +107,11 @@ func (r *Registry) Add(policy *sdkAct.Policy) {
107107
}
108108

109109
// Apply applies the signals to the registry and returns the outputs.
110-
func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
110+
func (r *Registry) Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output {
111111
// If there are no signals, apply the default policy.
112112
if len(signals) == 0 {
113113
r.Logger.Debug().Msg("No signals provided, applying default signal")
114-
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
114+
return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook)
115115
}
116116

117117
// Separate terminal and non-terminal signals to find contradictions.
@@ -139,7 +139,7 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
139139
}
140140

141141
// Apply the signal and append the output to the list of outputs.
142-
output, err := r.apply(signal)
142+
output, err := r.apply(signal, hook)
143143
if err != nil {
144144
r.Logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal")
145145
// If there is an error evaluating the policy, continue to the next signal.
@@ -155,14 +155,16 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
155155
}
156156

157157
if len(outputs) == 0 && !evalErr {
158-
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
158+
return r.Apply([]sdkAct.Signal{*r.DefaultSignal}, hook)
159159
}
160160

161161
return outputs
162162
}
163163

164164
// apply applies the signal to the registry and returns the output.
165-
func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDError) {
165+
func (r *Registry) apply(
166+
signal sdkAct.Signal, hook sdkAct.Hook,
167+
) (*sdkAct.Output, *gerr.GatewayDError) {
166168
action, exists := r.Actions[signal.Name]
167169
if !exists {
168170
return nil, gerr.ErrActionNotMatched
@@ -178,12 +180,12 @@ func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDEr
178180
defer cancel()
179181

180182
// Evaluate the policy.
181-
// TODO: Policy should be able to receive other parameters like server and client IPs, etc.
182183
verdict, err := policy.Eval(
183184
ctx, sdkAct.Input{
184185
Name: signal.Name,
185186
Policy: policy.Metadata,
186187
Signal: signal.Metadata,
188+
Hook: hook,
187189
// Action dictates the sync mode, not the signal.
188190
Sync: action.Sync,
189191
},

act/registry_test.go

Lines changed: 169 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,17 @@ func Test_Apply(t *testing.T) {
196196
})
197197
assert.NotNil(t, actRegistry)
198198

199-
outputs := actRegistry.Apply([]sdkAct.Signal{
200-
*sdkAct.Passthrough(),
201-
})
199+
outputs := actRegistry.Apply(
200+
[]sdkAct.Signal{
201+
*sdkAct.Passthrough(),
202+
},
203+
sdkAct.Hook{
204+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
205+
Priority: 1000,
206+
Params: map[string]any{},
207+
Result: map[string]any{},
208+
},
209+
)
202210
assert.NotNil(t, outputs)
203211
assert.Len(t, outputs, 1)
204212
assert.Equal(t, "passthrough", outputs[0].MatchedPolicy)
@@ -225,7 +233,15 @@ func Test_Apply_NoSignals(t *testing.T) {
225233
})
226234
assert.NotNil(t, actRegistry)
227235

228-
outputs := actRegistry.Apply([]sdkAct.Signal{})
236+
outputs := actRegistry.Apply(
237+
[]sdkAct.Signal{},
238+
sdkAct.Hook{
239+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
240+
Priority: 1000,
241+
Params: map[string]any{},
242+
Result: map[string]any{},
243+
},
244+
)
229245
assert.NotNil(t, outputs)
230246
assert.Len(t, outputs, 1)
231247
assert.Equal(t, "passthrough", outputs[0].MatchedPolicy)
@@ -272,7 +288,12 @@ func Test_Apply_ContradictorySignals(t *testing.T) {
272288
assert.NotNil(t, actRegistry)
273289

274290
for _, s := range signals {
275-
outputs := actRegistry.Apply(s)
291+
outputs := actRegistry.Apply(s, sdkAct.Hook{
292+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
293+
Priority: 1000,
294+
Params: map[string]any{},
295+
Result: map[string]any{},
296+
})
276297
assert.NotNil(t, outputs)
277298
assert.Len(t, outputs, 2)
278299
assert.Equal(t, "terminate", outputs[0].MatchedPolicy)
@@ -318,6 +339,11 @@ func Test_Apply_ActionNotMatched(t *testing.T) {
318339

319340
outputs := actRegistry.Apply([]sdkAct.Signal{
320341
{Name: "non-existent"},
342+
}, sdkAct.Hook{
343+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
344+
Priority: 1000,
345+
Params: map[string]any{},
346+
Result: map[string]any{},
321347
})
322348
assert.NotNil(t, outputs)
323349
assert.Len(t, outputs, 1)
@@ -351,6 +377,11 @@ func Test_Apply_PolicyNotMatched(t *testing.T) {
351377

352378
outputs := actRegistry.Apply([]sdkAct.Signal{
353379
*sdkAct.Terminate(),
380+
}, sdkAct.Hook{
381+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
382+
Priority: 1000,
383+
Params: map[string]any{},
384+
Result: map[string]any{},
354385
})
355386
assert.NotNil(t, outputs)
356387
assert.Len(t, outputs, 1)
@@ -399,6 +430,11 @@ func Test_Apply_NonBoolPolicy(t *testing.T) {
399430

400431
outputs := actRegistry.Apply([]sdkAct.Signal{
401432
*sdkAct.Passthrough(),
433+
}, sdkAct.Hook{
434+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
435+
Priority: 1000,
436+
Params: map[string]any{},
437+
Result: map[string]any{},
402438
})
403439
assert.NotNil(t, outputs)
404440
assert.Len(t, outputs, 1)
@@ -447,6 +483,110 @@ func Test_Apply_BadPolicy(t *testing.T) {
447483
}
448484
}
449485

486+
// Test_Apply_Hook tests the Apply function of the act registry with a policy that
487+
// has the hook info and makes use of it.
488+
func Test_Apply_Hook(t *testing.T) {
489+
buf := bytes.Buffer{}
490+
logger := zerolog.New(&buf)
491+
492+
// Custom policy leveraging the hook info.
493+
policies := map[string]*sdkAct.Policy{
494+
"passthrough": sdkAct.MustNewPolicy(
495+
"passthrough",
496+
"true",
497+
nil,
498+
),
499+
"log": sdkAct.MustNewPolicy(
500+
"log",
501+
`Signal.log == true && Policy.log == "enabled" &&
502+
split(Hook.Params.client.remote, ":")[0] == "192.168.0.1"`,
503+
map[string]any{
504+
"log": "enabled",
505+
},
506+
),
507+
}
508+
509+
actRegistry := NewActRegistry(
510+
Registry{
511+
Signals: BuiltinSignals(),
512+
Policies: policies,
513+
Actions: BuiltinActions(),
514+
DefaultPolicyName: config.DefaultPolicy,
515+
PolicyTimeout: config.DefaultPolicyTimeout,
516+
DefaultActionTimeout: config.DefaultActionTimeout,
517+
Logger: logger,
518+
})
519+
assert.NotNil(t, actRegistry)
520+
521+
hook := sdkAct.Hook{
522+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
523+
Priority: 1000,
524+
// Input parameters for the hook.
525+
Params: map[string]any{
526+
"field": "value",
527+
"server": map[string]any{
528+
"local": "value",
529+
"remote": "value",
530+
},
531+
"client": map[string]any{
532+
"local": "value",
533+
"remote": "192.168.0.1:15432",
534+
},
535+
"request": "Base64EncodedRequest",
536+
"error": "",
537+
},
538+
// Output parameters for the hook.
539+
Result: map[string]any{
540+
"field": "value",
541+
"server": map[string]any{
542+
"local": "value",
543+
"remote": "value",
544+
},
545+
"client": map[string]any{
546+
"local": "value",
547+
"remote": "value",
548+
},
549+
"request": "Base64EncodedRequest",
550+
"error": "",
551+
sdkAct.Signals: []any{
552+
sdkAct.Log("error", "error message", map[string]any{"key": "value"}).ToMap(),
553+
},
554+
"response": "Base64EncodedResponse",
555+
},
556+
}
557+
558+
outputs := actRegistry.Apply(
559+
[]sdkAct.Signal{
560+
*sdkAct.Log(
561+
"error",
562+
"policy matched from incoming address 192.168.0.1, so we are seeing this error message",
563+
map[string]any{"key": "value"},
564+
),
565+
},
566+
hook,
567+
)
568+
assert.NotNil(t, outputs)
569+
assert.Len(t, outputs, 1)
570+
assert.Equal(t, "log", outputs[0].MatchedPolicy)
571+
assert.Equal(t, outputs[0].Metadata, map[string]any{
572+
"key": "value",
573+
"level": "error",
574+
"log": true,
575+
"message": "policy matched from incoming address 192.168.0.1, so we are seeing this error message",
576+
})
577+
assert.False(t, outputs[0].Sync) // Asynchronous action.
578+
assert.True(t, cast.ToBool(outputs[0].Verdict))
579+
assert.False(t, outputs[0].Terminal)
580+
581+
result, err := actRegistry.Run(outputs[0], WithResult(hook.Result))
582+
assert.Equal(t, err, gerr.ErrAsyncAction, "expected async action sentinel error")
583+
assert.Nil(t, result, "expected nil result")
584+
585+
time.Sleep(time.Millisecond) // wait for async action to complete
586+
587+
assert.Contains(t, buf.String(), `{"level":"error","key":"value","message":"policy matched from incoming address 192.168.0.1, so we are seeing this error message"}`) //nolint:lll
588+
}
589+
450590
// Test_Run tests the Run function of the act registry with a non-terminal action.
451591
func Test_Run(t *testing.T) {
452592
logger := zerolog.Logger{}
@@ -464,6 +604,11 @@ func Test_Run(t *testing.T) {
464604

465605
outputs := actRegistry.Apply([]sdkAct.Signal{
466606
*sdkAct.Passthrough(),
607+
}, sdkAct.Hook{
608+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
609+
Priority: 1000,
610+
Params: map[string]any{},
611+
Result: map[string]any{},
467612
})
468613
assert.NotNil(t, outputs)
469614

@@ -489,6 +634,11 @@ func Test_Run_Terminate(t *testing.T) {
489634

490635
outputs := actRegistry.Apply([]sdkAct.Signal{
491636
*sdkAct.Terminate(),
637+
}, sdkAct.Hook{
638+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
639+
Priority: 1000,
640+
Params: map[string]any{},
641+
Result: map[string]any{},
492642
})
493643
assert.NotNil(t, outputs)
494644
assert.Equal(t, "terminate", outputs[0].MatchedPolicy)
@@ -522,6 +672,11 @@ func Test_Run_Async(t *testing.T) {
522672

523673
outputs := actRegistry.Apply([]sdkAct.Signal{
524674
*sdkAct.Log("info", "test", map[string]any{"async": true}),
675+
}, sdkAct.Hook{
676+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
677+
Priority: 1000,
678+
Params: map[string]any{},
679+
Result: map[string]any{},
525680
})
526681
assert.NotNil(t, outputs)
527682
assert.Equal(t, "log", outputs[0].MatchedPolicy)
@@ -647,7 +802,15 @@ func Test_Run_Timeout(t *testing.T) {
647802
})
648803
assert.NotNil(t, actRegistry)
649804

650-
outputs := actRegistry.Apply([]sdkAct.Signal{*signals[name]})
805+
outputs := actRegistry.Apply(
806+
[]sdkAct.Signal{*signals[name]},
807+
sdkAct.Hook{
808+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
809+
Priority: 1000,
810+
Params: map[string]any{},
811+
Result: map[string]any{},
812+
},
813+
)
651814
assert.NotNil(t, outputs)
652815
assert.Equal(t, name, outputs[0].MatchedPolicy)
653816
assert.Equal(t,

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ require (
88
github.com/codingsince1985/checksum v1.3.0
99
github.com/cybercyst/go-scaffold v0.0.0-20240404115540-744e601147cd
1010
github.com/envoyproxy/protoc-gen-validate v1.0.4
11-
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.13
11+
github.com/gatewayd-io/gatewayd-plugin-sdk v0.2.14
1212
github.com/getsentry/sentry-go v0.27.0
1313
github.com/go-co-op/gocron v1.37.0
1414
github.com/google/go-github/v53 v53.2.0

go.sum

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

network/proxy.go

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -856,31 +856,35 @@ func (pr *Proxy) shouldTerminate(result map[string]interface{}) (bool, map[strin
856856
// The Terminal field is only present if the action wants to terminate the request,
857857
// that is the `__terminal__` field is set in one of the outputs.
858858
keys := maps.Keys(result)
859-
if slices.Contains(keys, sdkAct.Terminal) {
860-
var actionResult map[string]interface{}
861-
for _, output := range outputs {
862-
actRes, err := pr.PluginRegistry.ActRegistry.Run(
863-
output, act.WithResult(result))
864-
// If the action is async and we received a sentinel error,
865-
// don't log the error.
866-
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
867-
pr.Logger.Error().Err(err).Msg("Error running policy")
868-
}
869-
// The terminate action should return a map.
870-
if v, ok := actRes.(map[string]interface{}); ok {
871-
actionResult = v
872-
}
859+
terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal])
860+
actionResult := make(map[string]interface{})
861+
for _, output := range outputs {
862+
if !cast.ToBool(output.Verdict) {
863+
pr.Logger.Debug().Msg(
864+
"Skipping the action, because the verdict of the policy execution is false")
865+
continue
866+
}
867+
actRes, err := pr.PluginRegistry.ActRegistry.Run(
868+
output, act.WithResult(result))
869+
// If the action is async and we received a sentinel error,
870+
// don't log the error.
871+
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
872+
pr.Logger.Error().Err(err).Msg("Error running policy")
873873
}
874+
// The terminate action should return a map.
875+
if v, ok := actRes.(map[string]interface{}); ok {
876+
actionResult = v
877+
}
878+
}
879+
if terminate {
874880
pr.Logger.Debug().Fields(
875881
map[string]interface{}{
876882
"function": "proxy.passthrough",
877883
"reason": "terminate",
878884
},
879885
).Msg("Terminating request")
880-
return cast.ToBool(result[sdkAct.Terminal]), actionResult
881886
}
882-
883-
return false, result
887+
return terminate, actionResult
884888
}
885889

886890
// getPluginModifiedRequest is a function that retrieves the modified request

0 commit comments

Comments
 (0)