From 8c8182d178c0efe1cf0a7fc41cc31827d8059710 Mon Sep 17 00:00:00 2001 From: Sean Cunningham Date: Wed, 11 Feb 2026 13:37:01 -0500 Subject: [PATCH 1/2] Support rules with a script node interior type. --- pkg/ast/ast.go | 71 ++++++++++-------- pkg/ast/ast_script.go | 95 ++++++++++++++++++++++++ pkg/ast/ast_test.go | 11 +++ pkg/parser/parse.go | 104 ++++++++++++++------------ pkg/parser/parse_test.go | 143 ++++++++++++++++++++++++++++++++++++ pkg/parser/tree.go | 155 +++++++++++++++++++++++++++++---------- pkg/schema/schema.go | 1 + pkg/testdata/rules.go | 98 +++++++++++++++++++++++++ 8 files changed, 567 insertions(+), 111 deletions(-) create mode 100644 pkg/ast/ast_script.go diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index bcdd637..b979d1a 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -106,12 +106,9 @@ func (b *builderT) descendTree(fn func() error) error { } func Build(data []byte) (*AstT, error) { - var ( - parseTree *parser.TreeT - err error - ) - if parseTree, err = parser.Parse(data); err != nil { + parseTree, err := parser.Parse(data) + if err != nil { log.Error().Any("err", err).Msg("Parser failed") return nil, err } @@ -157,33 +154,17 @@ func BuildTree(tree *parser.TreeT) (*AstT, error) { func (b *builderT) buildTree(parserNode *parser.NodeT, parentMachineAddress *AstNodeAddressT, termIdx *uint32) (*AstNodeT, error) { var ( - machineMatchNode *AstNodeT - matchNode *AstNodeT - children = make([]*AstNodeT, 0) - machineAddress = b.newAstNodeAddress(parserNode.Metadata.RuleHash, parserNode.Metadata.Type.String(), termIdx) - err error + machineAddress = b.newAstNodeAddress(parserNode.Metadata.RuleHash, parserNode.Metadata.Type.String(), termIdx) ) - // Build children (either matcher children or nested machines) - if parserNode.IsMatcherNode() { - if matchNode, err = b.buildMatcherChildren(parserNode, machineAddress, termIdx); err != nil { - return nil, err - } - children = append(children, matchNode) - } else if parserNode.IsPromNode() { - if matchNode, err = b.buildPromQLNode(parserNode, machineAddress, termIdx); err != nil { - return nil, err - } - children = append(children, matchNode) - - } else { - if children, err = b.buildMachineChildren(parserNode, machineAddress); err != nil { - return nil, err - } + children, err := b.buildChildrenNodes(parserNode, machineAddress, termIdx) + if err != nil { + return nil, err } // Build state machine after recursively building children - if machineMatchNode, err = b.buildStateMachine(parserNode, parentMachineAddress, machineAddress, children); err != nil { + machineMatchNode, err := b.buildStateMachine(parserNode, parentMachineAddress, machineAddress, children) + if err != nil { return nil, err } @@ -192,6 +173,35 @@ func (b *builderT) buildTree(parserNode *parser.NodeT, parentMachineAddress *Ast return machineMatchNode, nil } +func (b *builderT) buildChildrenNodes(parserNode *parser.NodeT, machineAddress *AstNodeAddressT, termIdx *uint32) (children []*AstNodeT, err error) { + + leaf, err := b.buildLeafChild(parserNode, machineAddress, termIdx) + + switch { + case err != nil: + return nil, err + case leaf != nil: + return []*AstNodeT{leaf}, nil + case parserNode.IsScriptNode(): + children, err = b.buildScriptChildren(parserNode, machineAddress) + default: + children, err = b.buildMachineChildren(parserNode, machineAddress) + } + + return children, err +} + +func (b *builderT) buildLeafChild(parserNode *parser.NodeT, machineAddress *AstNodeAddressT, termIdx *uint32) (leaf *AstNodeT, err error) { + + switch { + case parserNode.IsMatcherNode(): + leaf, err = b.buildMatcherChild(parserNode, machineAddress, termIdx) + case parserNode.IsPromNode(): + leaf, err = b.buildPromQLNode(parserNode, machineAddress, termIdx) + } + return +} + func (b *builderT) newAstNodeAddress(ruleHash, name string, termIdx *uint32) *AstNodeAddressT { var address = &AstNodeAddressT{ Version: "v" + strconv.FormatInt(int64(AstVersion), 10), @@ -220,7 +230,7 @@ func newAstNode(parserNode *parser.NodeT, typ schema.NodeTypeT, scope string, pa } } -func (b *builderT) buildMatcherChildren(parserNode *parser.NodeT, machineAddress *AstNodeAddressT, termIdx *uint32) (*AstNodeT, error) { +func (b *builderT) buildMatcherChild(parserNode *parser.NodeT, machineAddress *AstNodeAddressT, termIdx *uint32) (*AstNodeT, error) { var ( matchNode *AstNodeT @@ -360,7 +370,7 @@ func addNegateOpts(assert *AstNodeT, negateOpts *parser.NegateOptsT) { } } -func (b *builderT) buildStateMachine(parserNode *parser.NodeT, parentMachineAddress *AstNodeAddressT, machineAddress *AstNodeAddressT, children []*AstNodeT) (*AstNodeT, error) { +func (b *builderT) buildStateMachine(parserNode *parser.NodeT, parentMachineAddress, machineAddress *AstNodeAddressT, children []*AstNodeT) (*AstNodeT, error) { switch parserNode.Metadata.Type { case schema.NodeTypeSeq, schema.NodeTypeLogSeq: @@ -371,6 +381,9 @@ func (b *builderT) buildStateMachine(parserNode *parser.NodeT, parentMachineAddr return nil, parserNode.WrapError(ErrInvalidWindow) } case schema.NodeTypeSet, schema.NodeTypeLogSet, schema.NodeTypePromQL: + case schema.NodeTypeScript: + return b.buildScriptNode(parserNode, parentMachineAddress, machineAddress) + default: log.Error(). Any("address", machineAddress). diff --git a/pkg/ast/ast_script.go b/pkg/ast/ast_script.go new file mode 100644 index 0000000..a9614fb --- /dev/null +++ b/pkg/ast/ast_script.go @@ -0,0 +1,95 @@ +package ast + +import ( + "time" + + "github.com/prequel-dev/prequel-compiler/pkg/parser" + "github.com/prequel-dev/prequel-compiler/pkg/schema" + "github.com/rs/zerolog/log" +) + +type AstScriptT struct { + Code string + Language string + Timeout time.Duration +} + +// Build the child Ast nodes for the script. +// +// Script nodes are internal nodes with one input node. +// The parser node for a script contains a ScriptT struct as its first child, followed by one input node. +// Build the the child nodes for the script node by building each of the parser node's children; +// the first child is skipped since it is the script definition, and the remaining child is built as the input to the script node. + +func (b *builderT) buildScriptChildren(parserNode *parser.NodeT, machineAddress *AstNodeAddressT) ([]*AstNodeT, error) { + + if len(parserNode.Children) != 2 { + log.Error().Int("child_count", len(parserNode.Children)).Msg("Script node must have two children") + return nil, parserNode.WrapError(ErrInvalidNodeType) + } + + termIdx := uint32(1) + + child := parserNode.Children[1] + parserChildNode, ok := child.(*parser.NodeT) + if !ok { + log.Error().Any("child", child).Msg("Failed to build Script child node") + return nil, parserNode.WrapError(ErrInvalidNodeType) + } + + leaf, err := b.buildLeafChild(parserChildNode, machineAddress, &termIdx) + + switch { + case err != nil: + return nil, err + case leaf != nil: + return []*AstNodeT{leaf}, nil + default: + node, err := b.buildTree(parserChildNode, machineAddress, &termIdx) + if err != nil { + return nil, err + } + + return []*AstNodeT{node}, nil + } +} + +// Validate script definitions and build the script node. + +func (b *builderT) buildScriptNode(parserNode *parser.NodeT, parentMachineAddress, machineAddress *AstNodeAddressT) (*AstNodeT, error) { + + // Expects exactly two children, the first should be parser.ScriptT, the following is the script input node. + + if len(parserNode.Children) != 2 { + log.Error().Int("child_count", len(parserNode.Children)).Msg("Script node must have exactly two children") + return nil, parserNode.WrapError(ErrInvalidNodeType) + } + + scriptNode, ok := parserNode.Children[0].(*parser.ScriptT) + + if !ok { + log.Error().Any("script", parserNode.Children[0]).Msg("Failed to build Script node") + return nil, parserNode.WrapError(ErrMissingScalar) + } + + if scriptNode.Code == "" { + log.Error().Msg("Script code string is empty") + return nil, parserNode.WrapError(ErrMissingScalar) + } + + pn := &AstScriptT{ + Code: scriptNode.Code, + Language: scriptNode.Language, + } + + if scriptNode.Timeout != nil { + pn.Timeout = *scriptNode.Timeout + } + + var ( + node = newAstNode(parserNode, parserNode.Metadata.Type, schema.ScopeCluster, parentMachineAddress, machineAddress) + ) + + node.Object = pn + return node, nil +} diff --git a/pkg/ast/ast_test.go b/pkg/ast/ast_test.go index 58a5bf6..c9935b0 100644 --- a/pkg/ast/ast_test.go +++ b/pkg/ast/ast_test.go @@ -33,6 +33,9 @@ func gatherNodeAddresses(node *AstNodeT, out *[]string) { } *out = append(*out, node.Metadata.Address.String()) + for _, child := range node.Children { + gatherNodeAddresses(child, out) + } } func TestAstSuccess(t *testing.T) { @@ -73,6 +76,14 @@ func TestAstSuccess(t *testing.T) { rule: testdata.TestSuccessSimplePromQL, expectedNodeTypes: []string{"machine_set", "promql", "log_set"}, }, + "Success_ChildScript": { + rule: testdata.TestSuccessChildScript, + expectedNodeTypes: []string{"machine_seq", "script", "log_seq", "log_set"}, + }, + "Success_ChildScriptMultipleInputs": { + rule: testdata.TestSuccessChildScriptMultipleInputs, + expectedNodeTypes: []string{"machine_set", "script", "machine_seq", "log_seq", "log_set"}, + }, } for name, test := range tests { diff --git a/pkg/parser/parse.go b/pkg/parser/parse.go index b41abfc..e88de5a 100644 --- a/pkg/parser/parse.go +++ b/pkg/parser/parse.go @@ -5,6 +5,8 @@ import ( ) // Note that we prefer lower camel case like Kubernetes +// Also, have to keep the JSON tags although we are using YAML. +// The hash function uses JSON serialization, so the JSON tags are required to ensure consistent field names for hashing. const ( docRules = "rules" @@ -92,19 +94,6 @@ type ParseNegateOptsT struct { Absolute bool `yaml:"absolute,omitempty"` } -type ParseTermT struct { - Field string `yaml:"field,omitempty"` - StrValue string `yaml:"value,omitempty"` - JqValue string `yaml:"jq,omitempty"` - RegexValue string `yaml:"regex,omitempty"` - Count int `yaml:"count,omitempty"` - Set *ParseSetT `yaml:"set,omitempty"` - Sequence *ParseSequenceT `yaml:"sequence,omitempty"` - NegateOpts *ParseNegateOptsT `yaml:",inline,omitempty"` - PromQL *ParsePromQL `yaml:"promql,omitempty"` - Extract []ParseExtractT `yaml:"extract,omitempty"` -} - type ParseSetT struct { Window string `yaml:"window,omitempty"` Correlations []string `yaml:"correlations,omitempty"` @@ -126,23 +115,55 @@ type ParsePromQL struct { Event *ParseEventT `yaml:"event,omitempty"` } +type ParseScriptT struct { + Code string `yaml:"code"` + Language string `yaml:"language,omitempty"` // Assumes 'lua' if empty + Timeout string `yaml:"timeout,omitempty"` // Uses default if empty; expects duration string + Input *ParseTermT `yaml:"input"` // Required input +} + +type ParseEventT struct { + Source string `yaml:"source"` + Origin bool `yaml:"origin,omitempty" json:"origin,omitempty"` +} + +type ParseTermT struct { + Field string `yaml:"field,omitempty"` + StrValue string `yaml:"value,omitempty"` + JqValue string `yaml:"jq,omitempty"` + RegexValue string `yaml:"regex,omitempty"` + Count int `yaml:"count,omitempty"` + Set *ParseSetT `yaml:"set,omitempty"` + Sequence *ParseSequenceT `yaml:"sequence,omitempty"` + NegateOpts *ParseNegateOptsT `yaml:",inline,omitempty"` + PromQL *ParsePromQL `yaml:"promql,omitempty"` + Script *ParseScriptT `yaml:"script,omitempty"` + Extract []ParseExtractT `yaml:"extract,omitempty"` +} + func (o *ParseTermT) UnmarshalYAML(unmarshal func(any) error) error { + + // Try to unmarshal as a raw string first. + // If that fails, unmarshal as a struct. + // This allows for a shorthand syntax for simple match terms. var str string if err := unmarshal(&str); err == nil { o.StrValue = str return nil } + var temp struct { - Field string `yaml:"field,omitempty"` - StrValue string `yaml:"value,omitempty"` - JqValue string `yaml:"jq,omitempty"` - RegexValue string `yaml:"regex,omitempty"` - Count int `yaml:"count,omitempty"` - Set *ParseSetT `yaml:"set,omitempty"` - Sequence *ParseSequenceT `yaml:"sequence,omitempty"` - NegateOpts *ParseNegateOptsT `yaml:",inline,omitempty"` - ParsePromQL *ParsePromQL `yaml:"promql,omitempty"` - Extract []ParseExtractT `yaml:"extract,omitempty"` + Field string `yaml:"field"` + StrValue string `yaml:"value"` + JqValue string `yaml:"jq"` + RegexValue string `yaml:"regex"` + Count int `yaml:"count"` + Set *ParseSetT `yaml:"set"` + Sequence *ParseSequenceT `yaml:"sequence"` + NegateOpts *ParseNegateOptsT `yaml:",inline"` + ParsePromQL *ParsePromQL `yaml:"promql"` + Script *ParseScriptT `yaml:"script"` + Extract []ParseExtractT `yaml:"extract"` } if err := unmarshal(&temp); err != nil { return err @@ -156,22 +177,11 @@ func (o *ParseTermT) UnmarshalYAML(unmarshal func(any) error) error { o.Sequence = temp.Sequence o.NegateOpts = temp.NegateOpts o.PromQL = temp.ParsePromQL + o.Script = temp.Script o.Extract = temp.Extract return nil } -type ParseEventT struct { - Source string `yaml:"source"` - Origin bool `yaml:"origin,omitempty" json:"origin,omitempty"` -} - -type RulesT struct { - Rules []ParseRuleT `yaml:"rules"` - Root *yaml.Node `yaml:"-"` - TermsT map[string]ParseTermT `yaml:"terms,omitempty"` - TermsY map[string]*yaml.Node `yaml:"-"` -} - func RootNode(data []byte) (*yaml.Node, error) { var root yaml.Node if err := yaml.Unmarshal(data, &root); err != nil { @@ -180,21 +190,25 @@ func RootNode(data []byte) (*yaml.Node, error) { return &root, nil } -func _parse(data []byte) (RulesT, *yaml.Node, error) { +type RulesT struct { + Rules []ParseRuleT `yaml:"rules"` + Root *yaml.Node `yaml:"-"` + TermsT map[string]ParseTermT `yaml:"terms,omitempty"` + TermsY map[string]*yaml.Node `yaml:"-"` +} - var ( - root yaml.Node - rules RulesT - err error - ) +func _parse(data []byte) (*RulesT, *yaml.Node, error) { - if err = yaml.Unmarshal(data, &root); err != nil { - return RulesT{}, nil, err + root, err := RootNode(data) + if err != nil { + return nil, nil, err } + var rules RulesT if err := root.Decode(&rules); err != nil { - return RulesT{}, nil, err + return nil, nil, err + } - return rules, &root, nil + return &rules, root, nil } diff --git a/pkg/parser/parse_test.go b/pkg/parser/parse_test.go index cc542e7..d70554c 100644 --- a/pkg/parser/parse_test.go +++ b/pkg/parser/parse_test.go @@ -75,6 +75,16 @@ func TestParseSuccess(t *testing.T) { expectedNodeTypes: []string{"machine_set", "promql", "log_set"}, expectedNegIndexes: []int{-1, -1, -1}, }, + "Success_ChildScript": { + rule: testdata.TestSuccessChildScript, + expectedNodeTypes: []string{"machine_seq", "script", "log_seq", "log_set"}, + expectedNegIndexes: []int{-1, -1, -1, -1}, + }, + "Success_ChildScriptMultipleInputs": { + rule: testdata.TestSuccessChildScriptMultipleInputs, + expectedNodeTypes: []string{"machine_set", "script", "machine_seq", "log_seq", "log_set"}, + expectedNegIndexes: []int{-1, -1, -1, -1, -1}, + }, } for name, test := range tests { @@ -217,6 +227,18 @@ func TestParseFail(t *testing.T) { col: 7, err: ErrInvalidRuleHash, }, + "Fail_ScriptRoot": { + rule: testdata.TestFailScriptRoot, + line: 10, + col: 7, + err: ErrNotSupported, + }, + "Fail_ScriptNoInput": { + rule: testdata.TestFailScriptNoInput, + line: 11, + col: 9, + err: ErrMissingInput, + }, } for name, test := range tests { @@ -246,6 +268,127 @@ func TestParseFail(t *testing.T) { } } +const stableRuleYaml = ` +rules: + - cre: + id: PREQUEL-2026-0004 + severity: 3 + title: ArgoCD Excessive Syncs + category: argocd-problems + author: Prequel + description: | + ArgoCD Reconciliation Storm + tags: + - argocd + - sync-loop + - prequel-v0.14+ + mitigation: + Remove "CreateNamespace=true" from applications involved in the sync loop reconciliation storm. + impact: | + The ArgoCD applications are in a sync loop, which means that they are being synced more than once per minute. This increases the load on the ArgoCD server and the Kubernetes cluster. + mitigationScore: 3 + impactScore: 4 + references: + - https://github.com/argoproj/argo-cd/issues/14666#issuecomment-1715538502 + - https://argo-cd.readthedocs.io/en/stable/operator-manual/reconcile/ + applications: + - name: "argocd" + processName: "argocd-application-controller" + processPath: "/app/argocd/argocd-application-controller" + containerName: "argocd-application-controller" + imageUrl: "quay.io/argoproj/argocd:v2.7.5" + repoUrl: "https://github.com/argoproj/argo-cd" + + metadata: + kind: custom + id: NRdyR6FoTTsziQRVrxFMv5 + gen: 1 + rule: + set: + event: + source: cre.kubernetes + correlations: + - appNamespace + - appName + window: 1200s + match: + - jq: | + (.message | test("Initiated automated sync to '.*'")) + and (.source.component == "argocd-application-controller") + extract: + - name: appNamespace + jq: .involvedObject.namespace + - name: appName + jq: .involvedObject.name + count: 3 + - jq: | + (.message | test("(Partial s|S)ync operation to .* succeeded")) + and (.source.component == "argocd-application-controller") + extract: + - name: appNamespace + jq: .involvedObject.namespace + - name: appName + jq: .involvedObject.name + count: 3 + - jq: | + (.message | test("Updated sync status: Synced -> OutOfSync")) + and (.source.component == "argocd-application-controller") + extract: + - name: appNamespace + jq: .involvedObject.namespace + - name: appName + jq: .involvedObject.name + count: 3 + - jq: | + (.message | test("Updated sync status: OutOfSync -> Synced")) + and (.source.component == "argocd-application-controller") + extract: + - name: appNamespace + jq: .involvedObject.namespace + - name: appName + jq: .involvedObject.name + count: 3 +` + +func TestStableHashStability(t *testing.T) { + // Use a stable rule from above for test. + ruleYaml := stableRuleYaml + + // Unmarshal YAML to ParseRuleT + rules, err := Unmarshal([]byte(ruleYaml)) + if err != nil { + t.Fatalf("Failed to unmarshal rule: %v", err) + } + if len(rules.Rules) == 0 { + t.Fatalf("No rules found in testdata") + } + rule := rules.Rules[0] + + // Compute stable hash + hash1, err := StableHash(rule) + if err != nil { + t.Fatalf("Failed to compute stable hash: %v", err) + } + + // Modify non-semantic metadata fields + rule.Metadata.Version = "v2.0.0" + rule.Metadata.Gen = 42 + + // Compute stable hash again + hash2, err := StableHash(rule) + if err != nil { + t.Fatalf("Failed to compute stable hash after metadata change: %v", err) + } + + if hash1 != hash2 { + t.Errorf("StableHash changed after non-semantic metadata update: %s != %s", hash1, hash2) + } + + if hash1 != "QFr5UWZMni8KYe4B7FkYg64p8CaRr6yeuynwDfPXjDj" { + t.Errorf("StableHash value changed unexpectedly: got %s, want %s", hash1, "QFr5UWZMni8KYe4B7FkYg64p8CaRr6yeuynwDfPXjDj") + } +} + func DumpErrorChain(err error) { i := 0 for err != nil { diff --git a/pkg/parser/tree.go b/pkg/parser/tree.go index 699c795..4721fdf 100644 --- a/pkg/parser/tree.go +++ b/pkg/parser/tree.go @@ -25,6 +25,8 @@ var ( ErrTermNotFound = errors.New("term not found") ErrMissingOrder = errors.New("'sequence' missing 'order'") ErrMissingMatch = errors.New("'set' missing 'match'") + ErrMissingInput = errors.New("'script' missing 'input'") + ErrInputType = errors.New("invalid 'script' input type") ErrInvalidWindow = errors.New("invalid 'window'") ErrTermsMapping = errors.New("'terms' must be a mapping") ErrDuplicateTerm = errors.New("duplicate term name") @@ -36,6 +38,7 @@ var ( ErrInvalidRuleHash = errors.New("invalid rule hash (must be base58)") ErrExtractName = errors.New("invalid extract name (alphanumeric and underscores only)") ErrInnerEvent = errors.New("invalid event on inner node") + ErrScriptLanguage = errors.New("invalid script language") ) var ( @@ -110,9 +113,15 @@ type PromQLT struct { Interval *time.Duration `json:"interval,omitempty"` } -// PromQLValidator validates a PromQL expression. -// Hook exposed to avoid importing promql dependencies in compiler. -var PromQLValidator = func(expr string) error { return nil } +type ScriptT struct { + Code string `json:"code"` + Language string `json:"language,omitempty"` + Timeout *time.Duration `json:"timeout,omitempty"` +} + +// Hooks exposed to avoid importing dependencies in compiler. +var PromQLValidator = func(expr string) error { return nil } // PromQLValidator validates a PromQL expression. +var LuaValidator = func(code string) error { return nil } // LuaValidator validates Lua script syntax. func newEvent(t *ParseEventT) *EventT { return &EventT{ @@ -248,6 +257,17 @@ func (node *NodeT) IsPromNode() bool { return allPromQL } +func (node *NodeT) IsScriptNode() bool { + if len(node.Children) != 2 { + return false + } + + // Expect first child to be a script definition and second child to be undefined term. + _, ok := node.Children[0].(*ScriptT) + + return ok +} + func seqNodeProps(node *NodeT, seq *ParseSequenceT, order bool, yn *yaml.Node) error { if !order { @@ -487,17 +507,13 @@ func buildChildren(parent *NodeT, tm map[string]ParseTermT, terms []ParseTermT, for _, term := range terms { var ( - node any - resolvedTerm ParseTermT - t = term - n = yn - ok bool - err error + t = term + n = yn ) if term.StrValue != "" { // If the term is not found in the terms map, then use as str value - if resolvedTerm, ok = tm[term.StrValue]; ok { + if resolvedTerm, ok := tm[term.StrValue]; ok { t = resolvedTerm if n, ok = termsY[term.StrValue]; !ok { return nil, parent.WrapError(ErrTermNotFound) @@ -509,12 +525,12 @@ func buildChildren(parent *NodeT, tm map[string]ParseTermT, terms []ParseTermT, } } - if node, err = nodeFromTerm(parent, tm, t, parentNegate, n, termsY); err != nil { + if node, err := nodeFromTerm(parent, tm, t, parentNegate, n, termsY); err != nil { return nil, err + } else { + children = append(children, node) } - children = append(children, node) - } return children, nil @@ -582,6 +598,9 @@ func nodeFromTerm(parent *NodeT, termsT map[string]ParseTermT, term ParseTermT, case term.PromQL != nil: return nodeFromProm(parent, term, yn) + case term.Script != nil: + return nodeFromScript(parent, termsT, term, yn, termsY) + case term.StrValue != "" || term.JqValue != "" || term.RegexValue != "": return parseValue(term, parentNegate) @@ -754,6 +773,80 @@ func nodeFromProm(parent *NodeT, term ParseTermT, yn *yaml.Node) (*NodeT, error) return node, nil } +// Script nodes are internal nodes with one input node. +// The first child in the resultant NodeT is always the script definition, the following child is the input node. +// The input node can be a matcher, promql, or another script node, but not a value term since values cannot be inputs to scripts. +func nodeFromScript(parent *NodeT, termsT map[string]ParseTermT, term ParseTermT, yn *yaml.Node, termsY map[string]*yaml.Node) (*NodeT, error) { + + var timeout *time.Duration + if term.Script.Timeout != "" { + dur, err := time.ParseDuration(term.Script.Timeout) + if err != nil { + return nil, err + } + timeout = &dur + } + + switch term.Script.Language { + case "", "lua": + if err := LuaValidator(term.Script.Code); err != nil { + return nil, err + } + default: + return nil, parent.WrapError(ErrScriptLanguage) + } + + // Create the script node with metadata from the parent rule. + node, err := initNode(parent.Metadata.RuleId, parent.Metadata.RuleHash, parent.Metadata.CreId, yn) + if err != nil { + return nil, parent.WrapError(err) + } + + // Script node requires one input. + // The input can be a sequence, a set, or a promql term, but not a value term since values cannot be inputs to scripts. + if term.Script.Input == nil { + return nil, parent.WrapError(ErrMissingInput) + } + + // Validator function: only allow terms that could be an input + allowTerm := func(t ParseTermT) bool { + switch { + case t.Sequence != nil: + case t.Set != nil: + case t.PromQL != nil: + case t.Script != nil: + default: + return false + } + return true + } + + // Validate that input is of an allowed type + if !allowTerm(*term.Script.Input) { + return nil, parent.WrapError(ErrInputType) + } + + childNode, err := nodeFromTerm(node, termsT, *term.Script.Input, false, yn, termsY) + switch { + case err != nil: + return nil, err + case childNode == nil: + return nil, parent.WrapError(ErrMissingInput) + } + + // Assign the script node type + node.Metadata.Type = schema.NodeTypeScript + + // Append the script definition as the first child, followed by the input. + node.Children = append(node.Children, &ScriptT{ + Code: term.Script.Code, + Language: term.Script.Language, + Timeout: timeout, + }, childNode) + + return node, nil +} + func parseValue(term ParseTermT, negate bool) (*MatcherT, error) { var ( @@ -804,17 +897,14 @@ func parseValue(term ParseTermT, negate bool) (*MatcherT, error) { } func ParseCres(data []byte) (map[string]ParseCreT, error) { - var ( - config RulesT - cres = make(map[string]ParseCreT) - err error - ) - if config, _, err = _parse(data); err != nil { + cfg, _, err := _parse(data) + if err != nil { return nil, err } - for _, rule := range config.Rules { + cres := make(map[string]ParseCreT, len(cfg.Rules)) + for _, rule := range cfg.Rules { cres[rule.Metadata.Hash] = rule.Cre } @@ -837,32 +927,23 @@ func Parse(data []byte, opts ...ParseOptT) (*TreeT, error) { func Unmarshal(data []byte) (*RulesT, error) { - var ( - docMap *yaml.Node - termsNode *yaml.Node - config RulesT - root *yaml.Node - ok bool - err error - ) - - if config, root, err = _parse(data); err != nil { + cfg, root, err := _parse(data) + if err != nil { return nil, err } - docMap = root.Content[0] + docMap := root.Content[0] - config.Root, ok = findChild(docMap, docRules) - if !ok { + var ok bool + if cfg.Root, ok = findChild(docMap, docRules); !ok { return nil, errors.New("rules not found") } - termsNode, ok = findChild(docMap, docTerms) - if ok { - config.TermsY = collectTermsY(termsNode) + if termsNode, ok := findChild(docMap, docTerms); ok { + cfg.TermsY = collectTermsY(termsNode) } - return &config, nil + return cfg, nil } func Hash(h string) string { diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 3c9275e..8845a20 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -15,6 +15,7 @@ const ( NodeTypeLogSeq NodeTypeT = "log_seq" NodeTypeLogSet NodeTypeT = "log_set" NodeTypePromQL NodeTypeT = "promql" + NodeTypeScript NodeTypeT = "script" ) func (t NodeTypeT) String() string { diff --git a/pkg/testdata/rules.go b/pkg/testdata/rules.go index cb302be..e9622c3 100644 --- a/pkg/testdata/rules.go +++ b/pkg/testdata/rules.go @@ -362,6 +362,104 @@ rules: jq: ".field1" ` +// We currently do not support script at the root level. +var TestFailScriptRoot = ` +rules: + - cre: + id: TestFailScriptRoot + metadata: + id: "J7uRQTGpGMyL1iFpssnBeS" + hash: "rdJLgqYgkEp8jg8Qks1qiq" + generation: 1 + rule: + script: + code: | + function process(ev) + print("Processing...") + end +` + +// Script requires an input, so this should fail validation. +var TestFailScriptNoInput = ` +rules: + - cre: + id: TestFailScriptNoInput + metadata: + id: "J7uRQTGpGMyL1iFpssnBeS" + hash: "rdJLgqYgkEp8jg8Qks1qiq" + generation: 1 + rule: + set: + match: + - script: + code: "function process(ev) print(\"Processing...\") end" +` + +var TestSuccessChildScript = ` +rules: + - cre: + id: TestSuccessChildScript + metadata: + id: "J7uRQTGpGMyL1iFpssnBeS" + hash: "rdJLgqYgkEp8jg8Qks1qiq" + generation: 1 + rule: + sequence: + window: 30s + order: + - script: + code: "function process(ev) print(\"Processing...\") end" + input: + sequence: + window: 10s + event: + source: kafka + origin: true + order: + - value: "term1" + - value: "term2" + - set: + event: + source: kafka + match: + - value: "term2" +` + +var TestSuccessChildScriptMultipleInputs = ` +--- +rules: + - cre: + id: TestSuccessChildScriptMultipleInputs + metadata: + id: J7uRQTGpGMyL1iFpssnBeS + hash: rdJLgqYgkEp8jg8Qks1qiq + generation: 1 + rule: + set: + match: + - script: + code: function process(ev) print("Processing...") end + input: + sequence: + window: 10s + order: + - sequence: + event: + source: kafka + origin: true + window: 10s + order: + - value: term1 + - value: term2 + - set: + event: + source: kafka + window: 10s + match: + - value: term3 + - value: term4 +` + /* Failure cases */ var TestFailTypo = ` # Line 1 starts here rules: From 9fdcada34551da69a7c047fea2ed518ba63ccbb0 Mon Sep 17 00:00:00 2001 From: Sean Cunningham Date: Wed, 25 Feb 2026 11:04:38 -0500 Subject: [PATCH 2/2] Address child depth issues with scripts. Add tree address validation to test. --- pkg/ast/ast.go | 14 +++++------ pkg/ast/ast_log.go | 4 ++++ pkg/ast/ast_metrics.go | 19 +++++++++++++++ pkg/ast/ast_script.go | 22 +++++++++++------- pkg/ast/ast_test.go | 50 ++++++++++++++++++++++++++++++++++++++++ pkg/parser/parse_test.go | 5 ++++ pkg/testdata/rules.go | 23 ++++++++++++++++++ 7 files changed, 121 insertions(+), 16 deletions(-) diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index b979d1a..846139a 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -197,7 +197,7 @@ func (b *builderT) buildLeafChild(parserNode *parser.NodeT, machineAddress *AstN case parserNode.IsMatcherNode(): leaf, err = b.buildMatcherChild(parserNode, machineAddress, termIdx) case parserNode.IsPromNode(): - leaf, err = b.buildPromQLNode(parserNode, machineAddress, termIdx) + leaf, err = b.buildPromQLChild(parserNode, machineAddress, termIdx) } return } @@ -248,9 +248,11 @@ func (b *builderT) buildMatcherChild(parserNode *parser.NodeT, machineAddress *A return nil, parserNode.WrapError(ErrInvalidEventType) } - // Implied that the root node has an origin event - b.OriginCnt++ - parserNode.Metadata.Event.Origin = true + // This appears to be a legacy hack to support rules that don't specify origin but have event sources. + // We should consider removing this and requiring explicit origin specification in the rules. + if b.CurrentDepth == 0 && !parserNode.Metadata.Event.Origin { + parserNode.Metadata.Event.Origin = true + } err = b.descendTree(func() error { if matchNode, err = b.buildMatcherNodes(parserNode, machineAddress, termIdx); err != nil { @@ -329,10 +331,6 @@ func (b *builderT) buildMachineChildren(parserNode *parser.NodeT, machineAddress // If the child has an event/data source, then it is not a state machine. Build it via buildMatcherNodes - if parserChildNode.Metadata.Event.Origin { - b.OriginCnt++ - } - if parserChildNode.Metadata.Event.Source == "" { log.Error(). Any("address", machineAddress). diff --git a/pkg/ast/ast_log.go b/pkg/ast/ast_log.go index ca8daba..f050775 100644 --- a/pkg/ast/ast_log.go +++ b/pkg/ast/ast_log.go @@ -152,6 +152,10 @@ func (b *builderT) doBuildLogMatcherNode(parserNode *parser.NodeT, machineAddres Correlations: parserNode.Metadata.Correlations, } + if parserNode.Metadata.Event.Origin { + b.OriginCnt++ + } + return matchNode, nil } diff --git a/pkg/ast/ast_metrics.go b/pkg/ast/ast_metrics.go index b381c4f..d28b2ae 100644 --- a/pkg/ast/ast_metrics.go +++ b/pkg/ast/ast_metrics.go @@ -15,6 +15,22 @@ type AstPromQL struct { Event *AstEventT } +func (b *builderT) buildPromQLChild(parserNode *parser.NodeT, machineAddress *AstNodeAddressT, termIdx *uint32) (*AstNodeT, error) { + var child *AstNodeT + + err := b.descendTree(func() error { + node, err := b.buildPromQLNode(parserNode, machineAddress, termIdx) + if err != nil { + return err + } + child = node + return nil + }) + + return child, err + +} + func (b *builderT) buildPromQLNode(parserNode *parser.NodeT, machineAddress *AstNodeAddressT, termIdx *uint32) (*AstNodeT, error) { // Expects one child of type ParsePromQL @@ -45,6 +61,9 @@ func (b *builderT) buildPromQLNode(parserNode *parser.NodeT, machineAddress *Ast Source: parserNode.Metadata.Event.Source, Origin: parserNode.Metadata.Event.Origin, } + if parserNode.Metadata.Event.Origin { + b.OriginCnt++ + } } if promNode.Interval != nil { diff --git a/pkg/ast/ast_script.go b/pkg/ast/ast_script.go index a9614fb..5bc6cfc 100644 --- a/pkg/ast/ast_script.go +++ b/pkg/ast/ast_script.go @@ -39,19 +39,25 @@ func (b *builderT) buildScriptChildren(parserNode *parser.NodeT, machineAddress leaf, err := b.buildLeafChild(parserChildNode, machineAddress, &termIdx) + var childList []*AstNodeT + switch { case err != nil: - return nil, err + // fallthrough case leaf != nil: - return []*AstNodeT{leaf}, nil + childList = []*AstNodeT{leaf} default: - node, err := b.buildTree(parserChildNode, machineAddress, &termIdx) - if err != nil { - return nil, err - } - - return []*AstNodeT{node}, nil + err = b.descendTree(func() error { + node, err := b.buildTree(parserChildNode, machineAddress, &termIdx) + if err != nil { + return err + } + childList = []*AstNodeT{node} + return nil + }) } + + return childList, err } // Validate script definitions and build the script node. diff --git a/pkg/ast/ast_test.go b/pkg/ast/ast_test.go index c9935b0..c08ee7b 100644 --- a/pkg/ast/ast_test.go +++ b/pkg/ast/ast_test.go @@ -84,6 +84,10 @@ func TestAstSuccess(t *testing.T) { rule: testdata.TestSuccessChildScriptMultipleInputs, expectedNodeTypes: []string{"machine_set", "script", "machine_seq", "log_seq", "log_set"}, }, + "Success_ChildScriptPromQLInput": { + rule: testdata.TestSuccessChildScriptPromQLInput, + expectedNodeTypes: []string{"machine_set", "script", "promql"}, + }, } for name, test := range tests { @@ -104,6 +108,10 @@ func TestAstSuccess(t *testing.T) { t.Fatalf("No nodes found in AST") } + if err = validateTree(ast.Nodes[0]); err != nil { + t.Fatalf("Error validating tree: %v", err) + } + var actualNodes []string gatherNodeTypes(ast.Nodes[0], &actualNodes) @@ -273,3 +281,45 @@ func TestFailureExamples(t *testing.T) { } } } + +// Validate the following invariants on the tree: +// 1. No duplicate addresses +// 2. Root node has no parent address +// 3. Node ids are unique +// 4. Depth is consistent with distance from root + +func validateTree(node *AstNodeT) error { + if node == nil { + return fmt.Errorf("Root node is nil") + } + + if node.Metadata.ParentAddress != nil { + return fmt.Errorf("Root node has parent address: %s", node.Metadata.ParentAddress.String()) + } + + return _validateTree(node, 0, make(map[uint32]struct{})) // start at depth 0 for root +} + +func _validateTree(node *AstNodeT, depth uint32, ids map[uint32]struct{}) error { + + if node == nil { + return nil + } + + if node.Metadata.Address.Depth != depth { + return fmt.Errorf("Node %s has depth %d, expected %d", node.Metadata.Address.String(), node.Metadata.Address.Depth, depth) + } + + if _, exists := ids[node.Metadata.Address.NodeId]; exists { + return fmt.Errorf("Duplicate node ID %d found", node.Metadata.Address.NodeId) + } + ids[node.Metadata.Address.NodeId] = struct{}{} + + for _, child := range node.Children { + if err := _validateTree(child, depth+1, ids); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/parser/parse_test.go b/pkg/parser/parse_test.go index d70554c..f6e70ad 100644 --- a/pkg/parser/parse_test.go +++ b/pkg/parser/parse_test.go @@ -85,6 +85,11 @@ func TestParseSuccess(t *testing.T) { expectedNodeTypes: []string{"machine_set", "script", "machine_seq", "log_seq", "log_set"}, expectedNegIndexes: []int{-1, -1, -1, -1, -1}, }, + "Success_ChildScriptPromQLInput": { + rule: testdata.TestSuccessChildScriptPromQLInput, + expectedNodeTypes: []string{"machine_set", "script", "promql"}, + expectedNegIndexes: []int{-1, -1, -1}, + }, } for name, test := range tests { diff --git a/pkg/testdata/rules.go b/pkg/testdata/rules.go index e9622c3..da55f61 100644 --- a/pkg/testdata/rules.go +++ b/pkg/testdata/rules.go @@ -460,6 +460,29 @@ rules: - value: term4 ` +var TestSuccessChildScriptPromQLInput = ` +rules: + - cre: + id: TestSuccessChildScriptPromQLInput + metadata: + id: "J7uRQTGpGMyL1iFpssnBeS" + hash: "rdJLgqYgkEp8jg8Qks1qiq" + generation: 1 + rule: + set: + window: 30s + match: + - script: + code: "function process(ev) print(\"Processing...\") end" + input: + promql: + event: + source: cre.metrics + origin: true + expr: 'sum(rate(http_requests_total[5m])) by (service)' + interval: 10s +` + /* Failure cases */ var TestFailTypo = ` # Line 1 starts here rules: