diff --git a/go.mod b/go.mod index aba71dee..d4fe0a68 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 github.com/apparentlymart/go-textseg v1.0.0 github.com/creachadair/jrpc2 v0.32.0 - github.com/expr-lang/expr v1.16.9 + github.com/expr-lang/expr v1.17.2 github.com/gertd/go-pluralize v0.2.1 github.com/google/go-cmp v0.7.0 github.com/hashicorp/go-version v1.7.0 diff --git a/go.sum b/go.sum index ec73fa66..f6b77a24 100644 --- a/go.sum +++ b/go.sum @@ -151,8 +151,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= -github.com/expr-lang/expr v1.16.9 h1:WUAzmR0JNI9JCiF0/ewwHB1gmcGw5wW7nWt8gc6PpCI= -github.com/expr-lang/expr v1.16.9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/expr-lang/expr v1.17.2 h1:o0A99O/Px+/DTjEnQiodAgOIK9PPxL8DtXhBRKC+Iso= +github.com/expr-lang/expr v1.17.2/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= diff --git a/vendor/github.com/expr-lang/expr/.gitattributes b/vendor/github.com/expr-lang/expr/.gitattributes new file mode 100644 index 00000000..efd30300 --- /dev/null +++ b/vendor/github.com/expr-lang/expr/.gitattributes @@ -0,0 +1 @@ +*\[generated\].go linguist-language=txt diff --git a/vendor/github.com/expr-lang/expr/README.md b/vendor/github.com/expr-lang/expr/README.md index 6c56c67b..c0778a84 100644 --- a/vendor/github.com/expr-lang/expr/README.md +++ b/vendor/github.com/expr-lang/expr/README.md @@ -162,9 +162,15 @@ func main() { * [Visually.io](https://visually.io) employs Expr as a business rule engine for its personalization targeting algorithm. * [Akvorado](https://github.com/akvorado/akvorado) utilizes Expr to classify exporters and interfaces in network flows. * [keda.sh](https://keda.sh) uses Expr to allow customization of its Kubernetes-based event-driven autoscaling. -* [Span Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products. +* [Span Digital](https://spandigital.com/) uses Expr in its Knowledge Management products. * [Xiaohongshu](https://www.xiaohongshu.com/) combining yaml with Expr for dynamically policies delivery. * [Melrōse](https://melrōse.org) uses Expr to implement its music programming language. +* [Tork](https://www.tork.run/) integrates Expr into its workflow execution. +* [Critical Moments](https://criticalmoments.io) uses Expr for its mobile realtime conditional targeting system. +* [WoodpeckerCI](https://woodpecker-ci.org) uses Expr for [filtering workflows/steps](https://woodpecker-ci.org/docs/usage/workflow-syntax#evaluate). +* [FastSchema](https://github.com/fastschema/fastschema) - A BaaS leveraging Expr for its customizable and dynamic Access Control system. +* [WunderGraph Cosmo](https://github.com/wundergraph/cosmo) - GraphQL Federeration Router uses Expr to customize Middleware behaviour +* [SOLO](https://solo.one) uses Expr interally to allow dynamic code execution with custom defined functions. [Add your company too](https://github.com/expr-lang/expr/edit/master/README.md) diff --git a/vendor/github.com/expr-lang/expr/SECURITY.md b/vendor/github.com/expr-lang/expr/SECURITY.md index 8d692a39..e18771f5 100644 --- a/vendor/github.com/expr-lang/expr/SECURITY.md +++ b/vendor/github.com/expr-lang/expr/SECURITY.md @@ -11,11 +11,8 @@ unless this is not possible or feasible with a reasonable effort. | Version | Supported | |---------|--------------------| -| 1.16 | :white_check_mark: | -| 1.15 | :white_check_mark: | -| 1.14 | :white_check_mark: | -| 1.13 | :white_check_mark: | -| < 1.13 | :x: | +| 1.x | :white_check_mark: | +| 0.x | :x: | ## Reporting a Vulnerability diff --git a/vendor/github.com/expr-lang/expr/ast/find.go b/vendor/github.com/expr-lang/expr/ast/find.go new file mode 100644 index 00000000..247ff6c0 --- /dev/null +++ b/vendor/github.com/expr-lang/expr/ast/find.go @@ -0,0 +1,18 @@ +package ast + +func Find(node Node, fn func(node Node) bool) Node { + v := &finder{fn: fn} + Walk(&node, v) + return v.node +} + +type finder struct { + node Node + fn func(node Node) bool +} + +func (f *finder) Visit(node *Node) { + if f.fn(*node) { + f.node = *node + } +} diff --git a/vendor/github.com/expr-lang/expr/ast/node.go b/vendor/github.com/expr-lang/expr/ast/node.go index 03e8cf62..02923ac5 100644 --- a/vendor/github.com/expr-lang/expr/ast/node.go +++ b/vendor/github.com/expr-lang/expr/ast/node.go @@ -3,13 +3,20 @@ package ast import ( "reflect" + "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/file" ) +var ( + anyType = reflect.TypeOf(new(any)).Elem() +) + // Node represents items of abstract syntax tree. type Node interface { Location() file.Location SetLocation(file.Location) + Nature() nature.Nature + SetNature(nature.Nature) Type() reflect.Type SetType(reflect.Type) String() string @@ -25,8 +32,8 @@ func Patch(node *Node, newNode Node) { // base is a base struct for all nodes. type base struct { - loc file.Location - nodeType reflect.Type + loc file.Location + nature nature.Nature } // Location returns the location of the node in the source code. @@ -39,14 +46,27 @@ func (n *base) SetLocation(loc file.Location) { n.loc = loc } +// Nature returns the nature of the node. +func (n *base) Nature() nature.Nature { + return n.nature +} + +// SetNature sets the nature of the node. +func (n *base) SetNature(nature nature.Nature) { + n.nature = nature +} + // Type returns the type of the node. func (n *base) Type() reflect.Type { - return n.nodeType + if n.nature.Type == nil { + return anyType + } + return n.nature.Type } // SetType sets the type of the node. func (n *base) SetType(t reflect.Type) { - n.nodeType = t + n.nature.Type = t } // NilNode represents nil. @@ -163,13 +183,13 @@ type BuiltinNode struct { Map Node // Used by optimizer to fold filter() and map() builtins. } -// ClosureNode represents a predicate. +// PredicateNode represents a predicate. // Example: // // filter(foo, .bar == 1) // // The predicate is ".bar == 1". -type ClosureNode struct { +type PredicateNode struct { base Node Node // Node of the predicate body. } @@ -196,6 +216,13 @@ type VariableDeclaratorNode struct { Expr Node // Expression of the variable. Like "foo + 1" in "let foo = 1; foo + 1". } +// SequenceNode represents a sequence of nodes separated by semicolons. +// All nodes are executed, only the last node will be returned. +type SequenceNode struct { + base + Nodes []Node +} + // ArrayNode represents an array. type ArrayNode struct { base diff --git a/vendor/github.com/expr-lang/expr/ast/print.go b/vendor/github.com/expr-lang/expr/ast/print.go index 6a7d698a..e4e45f0f 100644 --- a/vendor/github.com/expr-lang/expr/ast/print.go +++ b/vendor/github.com/expr-lang/expr/ast/print.go @@ -45,13 +45,21 @@ func (n *ConstantNode) String() string { } func (n *UnaryNode) String() string { - op := "" + op := n.Operator if n.Operator == "not" { op = fmt.Sprintf("%s ", n.Operator) - } else { - op = fmt.Sprintf("%s", n.Operator) } - if _, ok := n.Node.(*BinaryNode); ok { + wrap := false + switch b := n.Node.(type) { + case *BinaryNode: + if operator.Binary[b.Operator].Precedence < + operator.Unary[n.Operator].Precedence { + wrap = true + } + case *ConditionalNode: + wrap = true + } + if wrap { return fmt.Sprintf("%s(%s)", op, n.Node.String()) } return fmt.Sprintf("%s%s", op, n.Node.String()) @@ -65,10 +73,21 @@ func (n *BinaryNode) String() string { var lhs, rhs string var lwrap, rwrap bool + if l, ok := n.Left.(*UnaryNode); ok { + if operator.Unary[l.Operator].Precedence < + operator.Binary[n.Operator].Precedence { + lwrap = true + } + } if lb, ok := n.Left.(*BinaryNode); ok { if operator.Less(lb.Operator, n.Operator) { lwrap = true } + if operator.Binary[lb.Operator].Precedence == + operator.Binary[n.Operator].Precedence && + operator.Binary[n.Operator].Associativity == operator.Right { + lwrap = true + } if lb.Operator == "??" { lwrap = true } @@ -80,6 +99,11 @@ func (n *BinaryNode) String() string { if operator.Less(rb.Operator, n.Operator) { rwrap = true } + if operator.Binary[rb.Operator].Precedence == + operator.Binary[n.Operator].Precedence && + operator.Binary[n.Operator].Associativity == operator.Left { + rwrap = true + } if operator.IsBoolean(rb.Operator) && n.Operator != rb.Operator { rwrap = true } @@ -162,7 +186,7 @@ func (n *BuiltinNode) String() string { return fmt.Sprintf("%s(%s)", n.Name, strings.Join(arguments, ", ")) } -func (n *ClosureNode) String() string { +func (n *PredicateNode) String() string { return n.Node.String() } @@ -174,6 +198,14 @@ func (n *VariableDeclaratorNode) String() string { return fmt.Sprintf("let %s = %s; %s", n.Name, n.Value.String(), n.Expr.String()) } +func (n *SequenceNode) String() string { + nodes := make([]string, len(n.Nodes)) + for i, node := range n.Nodes { + nodes[i] = node.String() + } + return strings.Join(nodes, "; ") +} + func (n *ConditionalNode) String() string { var cond, exp1, exp2 string if _, ok := n.Cond.(*ConditionalNode); ok { diff --git a/vendor/github.com/expr-lang/expr/ast/visitor.go b/vendor/github.com/expr-lang/expr/ast/visitor.go index 90bc9f1d..72cd6366 100644 --- a/vendor/github.com/expr-lang/expr/ast/visitor.go +++ b/vendor/github.com/expr-lang/expr/ast/visitor.go @@ -45,12 +45,16 @@ func Walk(node *Node, v Visitor) { for i := range n.Arguments { Walk(&n.Arguments[i], v) } - case *ClosureNode: + case *PredicateNode: Walk(&n.Node, v) case *PointerNode: case *VariableDeclaratorNode: Walk(&n.Value, v) Walk(&n.Expr, v) + case *SequenceNode: + for i := range n.Nodes { + Walk(&n.Nodes[i], v) + } case *ConditionalNode: Walk(&n.Cond, v) Walk(&n.Exp1, v) diff --git a/vendor/github.com/expr-lang/expr/builtin/builtin.go b/vendor/github.com/expr-lang/expr/builtin/builtin.go index cc6f197c..c23daf46 100644 --- a/vendor/github.com/expr-lang/expr/builtin/builtin.go +++ b/vendor/github.com/expr-lang/expr/builtin/builtin.go @@ -493,6 +493,9 @@ var Builtins = []*Function{ } return anyType, fmt.Errorf("invalid number of arguments (expected 0, got %d)", len(args)) }, + Deref: func(i int, arg reflect.Type) bool { + return false + }, }, { Name: "duration", @@ -567,6 +570,12 @@ var Builtins = []*Function{ } return timeType, nil }, + Deref: func(i int, arg reflect.Type) bool { + if arg.AssignableTo(locationType) { + return false + } + return true + }, }, { Name: "timezone", @@ -627,14 +636,7 @@ var Builtins = []*Function{ }, { Name: "get", - Func: func(args ...any) (out any, err error) { - defer func() { - if r := recover(); r != nil { - return - } - }() - return runtime.Fetch(args[0], args[1]), nil - }, + Func: get, }, { Name: "take", @@ -650,10 +652,13 @@ var Builtins = []*Function{ if !n.CanInt() { return nil, fmt.Errorf("cannot take %s elements", n.Kind()) } + to := 0 if n.Int() > int64(v.Len()) { - return args[0], nil + to = v.Len() + } else { + to = int(n.Int()) } - return v.Slice(0, int(n.Int())).Interface(), nil + return v.Slice(0, to).Interface(), nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { if len(args) != 2 { @@ -798,14 +803,14 @@ var Builtins = []*Function{ }, { Name: "reverse", - Func: func(args ...any) (any, error) { + Safe: func(args ...any) (any, uint, error) { if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + return nil, 0, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) } v := reflect.ValueOf(args[0]) if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot reverse %s", v.Kind()) + return nil, 0, fmt.Errorf("cannot reverse %s", v.Kind()) } size := v.Len() @@ -815,7 +820,7 @@ var Builtins = []*Function{ arr[i] = v.Index(size - i - 1).Interface() } - return arr, nil + return arr, uint(size), nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { @@ -830,6 +835,57 @@ var Builtins = []*Function{ } }, }, + + { + Name: "uniq", + Func: func(args ...any) (any, error) { + if len(args) != 1 { + return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + } + + v := reflect.ValueOf(args[0]) + if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { + return nil, fmt.Errorf("cannot uniq %s", v.Kind()) + } + + size := v.Len() + ret := []any{} + + eq := func(i int) bool { + for _, r := range ret { + if runtime.Equal(v.Index(i).Interface(), r) { + return true + } + } + + return false + } + + for i := 0; i < size; i += 1 { + if eq(i) { + continue + } + + ret = append(ret, v.Index(i).Interface()) + } + + return ret, nil + }, + + Validate: func(args []reflect.Type) (reflect.Type, error) { + if len(args) != 1 { + return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + } + + switch kind(args[0]) { + case reflect.Interface, reflect.Slice, reflect.Array: + return arrayType, nil + default: + return anyType, fmt.Errorf("cannot uniq %s", args[0]) + } + }, + }, + { Name: "concat", Safe: func(args ...any) (any, uint, error) { @@ -841,7 +897,7 @@ var Builtins = []*Function{ var arr []any for _, arg := range args { - v := reflect.ValueOf(deref.Deref(arg)) + v := reflect.ValueOf(arg) if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { return nil, 0, fmt.Errorf("cannot concat %s", v.Kind()) @@ -863,7 +919,7 @@ var Builtins = []*Function{ } for _, arg := range args { - switch kind(deref.Type(arg)) { + switch kind(arg) { case reflect.Interface, reflect.Slice, reflect.Array: default: return anyType, fmt.Errorf("cannot concat %s", arg) @@ -873,6 +929,37 @@ var Builtins = []*Function{ return arrayType, nil }, }, + { + Name: "flatten", + Safe: func(args ...any) (any, uint, error) { + var size uint + if len(args) != 1 { + return nil, 0, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + } + v := reflect.ValueOf(args[0]) + if v.Kind() != reflect.Array && v.Kind() != reflect.Slice { + return nil, size, fmt.Errorf("cannot flatten %s", v.Kind()) + } + ret := flatten(v) + size = uint(len(ret)) + return ret, size, nil + }, + Validate: func(args []reflect.Type) (reflect.Type, error) { + if len(args) != 1 { + return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + } + + for _, arg := range args { + switch kind(arg) { + case reflect.Interface, reflect.Slice, reflect.Array: + default: + return anyType, fmt.Errorf("cannot flatten %s", arg) + } + } + + return arrayType, nil + }, + }, { Name: "sort", Safe: func(args ...any) (any, uint, error) { diff --git a/vendor/github.com/expr-lang/expr/builtin/function.go b/vendor/github.com/expr-lang/expr/builtin/function.go index d4d78b1c..6634ac3f 100644 --- a/vendor/github.com/expr-lang/expr/builtin/function.go +++ b/vendor/github.com/expr-lang/expr/builtin/function.go @@ -11,6 +11,7 @@ type Function struct { Safe func(args ...any) (any, uint, error) Types []reflect.Type Validate func(args []reflect.Type) (reflect.Type, error) + Deref func(i int, arg reflect.Type) bool Predicate bool } diff --git a/vendor/github.com/expr-lang/expr/builtin/lib.go b/vendor/github.com/expr-lang/expr/builtin/lib.go index e3cd61b9..5a70a6b9 100644 --- a/vendor/github.com/expr-lang/expr/builtin/lib.go +++ b/vendor/github.com/expr-lang/expr/builtin/lib.go @@ -5,15 +5,19 @@ import ( "math" "reflect" "strconv" + "unicode/utf8" "github.com/expr-lang/expr/internal/deref" + "github.com/expr-lang/expr/vm/runtime" ) func Len(x any) any { v := reflect.ValueOf(x) switch v.Kind() { - case reflect.Array, reflect.Slice, reflect.Map, reflect.String: + case reflect.Array, reflect.Slice, reflect.Map: return v.Len() + case reflect.String: + return utf8.RuneCountInString(v.String()) default: panic(fmt.Sprintf("invalid argument for len (type %T)", x)) } @@ -24,15 +28,6 @@ func Type(arg any) any { return "nil" } v := reflect.ValueOf(arg) - for { - if v.Kind() == reflect.Ptr { - v = v.Elem() - } else if v.Kind() == reflect.Interface { - v = v.Elem() - } else { - break - } - } if v.Type().Name() != "" && v.Type().PkgPath() != "" { return fmt.Sprintf("%s.%s", v.Type().PkgPath(), v.Type().Name()) } @@ -261,7 +256,7 @@ func String(arg any) any { func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { var val any for _, arg := range args { - rv := reflect.ValueOf(deref.Deref(arg)) + rv := reflect.ValueOf(arg) switch rv.Kind() { case reflect.Array, reflect.Slice: size := rv.Len() @@ -304,7 +299,7 @@ func mean(args ...any) (int, float64, error) { var count int for _, arg := range args { - rv := reflect.ValueOf(deref.Deref(arg)) + rv := reflect.ValueOf(arg) switch rv.Kind() { case reflect.Array, reflect.Slice: size := rv.Len() @@ -336,7 +331,7 @@ func median(args ...any) ([]float64, error) { var values []float64 for _, arg := range args { - rv := reflect.ValueOf(deref.Deref(arg)) + rv := reflect.ValueOf(arg) switch rv.Kind() { case reflect.Array, reflect.Slice: size := rv.Len() @@ -359,3 +354,80 @@ func median(args ...any) ([]float64, error) { } return values, nil } + +func flatten(arg reflect.Value) []any { + ret := []any{} + for i := 0; i < arg.Len(); i++ { + v := deref.Value(arg.Index(i)) + if v.Kind() == reflect.Array || v.Kind() == reflect.Slice { + x := flatten(v) + ret = append(ret, x...) + } else { + ret = append(ret, v.Interface()) + } + } + return ret +} + +func get(params ...any) (out any, err error) { + from := params[0] + i := params[1] + v := reflect.ValueOf(from) + + if v.Kind() == reflect.Invalid { + panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) + } + + // Methods can be defined on any type. + if v.NumMethod() > 0 { + if methodName, ok := i.(string); ok { + method := v.MethodByName(methodName) + if method.IsValid() { + return method.Interface(), nil + } + } + } + + switch v.Kind() { + case reflect.Array, reflect.Slice, reflect.String: + index := runtime.ToInt(i) + l := v.Len() + if index < 0 { + index = l + index + } + if 0 <= index && index < l { + value := v.Index(index) + if value.IsValid() { + return value.Interface(), nil + } + } + + case reflect.Map: + var value reflect.Value + if i == nil { + value = v.MapIndex(reflect.Zero(v.Type().Key())) + } else { + value = v.MapIndex(reflect.ValueOf(i)) + } + if value.IsValid() { + return value.Interface(), nil + } + + case reflect.Struct: + fieldName := i.(string) + value := v.FieldByNameFunc(func(name string) bool { + field, _ := v.Type().FieldByName(name) + if field.Tag.Get("expr") == fieldName { + return true + } + return name == fieldName + }) + if value.IsValid() { + return value.Interface(), nil + } + } + + // Main difference from runtime.Fetch + // is that we return `nil` instead of panic. + return nil, nil +} diff --git a/vendor/github.com/expr-lang/expr/builtin/utils.go b/vendor/github.com/expr-lang/expr/builtin/utils.go index 29a95731..262bb379 100644 --- a/vendor/github.com/expr-lang/expr/builtin/utils.go +++ b/vendor/github.com/expr-lang/expr/builtin/utils.go @@ -4,6 +4,8 @@ import ( "fmt" "reflect" "time" + + "github.com/expr-lang/expr/internal/deref" ) var ( @@ -20,6 +22,7 @@ func kind(t reflect.Type) reflect.Kind { if t == nil { return reflect.Invalid } + t = deref.Type(t) return t.Kind() } diff --git a/vendor/github.com/expr-lang/expr/checker/checker.go b/vendor/github.com/expr-lang/expr/checker/checker.go index c71a98f0..f4923413 100644 --- a/vendor/github.com/expr-lang/expr/checker/checker.go +++ b/vendor/github.com/expr-lang/expr/checker/checker.go @@ -7,12 +7,46 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" + . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" - "github.com/expr-lang/expr/internal/deref" "github.com/expr-lang/expr/parser" ) +// Run visitors in a given config over the given tree +// runRepeatable controls whether to filter for only vistors that require multiple passes or not +func runVisitors(tree *parser.Tree, config *conf.Config, runRepeatable bool) { + for { + more := false + for _, v := range config.Visitors { + // We need to perform types check, because some visitors may rely on + // types information available in the tree. + _, _ = Check(tree, config) + + r, repeatable := v.(interface { + Reset() + ShouldRepeat() bool + }) + + if repeatable { + if runRepeatable { + r.Reset() + ast.Walk(&tree.Node, v) + more = more || r.ShouldRepeat() + } + } else { + if !runRepeatable { + ast.Walk(&tree.Node, v) + } + } + } + + if !more { + break + } + } +} + // ParseCheck parses input expression and checks its types. Also, it applies // all provided patchers. In case of error, it returns error with a tree. func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { @@ -22,25 +56,11 @@ func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { } if len(config.Visitors) > 0 { - for i := 0; i < 1000; i++ { - more := false - for _, v := range config.Visitors { - // We need to perform types check, because some visitors may rely on - // types information available in the tree. - _, _ = Check(tree, config) - - ast.Walk(&tree.Node, v) - - if v, ok := v.(interface { - ShouldRepeat() bool - }); ok { - more = more || v.ShouldRepeat() - } - } - if !more { - break - } - } + // Run all patchers that dont support being run repeatedly first + runVisitors(tree, config, false) + + // Run patchers that require multiple passes next (currently only Operator patching) + runVisitors(tree, config, true) } _, err = Check(tree, config) if err != nil { @@ -52,14 +72,20 @@ func ParseCheck(input string, config *conf.Config) (*parser.Tree, error) { // Check checks types of the expression tree. It returns type of the expression // and error if any. If config is nil, then default configuration will be used. -func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { +func Check(tree *parser.Tree, config *conf.Config) (reflect.Type, error) { if config == nil { config = conf.New(nil) } v := &checker{config: config} - t, _ = v.visit(tree.Node) + nt := v.visit(tree.Node) + + // To keep compatibility with previous versions, we should return any, if nature is unknown. + t := nt.Type + if t == nil { + t = anyType + } if v.err != nil { return t, v.err.Bind(tree.Source) @@ -67,23 +93,20 @@ func Check(tree *parser.Tree, config *conf.Config) (t reflect.Type, err error) { if v.config.Expect != reflect.Invalid { if v.config.ExpectAny { - if isAny(t) { + if isUnknown(nt) { return t, nil } } switch v.config.Expect { case reflect.Int, reflect.Int64, reflect.Float64: - if !isNumber(t) { - return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, t) + if !isNumber(nt) { + return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, nt) } default: - if t != nil { - if t.Kind() == v.config.Expect { - return t, nil - } + if nt.Kind() != v.config.Expect { + return nil, fmt.Errorf("expected %v, but got %s", v.config.Expect, nt) } - return nil, fmt.Errorf("expected %v, but got %v", v.config.Expect, t) } } @@ -98,14 +121,13 @@ type checker struct { } type predicateScope struct { - vtype reflect.Type - vars map[string]reflect.Type + collection Nature + vars map[string]Nature } type varScope struct { - name string - vtype reflect.Type - info info + name string + nature Nature } type info struct { @@ -119,285 +141,290 @@ type info struct { elem reflect.Type } -func (v *checker) visit(node ast.Node) (reflect.Type, info) { - var t reflect.Type - var i info +func (v *checker) visit(node ast.Node) Nature { + var nt Nature switch n := node.(type) { case *ast.NilNode: - t, i = v.NilNode(n) + nt = v.NilNode(n) case *ast.IdentifierNode: - t, i = v.IdentifierNode(n) + nt = v.IdentifierNode(n) case *ast.IntegerNode: - t, i = v.IntegerNode(n) + nt = v.IntegerNode(n) case *ast.FloatNode: - t, i = v.FloatNode(n) + nt = v.FloatNode(n) case *ast.BoolNode: - t, i = v.BoolNode(n) + nt = v.BoolNode(n) case *ast.StringNode: - t, i = v.StringNode(n) + nt = v.StringNode(n) case *ast.ConstantNode: - t, i = v.ConstantNode(n) + nt = v.ConstantNode(n) case *ast.UnaryNode: - t, i = v.UnaryNode(n) + nt = v.UnaryNode(n) case *ast.BinaryNode: - t, i = v.BinaryNode(n) + nt = v.BinaryNode(n) case *ast.ChainNode: - t, i = v.ChainNode(n) + nt = v.ChainNode(n) case *ast.MemberNode: - t, i = v.MemberNode(n) + nt = v.MemberNode(n) case *ast.SliceNode: - t, i = v.SliceNode(n) + nt = v.SliceNode(n) case *ast.CallNode: - t, i = v.CallNode(n) + nt = v.CallNode(n) case *ast.BuiltinNode: - t, i = v.BuiltinNode(n) - case *ast.ClosureNode: - t, i = v.ClosureNode(n) + nt = v.BuiltinNode(n) + case *ast.PredicateNode: + nt = v.PredicateNode(n) case *ast.PointerNode: - t, i = v.PointerNode(n) + nt = v.PointerNode(n) case *ast.VariableDeclaratorNode: - t, i = v.VariableDeclaratorNode(n) + nt = v.VariableDeclaratorNode(n) + case *ast.SequenceNode: + nt = v.SequenceNode(n) case *ast.ConditionalNode: - t, i = v.ConditionalNode(n) + nt = v.ConditionalNode(n) case *ast.ArrayNode: - t, i = v.ArrayNode(n) + nt = v.ArrayNode(n) case *ast.MapNode: - t, i = v.MapNode(n) + nt = v.MapNode(n) case *ast.PairNode: - t, i = v.PairNode(n) + nt = v.PairNode(n) default: panic(fmt.Sprintf("undefined node type (%T)", node)) } - node.SetType(t) - return t, i + node.SetNature(nt) + return nt } -func (v *checker) error(node ast.Node, format string, args ...any) (reflect.Type, info) { +func (v *checker) error(node ast.Node, format string, args ...any) Nature { if v.err == nil { // show first error v.err = &file.Error{ Location: node.Location(), Message: fmt.Sprintf(format, args...), } } - return anyType, info{} // interface represent undefined type + return unknown } -func (v *checker) NilNode(*ast.NilNode) (reflect.Type, info) { - return nilType, info{} +func (v *checker) NilNode(*ast.NilNode) Nature { + return nilNature } -func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info) { - if s, ok := v.lookupVariable(node.Value); ok { - return s.vtype, s.info +func (v *checker) IdentifierNode(node *ast.IdentifierNode) Nature { + if variable, ok := v.lookupVariable(node.Value); ok { + return variable.nature } if node.Value == "$env" { - return mapType, info{} + return unknown } - return v.ident(node, node.Value, true, true) + + return v.ident(node, node.Value, v.config.Strict, true) } // ident method returns type of environment variable, builtin or function. -func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (reflect.Type, info) { - if t, ok := v.config.Types[name]; ok { - if t.Ambiguous { - return v.error(node, "ambiguous identifier %v", name) - } - return t.Type, info{method: t.Method} +func (v *checker) ident(node ast.Node, name string, strict, builtins bool) Nature { + if nt, ok := v.config.Env.Get(name); ok { + return nt } if builtins { if fn, ok := v.config.Functions[name]; ok { - return fn.Type(), info{fn: fn} + return Nature{Type: fn.Type(), Func: fn} } if fn, ok := v.config.Builtins[name]; ok { - return fn.Type(), info{fn: fn} + return Nature{Type: fn.Type(), Func: fn} } } if v.config.Strict && strict { return v.error(node, "unknown name %v", name) } - if v.config.DefaultType != nil { - return v.config.DefaultType, info{} - } - return anyType, info{} + return unknown } -func (v *checker) IntegerNode(*ast.IntegerNode) (reflect.Type, info) { - return integerType, info{} +func (v *checker) IntegerNode(*ast.IntegerNode) Nature { + return integerNature } -func (v *checker) FloatNode(*ast.FloatNode) (reflect.Type, info) { - return floatType, info{} +func (v *checker) FloatNode(*ast.FloatNode) Nature { + return floatNature } -func (v *checker) BoolNode(*ast.BoolNode) (reflect.Type, info) { - return boolType, info{} +func (v *checker) BoolNode(*ast.BoolNode) Nature { + return boolNature } -func (v *checker) StringNode(*ast.StringNode) (reflect.Type, info) { - return stringType, info{} +func (v *checker) StringNode(*ast.StringNode) Nature { + return stringNature } -func (v *checker) ConstantNode(node *ast.ConstantNode) (reflect.Type, info) { - return reflect.TypeOf(node.Value), info{} +func (v *checker) ConstantNode(node *ast.ConstantNode) Nature { + return Nature{Type: reflect.TypeOf(node.Value)} } -func (v *checker) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) { - t, _ := v.visit(node.Node) - t = deref.Type(t) +func (v *checker) UnaryNode(node *ast.UnaryNode) Nature { + nt := v.visit(node.Node) + nt = nt.Deref() switch node.Operator { case "!", "not": - if isBool(t) { - return boolType, info{} + if isBool(nt) { + return boolNature } - if isAny(t) { - return boolType, info{} + if isUnknown(nt) { + return boolNature } case "+", "-": - if isNumber(t) { - return t, info{} + if isNumber(nt) { + return nt } - if isAny(t) { - return anyType, info{} + if isUnknown(nt) { + return unknown } default: return v.error(node, "unknown operator (%v)", node.Operator) } - return v.error(node, `invalid operation: %v (mismatched type %v)`, node.Operator, t) + return v.error(node, `invalid operation: %v (mismatched type %s)`, node.Operator, nt) } -func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { - l, _ := v.visit(node.Left) - r, ri := v.visit(node.Right) +func (v *checker) BinaryNode(node *ast.BinaryNode) Nature { + l := v.visit(node.Left) + r := v.visit(node.Right) - l = deref.Type(l) - r = deref.Type(r) + l = l.Deref() + r = r.Deref() switch node.Operator { case "==", "!=": if isComparable(l, r) { - return boolType, info{} + return boolNature } case "or", "||", "and", "&&": if isBool(l) && isBool(r) { - return boolType, info{} + return boolNature } if or(l, r, isBool) { - return boolType, info{} + return boolNature } case "<", ">", ">=", "<=": if isNumber(l) && isNumber(r) { - return boolType, info{} + return boolNature } if isString(l) && isString(r) { - return boolType, info{} + return boolNature } if isTime(l) && isTime(r) { - return boolType, info{} + return boolNature } - if or(l, r, isNumber, isString, isTime) { - return boolType, info{} + if isDuration(l) && isDuration(r) { + return boolNature + } + if or(l, r, isNumber, isString, isTime, isDuration) { + return boolNature } case "-": if isNumber(l) && isNumber(r) { - return combined(l, r), info{} + return combined(l, r) } if isTime(l) && isTime(r) { - return durationType, info{} + return durationNature } if isTime(l) && isDuration(r) { - return timeType, info{} + return timeNature + } + if isDuration(l) && isDuration(r) { + return durationNature } - if or(l, r, isNumber, isTime) { - return anyType, info{} + if or(l, r, isNumber, isTime, isDuration) { + return unknown } case "*": if isNumber(l) && isNumber(r) { - return combined(l, r), info{} + return combined(l, r) } - if or(l, r, isNumber) { - return anyType, info{} + if isNumber(l) && isDuration(r) { + return durationNature + } + if isDuration(l) && isNumber(r) { + return durationNature + } + if isDuration(l) && isDuration(r) { + return durationNature + } + if or(l, r, isNumber, isDuration) { + return unknown } case "/": if isNumber(l) && isNumber(r) { - return floatType, info{} + return floatNature } if or(l, r, isNumber) { - return floatType, info{} + return floatNature } case "**", "^": if isNumber(l) && isNumber(r) { - return floatType, info{} + return floatNature } if or(l, r, isNumber) { - return floatType, info{} + return floatNature } case "%": if isInteger(l) && isInteger(r) { - return combined(l, r), info{} + return integerNature } if or(l, r, isInteger) { - return anyType, info{} + return integerNature } case "+": if isNumber(l) && isNumber(r) { - return combined(l, r), info{} + return combined(l, r) } if isString(l) && isString(r) { - return stringType, info{} + return stringNature } if isTime(l) && isDuration(r) { - return timeType, info{} + return timeNature } if isDuration(l) && isTime(r) { - return timeType, info{} + return timeNature + } + if isDuration(l) && isDuration(r) { + return durationNature } if or(l, r, isNumber, isString, isTime, isDuration) { - return anyType, info{} + return unknown } case "in": - if (isString(l) || isAny(l)) && isStruct(r) { - return boolType, info{} + if (isString(l) || isUnknown(l)) && isStruct(r) { + return boolNature } if isMap(r) { - if l == nil { // It is possible to compare with nil. - return boolType, info{} - } - if !isAny(l) && !l.AssignableTo(r.Key()) { + if !isUnknown(l) && !l.AssignableTo(r.Key()) { return v.error(node, "cannot use %v as type %v in map key", l, r.Key()) } - return boolType, info{} + return boolNature } if isArray(r) { - if l == nil { // It is possible to compare with nil. - return boolType, info{} - } if !isComparable(l, r.Elem()) { return v.error(node, "cannot use %v as type %v in array", l, r.Elem()) } - if !isComparable(l, ri.elem) { - return v.error(node, "cannot use %v as type %v in array", l, ri.elem) - } - return boolType, info{} + return boolNature } - if isAny(l) && anyOf(r, isString, isArray, isMap) { - return boolType, info{} + if isUnknown(l) && anyOf(r, isString, isArray, isMap) { + return boolNature } - if isAny(r) { - return boolType, info{} + if isUnknown(r) { + return boolNature } case "matches": @@ -408,43 +435,42 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { } } if isString(l) && isString(r) { - return boolType, info{} + return boolNature } if or(l, r, isString) { - return boolType, info{} + return boolNature } case "contains", "startsWith", "endsWith": if isString(l) && isString(r) { - return boolType, info{} + return boolNature } if or(l, r, isString) { - return boolType, info{} + return boolNature } case "..": - ret := reflect.SliceOf(integerType) if isInteger(l) && isInteger(r) { - return ret, info{} + return arrayOf(integerNature) } if or(l, r, isInteger) { - return ret, info{} + return arrayOf(integerNature) } case "??": - if l == nil && r != nil { - return r, info{} + if isNil(l) && !isNil(r) { + return r } - if l != nil && r == nil { - return l, info{} + if !isNil(l) && isNil(r) { + return l } - if l == nil && r == nil { - return nilType, info{} + if isNil(l) && isNil(r) { + return nilNature } if r.AssignableTo(l) { - return l, info{} + return l } - return anyType, info{} + return unknown default: return v.error(node, "unknown operator (%v)", node.Operator) @@ -454,11 +480,11 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) { return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r) } -func (v *checker) ChainNode(node *ast.ChainNode) (reflect.Type, info) { +func (v *checker) ChainNode(node *ast.ChainNode) Nature { return v.visit(node.Node) } -func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { +func (v *checker) MemberNode(node *ast.MemberNode) Nature { // $env variable if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" { if name, ok := node.Property.(*ast.StringNode); ok { @@ -472,59 +498,55 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { } return v.ident(node, name.Value, strict, false /* no builtins and no functions */) } - return anyType, info{} + return unknown } - base, _ := v.visit(node.Node) - prop, _ := v.visit(node.Property) + base := v.visit(node.Node) + prop := v.visit(node.Property) + + if isUnknown(base) { + return unknown + } if name, ok := node.Property.(*ast.StringNode); ok { - if base == nil { - return v.error(node, "type %v has no field %v", base, name.Value) + if isNil(base) { + return v.error(node, "type nil has no field %v", name.Value) } + // First, check methods defined on base type itself, // independent of which type it is. Without dereferencing. if m, ok := base.MethodByName(name.Value); ok { - if kind(base) == reflect.Interface { - // In case of interface type method will not have a receiver, - // and to prevent checker decreasing numbers of in arguments - // return method type as not method (second argument is false). - - // Also, we can not use m.Index here, because it will be - // different indexes for different types which implement - // the same interface. - return m.Type, info{} - } else { - return m.Type, info{method: true} - } + return m } } - if kind(base) == reflect.Ptr { - base = base.Elem() - } - - switch kind(base) { - case reflect.Interface: - return anyType, info{} + base = base.Deref() + switch base.Kind() { case reflect.Map: - if prop != nil && !prop.AssignableTo(base.Key()) && !isAny(prop) { + if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { return v.error(node.Property, "cannot use %v to get an element from %v", prop, base) } - return base.Elem(), info{} + if prop, ok := node.Property.(*ast.StringNode); ok { + if field, ok := base.Fields[prop.Value]; ok { + return field + } else if base.Strict { + return v.error(node.Property, "unknown field %v", prop.Value) + } + } + return base.Elem() case reflect.Array, reflect.Slice: - if !isInteger(prop) && !isAny(prop) { + if !isInteger(prop) && !isUnknown(prop) { return v.error(node.Property, "array elements can only be selected using an integer (got %v)", prop) } - return base.Elem(), info{} + return base.Elem() case reflect.Struct: if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value - if field, ok := fetchField(base, propertyName); ok { - return field.Type, info{} + if field, ok := base.FieldByName(propertyName); ok { + return Nature{Type: field.Type} } if node.Method { return v.error(node, "type %v has no method %v", base, propertyName) @@ -533,38 +555,50 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { } } + // Not found. + + if name, ok := node.Property.(*ast.StringNode); ok { + if node.Method { + return v.error(node, "type %v has no method %v", base, name.Value) + } + return v.error(node, "type %v has no field %v", base, name.Value) + } return v.error(node, "type %v[%v] is undefined", base, prop) } -func (v *checker) SliceNode(node *ast.SliceNode) (reflect.Type, info) { - t, _ := v.visit(node.Node) +func (v *checker) SliceNode(node *ast.SliceNode) Nature { + nt := v.visit(node.Node) - switch kind(t) { - case reflect.Interface: - // ok + if isUnknown(nt) { + return unknown + } + + switch nt.Kind() { case reflect.String, reflect.Array, reflect.Slice: // ok default: - return v.error(node, "cannot slice %v", t) + return v.error(node, "cannot slice %s", nt) } if node.From != nil { - from, _ := v.visit(node.From) - if !isInteger(from) && !isAny(from) { + from := v.visit(node.From) + if !isInteger(from) && !isUnknown(from) { return v.error(node.From, "non-integer slice index %v", from) } } + if node.To != nil { - to, _ := v.visit(node.To) - if !isInteger(to) && !isAny(to) { + to := v.visit(node.To) + if !isInteger(to) && !isUnknown(to) { return v.error(node.To, "non-integer slice index %v", to) } } - return t, info{} + + return nt } -func (v *checker) CallNode(node *ast.CallNode) (reflect.Type, info) { - t, i := v.functionReturnType(node) +func (v *checker) CallNode(node *ast.CallNode) Nature { + nt := v.functionReturnType(node) // Check if type was set on node (for example, by patcher) // and use node type instead of function return type. @@ -578,17 +612,17 @@ func (v *checker) CallNode(node *ast.CallNode) (reflect.Type, info) { // checker pass we should replace anyType on method node // with new correct function return type. if node.Type() != nil && node.Type() != anyType { - return node.Type(), i + return node.Nature() } - return t, i + return nt } -func (v *checker) functionReturnType(node *ast.CallNode) (reflect.Type, info) { - fn, fnInfo := v.visit(node.Callee) +func (v *checker) functionReturnType(node *ast.CallNode) Nature { + nt := v.visit(node.Callee) - if fnInfo.fn != nil { - return v.checkFunction(fnInfo.fn, node, node.Arguments) + if nt.Func != nil { + return v.checkFunction(nt.Func, node, node.Arguments) } fnName := "function" @@ -601,240 +635,243 @@ func (v *checker) functionReturnType(node *ast.CallNode) (reflect.Type, info) { } } - if fn == nil { + if isUnknown(nt) { + return unknown + } + + if isNil(nt) { return v.error(node, "%v is nil; cannot call nil as function", fnName) } - switch fn.Kind() { - case reflect.Interface: - return anyType, info{} + switch nt.Kind() { case reflect.Func: - outType, err := v.checkArguments(fnName, fn, fnInfo.method, node.Arguments, node) + outType, err := v.checkArguments(fnName, nt, node.Arguments, node) if err != nil { if v.err == nil { v.err = err } - return anyType, info{} + return unknown } - return outType, info{} + return outType } - return v.error(node, "%v is not callable", fn) + return v.error(node, "%s is not callable", nt) } -func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { +func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature { switch node.Name { case "all", "none", "any", "one": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } - return boolType, info{} + return boolNature } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "filter": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } - if isAny(collection) { - return arrayType, info{} + if isUnknown(collection) { + return arrayNature } - return arrayType, info{} + return arrayOf(collection.Elem()) } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "map": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } - v.begin(collection, scopeVar{"index", integerType}) - closure, _ := v.visit(node.Arguments[1]) + v.begin(collection, scopeVar{"index", integerNature}) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - return arrayType, info{} + return arrayOf(*predicate.PredicateOut) } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "count": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } if len(node.Arguments) == 1 { - return integerType, info{} + return integerNature } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } - return integerType, info{} + return integerNature } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sum": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } if len(node.Arguments) == 2 { v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { - return closure.Out(0), info{} + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { + return predicate.Out(0) } } else { - if isAny(collection) { - return anyType, info{} + if isUnknown(collection) { + return unknown } - return collection.Elem(), info{} + return collection.Elem() } case "find", "findLast": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } - if isAny(collection) { - return anyType, info{} + if isUnknown(collection) { + return unknown } - return collection.Elem(), info{} + return collection.Elem() } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "findIndex", "findLastIndex": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - if !isBool(closure.Out(0)) && !isAny(closure.Out(0)) { - return v.error(node.Arguments[1], "predicate should return boolean (got %v)", closure.Out(0).String()) + if !isBool(predicate.Out(0)) && !isUnknown(predicate.Out(0)) { + return v.error(node.Arguments[1], "predicate should return boolean (got %v)", predicate.Out(0).String()) } - return integerType, info{} + return integerNature } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "groupBy": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - return reflect.TypeOf(map[any][]any{}), info{} + groups := arrayOf(collection.Elem()) + return Nature{Type: reflect.TypeOf(map[any][]any{}), ArrayOf: &groups} } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "sortBy": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } v.begin(collection) - closure, _ := v.visit(node.Arguments[1]) + predicate := v.visit(node.Arguments[1]) v.end() if len(node.Arguments) == 3 { - _, _ = v.visit(node.Arguments[2]) + _ = v.visit(node.Arguments[2]) } - if isFunc(closure) && - closure.NumOut() == 1 && - closure.NumIn() == 1 && isAny(closure.In(0)) { + if isFunc(predicate) && + predicate.NumOut() == 1 && + predicate.NumIn() == 1 && isUnknown(predicate.In(0)) { - return reflect.TypeOf([]any{}), info{} + return collection } return v.error(node.Arguments[1], "predicate should has one input and one output param") case "reduce": - collection, _ := v.visit(node.Arguments[0]) - if !isArray(collection) && !isAny(collection) { + collection := v.visit(node.Arguments[0]).Deref() + if !isArray(collection) && !isUnknown(collection) { return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) } - v.begin(collection, scopeVar{"index", integerType}, scopeVar{"acc", anyType}) - closure, _ := v.visit(node.Arguments[1]) + v.begin(collection, scopeVar{"index", integerNature}, scopeVar{"acc", unknown}) + predicate := v.visit(node.Arguments[1]) v.end() if len(node.Arguments) == 3 { - _, _ = v.visit(node.Arguments[2]) + _ = v.visit(node.Arguments[2]) } - if isFunc(closure) && closure.NumOut() == 1 { - return closure.Out(0), info{} + if isFunc(predicate) && predicate.NumOut() == 1 { + return *predicate.PredicateOut } return v.error(node.Arguments[1], "predicate should has two input and one output param") @@ -852,14 +889,14 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { } type scopeVar struct { - name string - vtype reflect.Type + varName string + varNature Nature } -func (v *checker) begin(vtype reflect.Type, vars ...scopeVar) { - scope := predicateScope{vtype: vtype, vars: make(map[string]reflect.Type)} +func (v *checker) begin(collectionNature Nature, vars ...scopeVar) { + scope := predicateScope{collection: collectionNature, vars: make(map[string]Nature)} for _, v := range vars { - scope.vars[v.name] = v.vtype + scope.vars[v.varName] = v.varNature } v.predicateScopes = append(v.predicateScopes, scope) } @@ -868,83 +905,91 @@ func (v *checker) end() { v.predicateScopes = v.predicateScopes[:len(v.predicateScopes)-1] } -func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) (reflect.Type, info) { +func (v *checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { if len(node.Arguments) != 2 { return v.error(node, "invalid number of arguments (expected 2, got %d)", len(node.Arguments)) } - val := node.Arguments[0] - prop := node.Arguments[1] - if id, ok := val.(*ast.IdentifierNode); ok && id.Value == "$env" { - if s, ok := prop.(*ast.StringNode); ok { - return v.config.Types[s.Value].Type, info{} + base := v.visit(node.Arguments[0]) + prop := v.visit(node.Arguments[1]) + + if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" { + if s, ok := node.Arguments[1].(*ast.StringNode); ok { + if nt, ok := v.config.Env.Get(s.Value); ok { + return nt + } } - return anyType, info{} + return unknown } - t, _ := v.visit(val) + if isUnknown(base) { + return unknown + } - switch kind(t) { - case reflect.Interface: - return anyType, info{} + switch base.Kind() { case reflect.Slice, reflect.Array: - p, _ := v.visit(prop) - if p == nil { - return v.error(prop, "cannot use nil as slice index") + if !isInteger(prop) && !isUnknown(prop) { + return v.error(node.Arguments[1], "non-integer slice index %s", prop) } - if !isInteger(p) && !isAny(p) { - return v.error(prop, "non-integer slice index %v", p) - } - return t.Elem(), info{} + return base.Elem() case reflect.Map: - p, _ := v.visit(prop) - if p == nil { - return v.error(prop, "cannot use nil as map index") - } - if !p.AssignableTo(t.Key()) && !isAny(p) { - return v.error(prop, "cannot use %v to get an element from %v", p, t) + if !prop.AssignableTo(base.Key()) && !isUnknown(prop) { + return v.error(node.Arguments[1], "cannot use %s to get an element from %s", prop, base) } - return t.Elem(), info{} + return base.Elem() } - return v.error(val, "type %v does not support indexing", t) + return v.error(node.Arguments[0], "type %v does not support indexing", base) } -func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) (reflect.Type, info) { +func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []ast.Node) Nature { if f.Validate != nil { args := make([]reflect.Type, len(arguments)) for i, arg := range arguments { - args[i], _ = v.visit(arg) + argNature := v.visit(arg) + if isUnknown(argNature) { + args[i] = anyType + } else { + args[i] = argNature.Type + } } t, err := f.Validate(args) if err != nil { return v.error(node, "%v", err) } - return t, info{} + return Nature{Type: t} } else if len(f.Types) == 0 { - t, err := v.checkArguments(f.Name, f.Type(), false, arguments, node) + nt, err := v.checkArguments(f.Name, Nature{Type: f.Type()}, arguments, node) if err != nil { if v.err == nil { v.err = err } - return anyType, info{} + return unknown } // No type was specified, so we assume the function returns any. - return t, info{} + return nt } var lastErr *file.Error for _, t := range f.Types { - outType, err := v.checkArguments(f.Name, t, false, arguments, node) + outNature, err := v.checkArguments(f.Name, Nature{Type: t}, arguments, node) if err != nil { lastErr = err continue } - return outType, info{} + + // As we found the correct function overload, we can stop the loop. + // Also, we need to set the correct nature of the callee so compiler, + // can correctly handle OpDeref opcode. + if callNode, ok := node.(*ast.CallNode); ok { + callNode.Callee.SetType(t) + } + + return outNature } if lastErr != nil { if v.err == nil { v.err = lastErr } - return anyType, info{} + return unknown } return v.error(node, "no matching overload for %v", f.Name) @@ -952,23 +997,22 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments [] func (v *checker) checkArguments( name string, - fn reflect.Type, - method bool, + fn Nature, arguments []ast.Node, node ast.Node, -) (reflect.Type, *file.Error) { - if isAny(fn) { - return anyType, nil +) (Nature, *file.Error) { + if isUnknown(fn) { + return unknown, nil } if fn.NumOut() == 0 { - return anyType, &file.Error{ + return unknown, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v doesn't return value", name), } } if numOut := fn.NumOut(); numOut > 2 { - return anyType, &file.Error{ + return unknown, &file.Error{ Location: node.Location(), Message: fmt.Sprintf("func %v returns more then two values", name), } @@ -977,12 +1021,12 @@ func (v *checker) checkArguments( // If func is method on an env, first argument should be a receiver, // and actual arguments less than fnNumIn by one. fnNumIn := fn.NumIn() - if method { + if fn.Method { // TODO: Move subtraction to the Nature.NumIn() and Nature.In() methods. fnNumIn-- } // Skip first argument in case of the receiver. fnInOffset := 0 - if method { + if fn.Method { fnInOffset = 1 } @@ -1013,15 +1057,15 @@ func (v *checker) checkArguments( // If we have an error, we should still visit all arguments to // type check them, as a patch can fix the error later. for _, arg := range arguments { - _, _ = v.visit(arg) + _ = v.visit(arg) } return fn.Out(0), err } for i, arg := range arguments { - t, _ := v.visit(arg) + argNature := v.visit(arg) - var in reflect.Type + var in Nature if fn.IsVariadic() && i >= fnNumIn-1 { // For variadic arguments fn(xs ...int), go replaces type of xs (int) with ([]int). // As we compare arguments one by one, we need underling type. @@ -1030,24 +1074,40 @@ func (v *checker) checkArguments( in = fn.In(i + fnInOffset) } - if isFloat(in) && isInteger(t) { + if isFloat(in) && isInteger(argNature) { traverseAndReplaceIntegerNodesWithFloatNodes(&arguments[i], in) continue } - if isInteger(in) && isInteger(t) && kind(t) != kind(in) { + if isInteger(in) && isInteger(argNature) && argNature.Kind() != in.Kind() { traverseAndReplaceIntegerNodesWithIntegerNodes(&arguments[i], in) continue } - if t == nil { - continue + if isNil(argNature) { + if in.Kind() == reflect.Ptr || in.Kind() == reflect.Interface { + continue + } + return unknown, &file.Error{ + Location: arg.Location(), + Message: fmt.Sprintf("cannot use nil as argument (type %s) to call %v", in, name), + } } - if !(t.AssignableTo(in) || deref.Type(t).AssignableTo(in)) && kind(t) != reflect.Interface { - return anyType, &file.Error{ + // Check if argument is assignable to the function input type. + // We check original type (like *time.Time), not dereferenced type, + // as function input type can be pointer to a struct. + assignable := argNature.AssignableTo(in) + + // We also need to check if dereference arg type is assignable to the function input type. + // For example, func(int) and argument *int. In this case we will add OpDeref to the argument, + // so we can call the function with *int argument. + assignable = assignable || argNature.Deref().AssignableTo(in) + + if !assignable && !isUnknown(argNature) { + return unknown, &file.Error{ Location: arg.Location(), - Message: fmt.Sprintf("cannot use %v as argument (type %v) to call %v ", t, in, name), + Message: fmt.Sprintf("cannot use %s as argument (type %s) to call %v ", argNature, in, name), } } } @@ -1055,75 +1115,82 @@ func (v *checker) checkArguments( return fn.Out(0), nil } -func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newType reflect.Type) { +func traverseAndReplaceIntegerNodesWithFloatNodes(node *ast.Node, newNature Nature) { switch (*node).(type) { case *ast.IntegerNode: *node = &ast.FloatNode{Value: float64((*node).(*ast.IntegerNode).Value)} - (*node).SetType(newType) + (*node).SetType(newNature.Type) case *ast.UnaryNode: unaryNode := (*node).(*ast.UnaryNode) - traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node, newType) + traverseAndReplaceIntegerNodesWithFloatNodes(&unaryNode.Node, newNature) case *ast.BinaryNode: binaryNode := (*node).(*ast.BinaryNode) switch binaryNode.Operator { case "+", "-", "*": - traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left, newType) - traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right, newType) + traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Left, newNature) + traverseAndReplaceIntegerNodesWithFloatNodes(&binaryNode.Right, newNature) } } } -func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newType reflect.Type) { +func traverseAndReplaceIntegerNodesWithIntegerNodes(node *ast.Node, newNature Nature) { switch (*node).(type) { case *ast.IntegerNode: - (*node).SetType(newType) + (*node).SetType(newNature.Type) case *ast.UnaryNode: - (*node).SetType(newType) + (*node).SetType(newNature.Type) unaryNode := (*node).(*ast.UnaryNode) - traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newType) + traverseAndReplaceIntegerNodesWithIntegerNodes(&unaryNode.Node, newNature) case *ast.BinaryNode: // TODO: Binary node return type is dependent on the type of the operands. We can't just change the type of the node. binaryNode := (*node).(*ast.BinaryNode) switch binaryNode.Operator { case "+", "-", "*": - traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Left, newType) - traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Right, newType) + traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Left, newNature) + traverseAndReplaceIntegerNodesWithIntegerNodes(&binaryNode.Right, newNature) } } } -func (v *checker) ClosureNode(node *ast.ClosureNode) (reflect.Type, info) { - t, _ := v.visit(node.Node) - if t == nil { - return v.error(node.Node, "closure cannot be nil") +func (v *checker) PredicateNode(node *ast.PredicateNode) Nature { + nt := v.visit(node.Node) + var out []reflect.Type + if isUnknown(nt) { + out = append(out, anyType) + } else if !isNil(nt) { + out = append(out, nt.Type) + } + return Nature{ + Type: reflect.FuncOf([]reflect.Type{anyType}, out, false), + PredicateOut: &nt, } - return reflect.FuncOf([]reflect.Type{anyType}, []reflect.Type{t}, false), info{} } -func (v *checker) PointerNode(node *ast.PointerNode) (reflect.Type, info) { +func (v *checker) PointerNode(node *ast.PointerNode) Nature { if len(v.predicateScopes) == 0 { - return v.error(node, "cannot use pointer accessor outside closure") + return v.error(node, "cannot use pointer accessor outside predicate") } scope := v.predicateScopes[len(v.predicateScopes)-1] if node.Name == "" { - switch scope.vtype.Kind() { - case reflect.Interface: - return anyType, info{} + if isUnknown(scope.collection) { + return unknown + } + switch scope.collection.Kind() { case reflect.Array, reflect.Slice: - return scope.vtype.Elem(), info{} + return scope.collection.Elem() } return v.error(node, "cannot use %v as array", scope) } if scope.vars != nil { if t, ok := scope.vars[node.Name]; ok { - return t, info{} + return t } } return v.error(node, "unknown pointer #%v", node.Name) } -func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) (reflect.Type, info) { - if _, ok := v.config.Types[node.Name]; ok { +func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) Nature { + if _, ok := v.config.Env.Get(node.Name); ok { return v.error(node, "cannot redeclare %v", node.Name) } if _, ok := v.config.Functions[node.Name]; ok { @@ -1135,11 +1202,22 @@ func (v *checker) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) (refl if _, ok := v.lookupVariable(node.Name); ok { return v.error(node, "cannot redeclare variable %v", node.Name) } - vtype, vinfo := v.visit(node.Value) - v.varScopes = append(v.varScopes, varScope{node.Name, vtype, vinfo}) - t, i := v.visit(node.Expr) + varNature := v.visit(node.Value) + v.varScopes = append(v.varScopes, varScope{node.Name, varNature}) + exprNature := v.visit(node.Expr) v.varScopes = v.varScopes[:len(v.varScopes)-1] - return t, i + return exprNature +} + +func (v *checker) SequenceNode(node *ast.SequenceNode) Nature { + if len(node.Nodes) == 0 { + return v.error(node, "empty sequence expression") + } + var last Nature + for _, node := range node.Nodes { + last = v.visit(node) + } + return last } func (v *checker) lookupVariable(name string) (varScope, bool) { @@ -1151,59 +1229,57 @@ func (v *checker) lookupVariable(name string) (varScope, bool) { return varScope{}, false } -func (v *checker) ConditionalNode(node *ast.ConditionalNode) (reflect.Type, info) { - c, _ := v.visit(node.Cond) - if !isBool(c) && !isAny(c) { +func (v *checker) ConditionalNode(node *ast.ConditionalNode) Nature { + c := v.visit(node.Cond) + if !isBool(c) && !isUnknown(c) { return v.error(node.Cond, "non-bool expression (type %v) used as condition", c) } - t1, _ := v.visit(node.Exp1) - t2, _ := v.visit(node.Exp2) + t1 := v.visit(node.Exp1) + t2 := v.visit(node.Exp2) - if t1 == nil && t2 != nil { - return t2, info{} + if isNil(t1) && !isNil(t2) { + return t2 } - if t1 != nil && t2 == nil { - return t1, info{} + if !isNil(t1) && isNil(t2) { + return t1 } - if t1 == nil && t2 == nil { - return nilType, info{} + if isNil(t1) && isNil(t2) { + return nilNature } if t1.AssignableTo(t2) { - return t1, info{} + return t1 } - return anyType, info{} + return unknown } -func (v *checker) ArrayNode(node *ast.ArrayNode) (reflect.Type, info) { - var prev reflect.Type +func (v *checker) ArrayNode(node *ast.ArrayNode) Nature { + var prev Nature allElementsAreSameType := true for i, node := range node.Nodes { - curr, _ := v.visit(node) + curr := v.visit(node) if i > 0 { - if curr == nil || prev == nil { - allElementsAreSameType = false - } else if curr.Kind() != prev.Kind() { + if curr.Kind() != prev.Kind() { allElementsAreSameType = false } } prev = curr } - if allElementsAreSameType && prev != nil { - return arrayType, info{elem: prev} + if allElementsAreSameType { + return arrayOf(prev) } - return arrayType, info{} + return arrayNature } -func (v *checker) MapNode(node *ast.MapNode) (reflect.Type, info) { +func (v *checker) MapNode(node *ast.MapNode) Nature { for _, pair := range node.Pairs { v.visit(pair) } - return mapType, info{} + return mapNature } -func (v *checker) PairNode(node *ast.PairNode) (reflect.Type, info) { +func (v *checker) PairNode(node *ast.PairNode) Nature { v.visit(node.Key) v.visit(node.Value) - return nilType, info{} + return nilNature } diff --git a/vendor/github.com/expr-lang/expr/checker/info.go b/vendor/github.com/expr-lang/expr/checker/info.go index 112bfab3..f1cc92eb 100644 --- a/vendor/github.com/expr-lang/expr/checker/info.go +++ b/vendor/github.com/expr-lang/expr/checker/info.go @@ -4,26 +4,26 @@ import ( "reflect" "github.com/expr-lang/expr/ast" - "github.com/expr-lang/expr/conf" + . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/vm" ) -func FieldIndex(types conf.TypesTable, node ast.Node) (bool, []int, string) { +func FieldIndex(env Nature, node ast.Node) (bool, []int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if t, ok := types[n.Value]; ok && len(t.FieldIndex) > 0 { - return true, t.FieldIndex, n.Value + if env.Kind() == reflect.Struct { + if field, ok := env.Get(n.Value); ok && len(field.FieldIndex) > 0 { + return true, field.FieldIndex, n.Value + } } case *ast.MemberNode: - base := n.Node.Type() - if kind(base) == reflect.Ptr { - base = base.Elem() - } - if kind(base) == reflect.Struct { + base := n.Node.Nature() + base = base.Deref() + if base.Kind() == reflect.Struct { if prop, ok := n.Property.(*ast.StringNode); ok { name := prop.Value - if field, ok := fetchField(base, name); ok { - return true, field.Index, name + if field, ok := base.FieldByName(name); ok { + return true, field.FieldIndex, name } } } @@ -31,11 +31,13 @@ func FieldIndex(types conf.TypesTable, node ast.Node) (bool, []int, string) { return false, nil, "" } -func MethodIndex(types conf.TypesTable, node ast.Node) (bool, int, string) { +func MethodIndex(env Nature, node ast.Node) (bool, int, string) { switch n := node.(type) { case *ast.IdentifierNode: - if t, ok := types[n.Value]; ok { - return t.Method, t.MethodIndex, n.Value + if env.Kind() == reflect.Struct { + if m, ok := env.Get(n.Value); ok { + return m.Method, m.MethodIndex, n.Value + } } case *ast.MemberNode: if name, ok := n.Property.(*ast.StringNode); ok { @@ -114,8 +116,7 @@ func IsFastFunc(fn reflect.Type, method bool) bool { if method { numIn = 2 } - if !isAny(fn) && - fn.IsVariadic() && + if fn.IsVariadic() && fn.NumIn() == numIn && fn.NumOut() == 1 && fn.Out(0).Kind() == reflect.Interface { diff --git a/vendor/github.com/expr-lang/expr/checker/nature/nature.go b/vendor/github.com/expr-lang/expr/checker/nature/nature.go new file mode 100644 index 00000000..993c9fcf --- /dev/null +++ b/vendor/github.com/expr-lang/expr/checker/nature/nature.go @@ -0,0 +1,261 @@ +package nature + +import ( + "reflect" + + "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/internal/deref" +) + +var ( + unknown = Nature{} +) + +type Nature struct { + Type reflect.Type // Type of the value. If nil, then value is unknown. + Func *builtin.Function // Used to pass function type from callee to CallNode. + ArrayOf *Nature // Elem nature of array type (usually Type is []any, but ArrayOf can be any nature). + PredicateOut *Nature // Out nature of predicate. + Fields map[string]Nature // Fields of map type. + DefaultMapValue *Nature // Default value of map type. + Strict bool // If map is types.StrictMap. + Nil bool // If value is nil. + Method bool // If value retrieved from method. Usually used to determine amount of in arguments. + MethodIndex int // Index of method in type. + FieldIndex []int // Index of field in type. +} + +func (n Nature) IsAny() bool { + return n.Kind() == reflect.Interface && n.NumMethods() == 0 +} + +func (n Nature) IsUnknown() bool { + switch { + case n.Type == nil && !n.Nil: + return true + case n.IsAny(): + return true + } + return false +} + +func (n Nature) String() string { + if n.Type != nil { + return n.Type.String() + } + return "unknown" +} + +func (n Nature) Deref() Nature { + if n.Type != nil { + n.Type = deref.Type(n.Type) + } + return n +} + +func (n Nature) Kind() reflect.Kind { + if n.Type != nil { + return n.Type.Kind() + } + return reflect.Invalid +} + +func (n Nature) Key() Nature { + if n.Kind() == reflect.Map { + return Nature{Type: n.Type.Key()} + } + return unknown +} + +func (n Nature) Elem() Nature { + switch n.Kind() { + case reflect.Ptr: + return Nature{Type: n.Type.Elem()} + case reflect.Map: + if n.DefaultMapValue != nil { + return *n.DefaultMapValue + } + return Nature{Type: n.Type.Elem()} + case reflect.Array, reflect.Slice: + if n.ArrayOf != nil { + return *n.ArrayOf + } + return Nature{Type: n.Type.Elem()} + } + return unknown +} + +func (n Nature) AssignableTo(nt Nature) bool { + if n.Nil { + // Untyped nil is assignable to any interface, but implements only the empty interface. + if nt.IsAny() { + return true + } + } + if n.Type == nil || nt.Type == nil { + return false + } + return n.Type.AssignableTo(nt.Type) +} + +func (n Nature) NumMethods() int { + if n.Type == nil { + return 0 + } + return n.Type.NumMethod() +} + +func (n Nature) MethodByName(name string) (Nature, bool) { + if n.Type == nil { + return unknown, false + } + method, ok := n.Type.MethodByName(name) + if !ok { + return unknown, false + } + + if n.Type.Kind() == reflect.Interface { + // In case of interface type method will not have a receiver, + // and to prevent checker decreasing numbers of in arguments + // return method type as not method (second argument is false). + + // Also, we can not use m.Index here, because it will be + // different indexes for different types which implement + // the same interface. + return Nature{Type: method.Type}, true + } else { + return Nature{ + Type: method.Type, + Method: true, + MethodIndex: method.Index, + }, true + } +} + +func (n Nature) NumIn() int { + if n.Type == nil { + return 0 + } + return n.Type.NumIn() +} + +func (n Nature) In(i int) Nature { + if n.Type == nil { + return unknown + } + return Nature{Type: n.Type.In(i)} +} + +func (n Nature) NumOut() int { + if n.Type == nil { + return 0 + } + return n.Type.NumOut() +} + +func (n Nature) Out(i int) Nature { + if n.Type == nil { + return unknown + } + return Nature{Type: n.Type.Out(i)} +} + +func (n Nature) IsVariadic() bool { + if n.Type == nil { + return false + } + return n.Type.IsVariadic() +} + +func (n Nature) FieldByName(name string) (Nature, bool) { + if n.Type == nil { + return unknown, false + } + field, ok := fetchField(n.Type, name) + return Nature{Type: field.Type, FieldIndex: field.Index}, ok +} + +func (n Nature) PkgPath() string { + if n.Type == nil { + return "" + } + return n.Type.PkgPath() +} + +func (n Nature) IsFastMap() bool { + if n.Type == nil { + return false + } + if n.Type.Kind() == reflect.Map && + n.Type.Key().Kind() == reflect.String && + n.Type.Elem().Kind() == reflect.Interface { + return true + } + return false +} + +func (n Nature) Get(name string) (Nature, bool) { + if n.Type == nil { + return unknown, false + } + + if m, ok := n.MethodByName(name); ok { + return m, true + } + + t := deref.Type(n.Type) + + switch t.Kind() { + case reflect.Struct: + if f, ok := fetchField(t, name); ok { + return Nature{ + Type: f.Type, + FieldIndex: f.Index, + }, true + } + case reflect.Map: + if f, ok := n.Fields[name]; ok { + return f, true + } + } + return unknown, false +} + +func (n Nature) All() map[string]Nature { + table := make(map[string]Nature) + + if n.Type == nil { + return table + } + + for i := 0; i < n.Type.NumMethod(); i++ { + method := n.Type.Method(i) + table[method.Name] = Nature{ + Type: method.Type, + Method: true, + MethodIndex: method.Index, + } + } + + t := deref.Type(n.Type) + + switch t.Kind() { + case reflect.Struct: + for name, nt := range StructFields(t) { + if _, ok := table[name]; ok { + continue + } + table[name] = nt + } + + case reflect.Map: + for key, nt := range n.Fields { + if _, ok := table[key]; ok { + continue + } + table[key] = nt + } + } + + return table +} diff --git a/vendor/github.com/expr-lang/expr/checker/nature/utils.go b/vendor/github.com/expr-lang/expr/checker/nature/utils.go new file mode 100644 index 00000000..c242f91a --- /dev/null +++ b/vendor/github.com/expr-lang/expr/checker/nature/utils.go @@ -0,0 +1,76 @@ +package nature + +import ( + "reflect" + + "github.com/expr-lang/expr/internal/deref" +) + +func fieldName(field reflect.StructField) string { + if taggedName := field.Tag.Get("expr"); taggedName != "" { + return taggedName + } + return field.Name +} + +func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { + // First check all structs fields. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + // Search all fields, even embedded structs. + if fieldName(field) == name { + return field, true + } + } + + // Second check fields of embedded structs. + for i := 0; i < t.NumField(); i++ { + anon := t.Field(i) + if anon.Anonymous { + anonType := anon.Type + if anonType.Kind() == reflect.Pointer { + anonType = anonType.Elem() + } + if field, ok := fetchField(anonType, name); ok { + field.Index = append(anon.Index, field.Index...) + return field, true + } + } + } + + return reflect.StructField{}, false +} + +func StructFields(t reflect.Type) map[string]Nature { + table := make(map[string]Nature) + + t = deref.Type(t) + if t == nil { + return table + } + + switch t.Kind() { + case reflect.Struct: + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + + if f.Anonymous { + for name, typ := range StructFields(f.Type) { + if _, ok := table[name]; ok { + continue + } + typ.FieldIndex = append(f.Index, typ.FieldIndex...) + table[name] = typ + } + } + + table[fieldName(f)] = Nature{ + Type: f.Type, + FieldIndex: f.Index, + } + + } + } + + return table +} diff --git a/vendor/github.com/expr-lang/expr/checker/types.go b/vendor/github.com/expr-lang/expr/checker/types.go index d10736a7..09896de5 100644 --- a/vendor/github.com/expr-lang/expr/checker/types.go +++ b/vendor/github.com/expr-lang/expr/checker/types.go @@ -4,207 +4,162 @@ import ( "reflect" "time" - "github.com/expr-lang/expr/conf" + . "github.com/expr-lang/expr/checker/nature" +) + +var ( + unknown = Nature{} + nilNature = Nature{Nil: true} + boolNature = Nature{Type: reflect.TypeOf(true)} + integerNature = Nature{Type: reflect.TypeOf(0)} + floatNature = Nature{Type: reflect.TypeOf(float64(0))} + stringNature = Nature{Type: reflect.TypeOf("")} + arrayNature = Nature{Type: reflect.TypeOf([]any{})} + mapNature = Nature{Type: reflect.TypeOf(map[string]any{})} + timeNature = Nature{Type: reflect.TypeOf(time.Time{})} + durationNature = Nature{Type: reflect.TypeOf(time.Duration(0))} ) var ( - nilType = reflect.TypeOf(nil) - boolType = reflect.TypeOf(true) - integerType = reflect.TypeOf(0) - floatType = reflect.TypeOf(float64(0)) - stringType = reflect.TypeOf("") - arrayType = reflect.TypeOf([]any{}) - mapType = reflect.TypeOf(map[string]any{}) anyType = reflect.TypeOf(new(any)).Elem() timeType = reflect.TypeOf(time.Time{}) durationType = reflect.TypeOf(time.Duration(0)) + arrayType = reflect.TypeOf([]any{}) ) -func combined(a, b reflect.Type) reflect.Type { - if a.Kind() == b.Kind() { - return a +func arrayOf(nt Nature) Nature { + return Nature{ + Type: arrayType, + ArrayOf: &nt, + } +} + +func isNil(nt Nature) bool { + return nt.Nil +} + +func combined(l, r Nature) Nature { + if isUnknown(l) || isUnknown(r) { + return unknown } - if isFloat(a) || isFloat(b) { - return floatType + if isFloat(l) || isFloat(r) { + return floatNature } - return integerType + return integerNature } -func anyOf(t reflect.Type, fns ...func(reflect.Type) bool) bool { +func anyOf(nt Nature, fns ...func(Nature) bool) bool { for _, fn := range fns { - if fn(t) { + if fn(nt) { return true } } return false } -func or(l, r reflect.Type, fns ...func(reflect.Type) bool) bool { - if isAny(l) && isAny(r) { +func or(l, r Nature, fns ...func(Nature) bool) bool { + if isUnknown(l) && isUnknown(r) { return true } - if isAny(l) && anyOf(r, fns...) { + if isUnknown(l) && anyOf(r, fns...) { return true } - if isAny(r) && anyOf(l, fns...) { + if isUnknown(r) && anyOf(l, fns...) { return true } return false } -func isAny(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Interface: - return true - } - } - return false +func isUnknown(nt Nature) bool { + return nt.IsUnknown() } -func isInteger(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fallthrough - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return true - } +func isInteger(nt Nature) bool { + switch nt.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fallthrough + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return nt.PkgPath() == "" } return false } -func isFloat(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Float32, reflect.Float64: - return true - } +func isFloat(nt Nature) bool { + switch nt.Kind() { + case reflect.Float32, reflect.Float64: + return nt.PkgPath() == "" } return false } -func isNumber(t reflect.Type) bool { - return isInteger(t) || isFloat(t) +func isNumber(nt Nature) bool { + return isInteger(nt) || isFloat(nt) } -func isTime(t reflect.Type) bool { - if t != nil { - switch t { - case timeType: - return true - } +func isTime(nt Nature) bool { + switch nt.Type { + case timeType: + return true } return false } -func isDuration(t reflect.Type) bool { - if t != nil { - switch t { - case durationType: - return true - } +func isDuration(nt Nature) bool { + switch nt.Type { + case durationType: + return true } return false } -func isBool(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Bool: - return true - } +func isBool(nt Nature) bool { + switch nt.Kind() { + case reflect.Bool: + return true } return false } -func isString(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.String: - return true - } +func isString(nt Nature) bool { + switch nt.Kind() { + case reflect.String: + return true } return false } -func isArray(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Ptr: - return isArray(t.Elem()) - case reflect.Slice, reflect.Array: - return true - } +func isArray(nt Nature) bool { + switch nt.Kind() { + case reflect.Slice, reflect.Array: + return true } return false } -func isMap(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Ptr: - return isMap(t.Elem()) - case reflect.Map: - return true - } +func isMap(nt Nature) bool { + switch nt.Kind() { + case reflect.Map: + return true } return false } -func isStruct(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Ptr: - return isStruct(t.Elem()) - case reflect.Struct: - return true - } +func isStruct(nt Nature) bool { + switch nt.Kind() { + case reflect.Struct: + return true } return false } -func isFunc(t reflect.Type) bool { - if t != nil { - switch t.Kind() { - case reflect.Ptr: - return isFunc(t.Elem()) - case reflect.Func: - return true - } +func isFunc(nt Nature) bool { + switch nt.Kind() { + case reflect.Func: + return true } return false } -func fetchField(t reflect.Type, name string) (reflect.StructField, bool) { - if t != nil { - // First check all structs fields. - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - // Search all fields, even embedded structs. - if conf.FieldName(field) == name { - return field, true - } - } - - // Second check fields of embedded structs. - for i := 0; i < t.NumField(); i++ { - anon := t.Field(i) - if anon.Anonymous { - anonType := anon.Type - if anonType.Kind() == reflect.Pointer { - anonType = anonType.Elem() - } - if field, ok := fetchField(anonType, name); ok { - field.Index = append(anon.Index, field.Index...) - return field, true - } - } - } - } - return reflect.StructField{}, false -} - func kind(t reflect.Type) reflect.Kind { if t == nil { return reflect.Invalid @@ -212,17 +167,24 @@ func kind(t reflect.Type) reflect.Kind { return t.Kind() } -func isComparable(l, r reflect.Type) bool { - if l == nil || r == nil { +func isComparable(l, r Nature) bool { + if isUnknown(l) || isUnknown(r) { return true } - switch { - case l.Kind() == r.Kind(): + if isNil(l) || isNil(r) { return true - case isNumber(l) && isNumber(r): + } + if isNumber(l) && isNumber(r) { return true - case isAny(l) || isAny(r): + } + if isDuration(l) && isDuration(r) { return true } - return false + if isTime(l) && isTime(r) { + return true + } + if isArray(l) && isArray(r) { + return true + } + return l.AssignableTo(r) } diff --git a/vendor/github.com/expr-lang/expr/compiler/compiler.go b/vendor/github.com/expr-lang/expr/compiler/compiler.go index 720f6a26..595355d2 100644 --- a/vendor/github.com/expr-lang/expr/compiler/compiler.go +++ b/vendor/github.com/expr-lang/expr/compiler/compiler.go @@ -9,6 +9,7 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/checker" + . "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/parser" @@ -259,12 +260,14 @@ func (c *compiler) compile(node ast.Node) { c.CallNode(n) case *ast.BuiltinNode: c.BuiltinNode(n) - case *ast.ClosureNode: - c.ClosureNode(n) + case *ast.PredicateNode: + c.PredicateNode(n) case *ast.PointerNode: c.PointerNode(n) case *ast.VariableDeclaratorNode: c.VariableDeclaratorNode(n) + case *ast.SequenceNode: + c.SequenceNode(n) case *ast.ConditionalNode: c.ConditionalNode(n) case *ast.ArrayNode: @@ -292,21 +295,19 @@ func (c *compiler) IdentifierNode(node *ast.IdentifierNode) { return } - var mapEnv bool - var types conf.TypesTable + var env Nature if c.config != nil { - mapEnv = c.config.MapEnv - types = c.config.Types + env = c.config.Env } - if mapEnv { + if env.IsFastMap() { c.emit(OpLoadFast, c.addConstant(node.Value)) - } else if ok, index, name := checker.FieldIndex(types, node); ok { + } else if ok, index, name := checker.FieldIndex(env, node); ok { c.emit(OpLoadField, c.addConstant(&runtime.Field{ Index: index, Path: []string{name}, })) - } else if ok, index, name := checker.MethodIndex(types, node); ok { + } else if ok, index, name := checker.MethodIndex(env, node); ok { c.emit(OpLoadMethod, c.addConstant(&runtime.Method{ Name: name, Index: index, @@ -377,16 +378,13 @@ func (c *compiler) IntegerNode(node *ast.IntegerNode) { } func (c *compiler) FloatNode(node *ast.FloatNode) { - t := node.Type() - if t == nil { - c.emitPush(node.Value) - return - } - switch t.Kind() { + switch node.Type().Kind() { case reflect.Float32: c.emitPush(float32(node.Value)) case reflect.Float64: c.emitPush(node.Value) + default: + c.emitPush(node.Value) } } @@ -403,6 +401,10 @@ func (c *compiler) StringNode(node *ast.StringNode) { } func (c *compiler) ConstantNode(node *ast.ConstantNode) { + if node.Value == nil { + c.emit(OpNil) + return + } c.emitPush(node.Value) } @@ -646,12 +648,12 @@ func (c *compiler) ChainNode(node *ast.ChainNode) { } func (c *compiler) MemberNode(node *ast.MemberNode) { - var types conf.TypesTable + var env Nature if c.config != nil { - types = c.config.Types + env = c.config.Env } - if ok, index, name := checker.MethodIndex(types, node); ok { + if ok, index, name := checker.MethodIndex(env, node); ok { c.compile(node.Node) c.emit(OpMethod, c.addConstant(&runtime.Method{ Name: name, @@ -662,14 +664,14 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { op := OpFetch base := node.Node - ok, index, nodeName := checker.FieldIndex(types, node) + ok, index, nodeName := checker.FieldIndex(env, node) path := []string{nodeName} if ok { op = OpFetchField for !node.Optional { if ident, isIdent := base.(*ast.IdentifierNode); isIdent { - if ok, identIndex, name := checker.FieldIndex(types, ident); ok { + if ok, identIndex, name := checker.FieldIndex(env, ident); ok { index = append(identIndex, index...) path = append([]string{name}, path...) c.emitLocation(ident.Location(), OpLoadField, c.addConstant( @@ -680,7 +682,7 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { } if member, isMember := base.(*ast.MemberNode); isMember { - if ok, memberIndex, name := checker.FieldIndex(types, member); ok { + if ok, memberIndex, name := checker.FieldIndex(env, member); ok { index = append(memberIndex, index...) path = append([]string{name}, path...) node = member @@ -695,7 +697,9 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { } c.compile(base) - if node.Optional { + // If the field is optional, we need to jump over the fetch operation. + // If no ChainNode (none c.chains) is used, do not compile the optional fetch. + if node.Optional && len(c.chains) > 0 { ph := c.emit(OpJumpIfNil, placeholder) c.chains[len(c.chains)-1] = append(c.chains[len(c.chains)-1], ph) } @@ -727,7 +731,7 @@ func (c *compiler) SliceNode(node *ast.SliceNode) { func (c *compiler) CallNode(node *ast.CallNode) { fn := node.Callee.Type() - if kind(fn) == reflect.Func { + if fn.Kind() == reflect.Func { fnInOffset := 0 fnNumIn := fn.NumIn() switch callee := node.Callee.(type) { @@ -739,24 +743,22 @@ func (c *compiler) CallNode(node *ast.CallNode) { } } case *ast.IdentifierNode: - if t, ok := c.config.Types[callee.Value]; ok && t.Method { + if t, ok := c.config.Env.MethodByName(callee.Value); ok && t.Method { fnInOffset = 1 fnNumIn-- } } for i, arg := range node.Arguments { c.compile(arg) - if k := kind(arg.Type()); k == reflect.Ptr || k == reflect.Interface { - var in reflect.Type - if fn.IsVariadic() && i >= fnNumIn-1 { - in = fn.In(fn.NumIn() - 1).Elem() - } else { - in = fn.In(i + fnInOffset) - } - if k = kind(in); k != reflect.Ptr && k != reflect.Interface { - c.emit(OpDeref) - } + + var in reflect.Type + if fn.IsVariadic() && i >= fnNumIn-1 { + in = fn.In(fn.NumIn() - 1).Elem() + } else { + in = fn.In(i + fnInOffset) } + + c.derefParam(in, arg) } } else { for _, arg := range node.Arguments { @@ -774,12 +776,16 @@ func (c *compiler) CallNode(node *ast.CallNode) { } c.compile(node.Callee) - isMethod, _, _ := checker.MethodIndex(c.config.Types, node.Callee) - if index, ok := checker.TypedFuncIndex(node.Callee.Type(), isMethod); ok { - c.emit(OpCallTyped, index) - return - } else if checker.IsFastFunc(node.Callee.Type(), isMethod) { - c.emit(OpCallFast, len(node.Arguments)) + if c.config != nil { + isMethod, _, _ := checker.MethodIndex(c.config.Env, node.Callee) + if index, ok := checker.TypedFuncIndex(node.Callee.Type(), isMethod); ok { + c.emit(OpCallTyped, index) + return + } else if checker.IsFastFunc(node.Callee.Type(), isMethod) { + c.emit(OpCallFast, len(node.Arguments)) + } else { + c.emit(OpCall, len(node.Arguments)) + } } else { c.emit(OpCall, len(node.Arguments)) } @@ -789,6 +795,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { switch node.Name { case "all": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) var loopBreak int c.emitLoop(func() { @@ -803,6 +810,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "none": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) var loopBreak int c.emitLoop(func() { @@ -818,6 +826,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "any": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) var loopBreak int c.emitLoop(func() { @@ -832,6 +841,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "one": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) c.emitLoop(func() { c.compile(node.Arguments[1]) @@ -847,6 +857,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "filter": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) c.emitLoop(func() { c.compile(node.Arguments[1]) @@ -866,6 +877,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "map": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) c.emitLoop(func() { c.compile(node.Arguments[1]) @@ -877,6 +889,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "count": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) c.emitLoop(func() { if len(node.Arguments) == 2 { @@ -894,6 +907,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "sum": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) c.emit(OpInt, 0) c.emit(OpSetAcc) @@ -913,6 +927,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "find": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) var loopBreak int c.emitLoop(func() { @@ -940,6 +955,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "findIndex": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) var loopBreak int c.emitLoop(func() { @@ -958,6 +974,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "findLast": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) var loopBreak int c.emitLoopBackwards(func() { @@ -985,6 +1002,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "findLastIndex": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) var loopBreak int c.emitLoopBackwards(func() { @@ -1003,6 +1021,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "groupBy": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) c.emit(OpCreate, 1) c.emit(OpSetAcc) @@ -1016,6 +1035,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "sortBy": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) if len(node.Arguments) == 3 { c.compile(node.Arguments[2]) @@ -1034,9 +1054,11 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { case "reduce": c.compile(node.Arguments[0]) + c.derefInNeeded(node.Arguments[0]) c.emit(OpBegin) if len(node.Arguments) == 3 { c.compile(node.Arguments[2]) + c.derefInNeeded(node.Arguments[2]) c.emit(OpSetAcc) } else { c.emit(OpPointer) @@ -1055,8 +1077,19 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { if id, ok := builtin.Index[node.Name]; ok { f := builtin.Builtins[id] - for _, arg := range node.Arguments { + for i, arg := range node.Arguments { c.compile(arg) + argType := arg.Type() + if argType.Kind() == reflect.Ptr || arg.Nature().IsUnknown() { + if f.Deref == nil { + // By default, builtins expect arguments to be dereferenced. + c.emit(OpDeref) + } else { + if f.Deref(i, argType) { + c.emit(OpDeref) + } + } + } } if f.Fast != nil { @@ -1114,7 +1147,7 @@ func (c *compiler) emitLoopBackwards(body func()) { c.patchJump(end) } -func (c *compiler) ClosureNode(node *ast.ClosureNode) { +func (c *compiler) PredicateNode(node *ast.PredicateNode) { c.compile(node.Node) } @@ -1140,6 +1173,15 @@ func (c *compiler) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) { c.endScope() } +func (c *compiler) SequenceNode(node *ast.SequenceNode) { + for i, n := range node.Nodes { + c.compile(n) + if i < len(node.Nodes)-1 { + c.emit(OpPop) + } + } +} + func (c *compiler) beginScope(name string, index int) { c.scopes = append(c.scopes, scope{name, index}) } @@ -1196,12 +1238,27 @@ func (c *compiler) PairNode(node *ast.PairNode) { } func (c *compiler) derefInNeeded(node ast.Node) { - switch kind(node.Type()) { + if node.Nature().Nil { + return + } + switch node.Type().Kind() { case reflect.Ptr, reflect.Interface: c.emit(OpDeref) } } +func (c *compiler) derefParam(in reflect.Type, param ast.Node) { + if param.Nature().Nil { + return + } + if param.Type().AssignableTo(in) { + return + } + if in.Kind() != reflect.Ptr && param.Type().Kind() == reflect.Ptr { + c.emit(OpDeref) + } +} + func (c *compiler) optimize() { for i, op := range c.bytecode { switch op { diff --git a/vendor/github.com/expr-lang/expr/conf/config.go b/vendor/github.com/expr-lang/expr/conf/config.go index 01a407a1..d629958e 100644 --- a/vendor/github.com/expr-lang/expr/conf/config.go +++ b/vendor/github.com/expr-lang/expr/conf/config.go @@ -6,37 +6,47 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/vm/runtime" ) +const ( + // DefaultMemoryBudget represents an upper limit of memory usage + DefaultMemoryBudget uint = 1e6 + + // DefaultMaxNodes represents an upper limit of AST nodes + DefaultMaxNodes uint = 10000 +) + type FunctionsTable map[string]*builtin.Function type Config struct { - Env any - Types TypesTable - MapEnv bool - DefaultType reflect.Type - Expect reflect.Kind - ExpectAny bool - Optimize bool - Strict bool - Profile bool - ConstFns map[string]reflect.Value - Visitors []ast.Visitor - Functions FunctionsTable - Builtins FunctionsTable - Disabled map[string]bool // disabled builtins + EnvObject any + Env nature.Nature + Expect reflect.Kind + ExpectAny bool + Optimize bool + Strict bool + Profile bool + MaxNodes uint + MemoryBudget uint + ConstFns map[string]reflect.Value + Visitors []ast.Visitor + Functions FunctionsTable + Builtins FunctionsTable + Disabled map[string]bool // disabled builtins } // CreateNew creates new config with default values. func CreateNew() *Config { c := &Config{ - Optimize: true, - Types: make(TypesTable), - ConstFns: make(map[string]reflect.Value), - Functions: make(map[string]*builtin.Function), - Builtins: make(map[string]*builtin.Function), - Disabled: make(map[string]bool), + Optimize: true, + MaxNodes: DefaultMaxNodes, + MemoryBudget: DefaultMemoryBudget, + ConstFns: make(map[string]reflect.Value), + Functions: make(map[string]*builtin.Function), + Builtins: make(map[string]*builtin.Function), + Disabled: make(map[string]bool), } for _, f := range builtin.Builtins { c.Builtins[f.Name] = f @@ -52,31 +62,16 @@ func New(env any) *Config { } func (c *Config) WithEnv(env any) { - var mapEnv bool - var mapValueType reflect.Type - if _, ok := env.(map[string]any); ok { - mapEnv = true - } else { - if reflect.ValueOf(env).Kind() == reflect.Map { - mapValueType = reflect.TypeOf(env).Elem() - } - } - - c.Env = env - types := CreateTypesTable(env) - for name, t := range types { - c.Types[name] = t - } - c.MapEnv = mapEnv - c.DefaultType = mapValueType - c.Strict = true + c.EnvObject = env + c.Env = Env(env) + c.Strict = c.Env.Strict } func (c *Config) ConstExpr(name string) { - if c.Env == nil { + if c.EnvObject == nil { panic("no environment is specified for ConstExpr()") } - fn := reflect.ValueOf(runtime.Fetch(c.Env, name)) + fn := reflect.ValueOf(runtime.Fetch(c.EnvObject, name)) if fn.Kind() != reflect.Func { panic(fmt.Errorf("const expression %q must be a function", name)) } @@ -99,7 +94,7 @@ func (c *Config) IsOverridden(name string) bool { if _, ok := c.Functions[name]; ok { return true } - if _, ok := c.Types[name]; ok { + if _, ok := c.Env.Get(name); ok { return true } return false diff --git a/vendor/github.com/expr-lang/expr/conf/env.go b/vendor/github.com/expr-lang/expr/conf/env.go new file mode 100644 index 00000000..8b13df1e --- /dev/null +++ b/vendor/github.com/expr-lang/expr/conf/env.go @@ -0,0 +1,68 @@ +package conf + +import ( + "fmt" + "reflect" + + . "github.com/expr-lang/expr/checker/nature" + "github.com/expr-lang/expr/internal/deref" + "github.com/expr-lang/expr/types" +) + +func Env(env any) Nature { + if env == nil { + return Nature{ + Type: reflect.TypeOf(map[string]any{}), + Strict: true, + } + } + + switch env := env.(type) { + case types.Map: + return env.Nature() + } + + v := reflect.ValueOf(env) + d := deref.Value(v) + + switch d.Kind() { + case reflect.Struct: + return Nature{ + Type: v.Type(), + Strict: true, + } + + case reflect.Map: + n := Nature{ + Type: v.Type(), + Fields: make(map[string]Nature, v.Len()), + Strict: true, + } + + for _, key := range v.MapKeys() { + elem := v.MapIndex(key) + if !elem.IsValid() || !elem.CanInterface() { + panic(fmt.Sprintf("invalid map value: %s", key)) + } + + face := elem.Interface() + + switch face := face.(type) { + case types.Map: + n.Fields[key.String()] = face.Nature() + + default: + if face == nil { + n.Fields[key.String()] = Nature{Nil: true} + continue + } + n.Fields[key.String()] = Nature{Type: reflect.TypeOf(face)} + } + + } + + return n + } + + panic(fmt.Sprintf("unknown type %T", env)) +} diff --git a/vendor/github.com/expr-lang/expr/conf/types_table.go b/vendor/github.com/expr-lang/expr/conf/types_table.go deleted file mode 100644 index a42a4287..00000000 --- a/vendor/github.com/expr-lang/expr/conf/types_table.go +++ /dev/null @@ -1,121 +0,0 @@ -package conf - -import ( - "reflect" - - "github.com/expr-lang/expr/internal/deref" -) - -type TypesTable map[string]Tag - -type Tag struct { - Type reflect.Type - Ambiguous bool - FieldIndex []int - Method bool - MethodIndex int -} - -// CreateTypesTable creates types table for type checks during parsing. -// If struct is passed, all fields will be treated as variables, -// as well as all fields of embedded structs and struct itself. -// -// If map is passed, all items will be treated as variables -// (key as name, value as type). -func CreateTypesTable(i any) TypesTable { - if i == nil { - return nil - } - - types := make(TypesTable) - v := reflect.ValueOf(i) - t := reflect.TypeOf(i) - - d := t - if t.Kind() == reflect.Ptr { - d = t.Elem() - } - - switch d.Kind() { - case reflect.Struct: - types = FieldsFromStruct(d) - - // Methods of struct should be gathered from original struct with pointer, - // as methods maybe declared on pointer receiver. Also this method retrieves - // all embedded structs methods as well, no need to recursion. - for i := 0; i < t.NumMethod(); i++ { - m := t.Method(i) - types[m.Name] = Tag{ - Type: m.Type, - Method: true, - MethodIndex: i, - } - } - - case reflect.Map: - for _, key := range v.MapKeys() { - value := v.MapIndex(key) - if key.Kind() == reflect.String && value.IsValid() && value.CanInterface() { - if key.String() == "$env" { // Could check for all keywords here - panic("attempt to misuse env keyword as env map key") - } - types[key.String()] = Tag{Type: reflect.TypeOf(value.Interface())} - } - } - - // A map may have method too. - for i := 0; i < t.NumMethod(); i++ { - m := t.Method(i) - types[m.Name] = Tag{ - Type: m.Type, - Method: true, - MethodIndex: i, - } - } - } - - return types -} - -func FieldsFromStruct(t reflect.Type) TypesTable { - types := make(TypesTable) - t = deref.Type(t) - if t == nil { - return types - } - - switch t.Kind() { - case reflect.Struct: - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - - if f.Anonymous { - for name, typ := range FieldsFromStruct(f.Type) { - if _, ok := types[name]; ok { - types[name] = Tag{Ambiguous: true} - } else { - typ.FieldIndex = append(f.Index, typ.FieldIndex...) - types[name] = typ - } - } - } - if fn := FieldName(f); fn == "$env" { // Could check for all keywords here - panic("attempt to misuse env keyword as env struct field tag") - } else { - types[FieldName(f)] = Tag{ - Type: f.Type, - FieldIndex: f.Index, - } - } - } - } - - return types -} - -func FieldName(field reflect.StructField) string { - if taggedName := field.Tag.Get("expr"); taggedName != "" { - return taggedName - } - return field.Name -} diff --git a/vendor/github.com/expr-lang/expr/expr.go b/vendor/github.com/expr-lang/expr/expr.go index 8c619e1c..33b7cf35 100644 --- a/vendor/github.com/expr-lang/expr/expr.go +++ b/vendor/github.com/expr-lang/expr/expr.go @@ -13,6 +13,7 @@ import ( "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/optimizer" + "github.com/expr-lang/expr/parser" "github.com/expr-lang/expr/patcher" "github.com/expr-lang/expr/vm" ) @@ -45,7 +46,7 @@ func Operator(operator string, fn ...string) Option { p := &patcher.OperatorOverloading{ Operator: operator, Overloads: fn, - Types: c.Types, + Env: &c.Env, Functions: c.Functions, } c.Visitors = append(c.Visitors, p) @@ -240,7 +241,12 @@ func Eval(input string, env any) (any, error) { return nil, fmt.Errorf("misused expr.Eval: second argument (env) should be passed without expr.Env") } - program, err := Compile(input) + tree, err := parser.Parse(input) + if err != nil { + return nil, err + } + + program, err := compiler.Compile(tree, nil) if err != nil { return nil, err } diff --git a/vendor/github.com/expr-lang/expr/internal/deref/deref.go b/vendor/github.com/expr-lang/expr/internal/deref/deref.go index acdc8981..da3e28ce 100644 --- a/vendor/github.com/expr-lang/expr/internal/deref/deref.go +++ b/vendor/github.com/expr-lang/expr/internal/deref/deref.go @@ -5,7 +5,7 @@ import ( "reflect" ) -func Deref(p any) any { +func Interface(p any) any { if p == nil { return nil } diff --git a/vendor/github.com/expr-lang/expr/optimizer/const_expr.go b/vendor/github.com/expr-lang/expr/optimizer/const_expr.go index 501ea3c5..1b45385f 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/const_expr.go +++ b/vendor/github.com/expr-lang/expr/optimizer/const_expr.go @@ -30,11 +30,6 @@ func (c *constExpr) Visit(node *Node) { } }() - patch := func(newNode Node) { - c.applied = true - Patch(node, newNode) - } - if call, ok := (*node).(*CallNode); ok { if name, ok := call.Callee.(*IdentifierNode); ok { fn, ok := c.fns[name.Value] @@ -78,7 +73,8 @@ func (c *constExpr) Visit(node *Node) { return } constNode := &ConstantNode{Value: value} - patch(constNode) + patchWithType(node, constNode) + c.applied = true } } } diff --git a/vendor/github.com/expr-lang/expr/optimizer/filter_first.go b/vendor/github.com/expr-lang/expr/optimizer/filter_first.go index 7ea8f6fa..b04a5cb3 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/filter_first.go +++ b/vendor/github.com/expr-lang/expr/optimizer/filter_first.go @@ -12,7 +12,7 @@ func (*filterFirst) Visit(node *Node) { if filter, ok := member.Node.(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "find", Arguments: filter.Arguments, Throws: true, // to match the behavior of filter()[0] @@ -27,7 +27,7 @@ func (*filterFirst) Visit(node *Node) { if filter, ok := first.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "find", Arguments: filter.Arguments, Throws: false, // as first() will return nil if not found diff --git a/vendor/github.com/expr-lang/expr/optimizer/filter_last.go b/vendor/github.com/expr-lang/expr/optimizer/filter_last.go index 9a1cc5e2..8c046bf8 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/filter_last.go +++ b/vendor/github.com/expr-lang/expr/optimizer/filter_last.go @@ -12,7 +12,7 @@ func (*filterLast) Visit(node *Node) { if filter, ok := member.Node.(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "findLast", Arguments: filter.Arguments, Throws: true, // to match the behavior of filter()[-1] @@ -27,7 +27,7 @@ func (*filterLast) Visit(node *Node) { if filter, ok := first.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "findLast", Arguments: filter.Arguments, Throws: false, // as last() will return nil if not found diff --git a/vendor/github.com/expr-lang/expr/optimizer/filter_len.go b/vendor/github.com/expr-lang/expr/optimizer/filter_len.go index 6577163e..c66fde96 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/filter_len.go +++ b/vendor/github.com/expr-lang/expr/optimizer/filter_len.go @@ -13,7 +13,7 @@ func (*filterLen) Visit(node *Node) { if filter, ok := ln.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && len(filter.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "count", Arguments: filter.Arguments, }) diff --git a/vendor/github.com/expr-lang/expr/optimizer/filter_map.go b/vendor/github.com/expr-lang/expr/optimizer/filter_map.go index d988dc69..17659a91 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/filter_map.go +++ b/vendor/github.com/expr-lang/expr/optimizer/filter_map.go @@ -9,17 +9,25 @@ type filterMap struct{} func (*filterMap) Visit(node *Node) { if mapBuiltin, ok := (*node).(*BuiltinNode); ok && mapBuiltin.Name == "map" && - len(mapBuiltin.Arguments) == 2 { - if closure, ok := mapBuiltin.Arguments[1].(*ClosureNode); ok { + len(mapBuiltin.Arguments) == 2 && + Find(mapBuiltin.Arguments[1], isIndexPointer) == nil { + if predicate, ok := mapBuiltin.Arguments[1].(*PredicateNode); ok { if filter, ok := mapBuiltin.Arguments[0].(*BuiltinNode); ok && filter.Name == "filter" && filter.Map == nil /* not already optimized */ { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "filter", Arguments: filter.Arguments, - Map: closure.Node, + Map: predicate.Node, }) } } } } + +func isIndexPointer(node Node) bool { + if pointer, ok := node.(*PointerNode); ok && pointer.Name == "index" { + return true + } + return false +} diff --git a/vendor/github.com/expr-lang/expr/optimizer/fold.go b/vendor/github.com/expr-lang/expr/optimizer/fold.go index 910c9240..bb40eab9 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/fold.go +++ b/vendor/github.com/expr-lang/expr/optimizer/fold.go @@ -1,20 +1,12 @@ package optimizer import ( - "fmt" "math" - "reflect" . "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/file" ) -var ( - integerType = reflect.TypeOf(0) - floatType = reflect.TypeOf(float64(0)) - stringType = reflect.TypeOf("") -) - type fold struct { applied bool err *file.Error @@ -23,20 +15,11 @@ type fold struct { func (fold *fold) Visit(node *Node) { patch := func(newNode Node) { fold.applied = true - Patch(node, newNode) + patchWithType(node, newNode) } - patchWithType := func(newNode Node) { - patch(newNode) - switch newNode.(type) { - case *IntegerNode: - newNode.SetType(integerType) - case *FloatNode: - newNode.SetType(floatType) - case *StringNode: - newNode.SetType(stringType) - default: - panic(fmt.Sprintf("unknown type %T", newNode)) - } + patchCopy := func(newNode Node) { + fold.applied = true + patchCopyType(node, newNode) } switch n := (*node).(type) { @@ -44,17 +27,17 @@ func (fold *fold) Visit(node *Node) { switch n.Operator { case "-": if i, ok := n.Node.(*IntegerNode); ok { - patchWithType(&IntegerNode{Value: -i.Value}) + patch(&IntegerNode{Value: -i.Value}) } if i, ok := n.Node.(*FloatNode); ok { - patchWithType(&FloatNode{Value: -i.Value}) + patch(&FloatNode{Value: -i.Value}) } case "+": if i, ok := n.Node.(*IntegerNode); ok { - patchWithType(&IntegerNode{Value: i.Value}) + patch(&IntegerNode{Value: i.Value}) } if i, ok := n.Node.(*FloatNode); ok { - patchWithType(&FloatNode{Value: i.Value}) + patch(&FloatNode{Value: i.Value}) } case "!", "not": if a := toBool(n.Node); a != nil { @@ -69,28 +52,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&IntegerNode{Value: a.Value + b.Value}) + patch(&IntegerNode{Value: a.Value + b.Value}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) + b.Value}) + patch(&FloatNode{Value: float64(a.Value) + b.Value}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value + float64(b.Value)}) + patch(&FloatNode{Value: a.Value + float64(b.Value)}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value + b.Value}) + patch(&FloatNode{Value: a.Value + b.Value}) } } { @@ -105,28 +88,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&IntegerNode{Value: a.Value - b.Value}) + patch(&IntegerNode{Value: a.Value - b.Value}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) - b.Value}) + patch(&FloatNode{Value: float64(a.Value) - b.Value}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value - float64(b.Value)}) + patch(&FloatNode{Value: a.Value - float64(b.Value)}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value - b.Value}) + patch(&FloatNode{Value: a.Value - b.Value}) } } case "*": @@ -134,28 +117,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&IntegerNode{Value: a.Value * b.Value}) + patch(&IntegerNode{Value: a.Value * b.Value}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) * b.Value}) + patch(&FloatNode{Value: float64(a.Value) * b.Value}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value * float64(b.Value)}) + patch(&FloatNode{Value: a.Value * float64(b.Value)}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value * b.Value}) + patch(&FloatNode{Value: a.Value * b.Value}) } } case "/": @@ -163,28 +146,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)}) + patch(&FloatNode{Value: float64(a.Value) / float64(b.Value)}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: float64(a.Value) / b.Value}) + patch(&FloatNode{Value: float64(a.Value) / b.Value}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value / float64(b.Value)}) + patch(&FloatNode{Value: a.Value / float64(b.Value)}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: a.Value / b.Value}) + patch(&FloatNode{Value: a.Value / b.Value}) } } case "%": @@ -205,28 +188,28 @@ func (fold *fold) Visit(node *Node) { a := toInteger(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}) + patch(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}) } } { a := toInteger(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}) + patch(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}) } } { a := toFloat(n.Left) b := toInteger(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}) + patch(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}) } } { a := toFloat(n.Left) b := toFloat(n.Right) if a != nil && b != nil { - patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)}) + patch(&FloatNode{Value: math.Pow(a.Value, b.Value)}) } } case "and", "&&": @@ -234,9 +217,9 @@ func (fold *fold) Visit(node *Node) { b := toBool(n.Right) if a != nil && a.Value { // true and x - patch(n.Right) + patchCopy(n.Right) } else if b != nil && b.Value { // x and true - patch(n.Left) + patchCopy(n.Left) } else if (a != nil && !a.Value) || (b != nil && !b.Value) { // "x and false" or "false and x" patch(&BoolNode{Value: false}) } @@ -245,9 +228,9 @@ func (fold *fold) Visit(node *Node) { b := toBool(n.Right) if a != nil && !a.Value { // false or x - patch(n.Right) + patchCopy(n.Right) } else if b != nil && !b.Value { // x or false - patch(n.Left) + patchCopy(n.Left) } else if (a != nil && a.Value) || (b != nil && b.Value) { // "x or true" or "true or x" patch(&BoolNode{Value: true}) } @@ -302,20 +285,21 @@ func (fold *fold) Visit(node *Node) { } case *BuiltinNode: + // TODO: Move this to a separate visitor filter_filter.go switch n.Name { case "filter": if len(n.Arguments) != 2 { return } if base, ok := n.Arguments[0].(*BuiltinNode); ok && base.Name == "filter" { - patch(&BuiltinNode{ + patchCopy(&BuiltinNode{ Name: "filter", Arguments: []Node{ base.Arguments[0], &BinaryNode{ Operator: "&&", - Left: base.Arguments[1], - Right: n.Arguments[1], + Left: base.Arguments[1].(*PredicateNode).Node, + Right: n.Arguments[1].(*PredicateNode).Node, }, }, }) diff --git a/vendor/github.com/expr-lang/expr/optimizer/in_array.go b/vendor/github.com/expr-lang/expr/optimizer/in_array.go index 8933d9b9..e91320c0 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/in_array.go +++ b/vendor/github.com/expr-lang/expr/optimizer/in_array.go @@ -32,10 +32,12 @@ func (*inArray) Visit(node *Node) { for _, a := range array.Nodes { value[a.(*IntegerNode).Value] = struct{}{} } - Patch(node, &BinaryNode{ + m := &ConstantNode{Value: value} + m.SetType(reflect.TypeOf(value)) + patchCopyType(node, &BinaryNode{ Operator: n.Operator, Left: n.Left, - Right: &ConstantNode{Value: value}, + Right: m, }) } @@ -50,10 +52,12 @@ func (*inArray) Visit(node *Node) { for _, a := range array.Nodes { value[a.(*StringNode).Value] = struct{}{} } - Patch(node, &BinaryNode{ + m := &ConstantNode{Value: value} + m.SetType(reflect.TypeOf(value)) + patchCopyType(node, &BinaryNode{ Operator: n.Operator, Left: n.Left, - Right: &ConstantNode{Value: value}, + Right: m, }) } diff --git a/vendor/github.com/expr-lang/expr/optimizer/in_range.go b/vendor/github.com/expr-lang/expr/optimizer/in_range.go index 01faabbd..ed2f557e 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/in_range.go +++ b/vendor/github.com/expr-lang/expr/optimizer/in_range.go @@ -22,7 +22,7 @@ func (*inRange) Visit(node *Node) { if rangeOp, ok := n.Right.(*BinaryNode); ok && rangeOp.Operator == ".." { if from, ok := rangeOp.Left.(*IntegerNode); ok { if to, ok := rangeOp.Right.(*IntegerNode); ok { - Patch(node, &BinaryNode{ + patchCopyType(node, &BinaryNode{ Operator: "and", Left: &BinaryNode{ Operator: ">=", diff --git a/vendor/github.com/expr-lang/expr/optimizer/optimizer.go b/vendor/github.com/expr-lang/expr/optimizer/optimizer.go index 4ceb3fa4..9a9677c1 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/optimizer.go +++ b/vendor/github.com/expr-lang/expr/optimizer/optimizer.go @@ -1,6 +1,9 @@ package optimizer import ( + "fmt" + "reflect" + . "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/conf" ) @@ -41,3 +44,36 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &sumMap{}) return nil } + +var ( + boolType = reflect.TypeOf(true) + integerType = reflect.TypeOf(0) + floatType = reflect.TypeOf(float64(0)) + stringType = reflect.TypeOf("") +) + +func patchWithType(node *Node, newNode Node) { + switch n := newNode.(type) { + case *BoolNode: + newNode.SetType(boolType) + case *IntegerNode: + newNode.SetType(integerType) + case *FloatNode: + newNode.SetType(floatType) + case *StringNode: + newNode.SetType(stringType) + case *ConstantNode: + newNode.SetType(reflect.TypeOf(n.Value)) + case *BinaryNode: + newNode.SetType(n.Type()) + default: + panic(fmt.Sprintf("unknown type %T", newNode)) + } + Patch(node, newNode) +} + +func patchCopyType(node *Node, newNode Node) { + t := (*node).Type() + newNode.SetType(t) + Patch(node, newNode) +} diff --git a/vendor/github.com/expr-lang/expr/optimizer/predicate_combination.go b/vendor/github.com/expr-lang/expr/optimizer/predicate_combination.go index 6e8a7f7c..65f88e34 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/predicate_combination.go +++ b/vendor/github.com/expr-lang/expr/optimizer/predicate_combination.go @@ -21,19 +21,19 @@ func (v *predicateCombination) Visit(node *Node) { if combinedOp, ok := combinedOperator(left.Name, op.Operator); ok { if right, ok := op.Right.(*BuiltinNode); ok && right.Name == left.Name { if left.Arguments[0].Type() == right.Arguments[0].Type() && left.Arguments[0].String() == right.Arguments[0].String() { - closure := &ClosureNode{ + predicate := &PredicateNode{ Node: &BinaryNode{ Operator: combinedOp, - Left: left.Arguments[1].(*ClosureNode).Node, - Right: right.Arguments[1].(*ClosureNode).Node, + Left: left.Arguments[1].(*PredicateNode).Node, + Right: right.Arguments[1].(*PredicateNode).Node, }, } - v.Visit(&closure.Node) - Patch(node, &BuiltinNode{ + v.Visit(&predicate.Node) + patchCopyType(node, &BuiltinNode{ Name: left.Name, Arguments: []Node{ left.Arguments[0], - closure, + predicate, }, }) } diff --git a/vendor/github.com/expr-lang/expr/optimizer/sum_array.go b/vendor/github.com/expr-lang/expr/optimizer/sum_array.go index 0a05d1f2..3c96795e 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/sum_array.go +++ b/vendor/github.com/expr-lang/expr/optimizer/sum_array.go @@ -14,7 +14,7 @@ func (*sumArray) Visit(node *Node) { len(sumBuiltin.Arguments) == 1 { if array, ok := sumBuiltin.Arguments[0].(*ArrayNode); ok && len(array.Nodes) >= 2 { - Patch(node, sumArrayFold(array)) + patchCopyType(node, sumArrayFold(array)) } } } diff --git a/vendor/github.com/expr-lang/expr/optimizer/sum_map.go b/vendor/github.com/expr-lang/expr/optimizer/sum_map.go index a41a5373..6de97d37 100644 --- a/vendor/github.com/expr-lang/expr/optimizer/sum_map.go +++ b/vendor/github.com/expr-lang/expr/optimizer/sum_map.go @@ -13,7 +13,7 @@ func (*sumMap) Visit(node *Node) { if mapBuiltin, ok := sumBuiltin.Arguments[0].(*BuiltinNode); ok && mapBuiltin.Name == "map" && len(mapBuiltin.Arguments) == 2 { - Patch(node, &BuiltinNode{ + patchCopyType(node, &BuiltinNode{ Name: "sum", Arguments: []Node{ mapBuiltin.Arguments[0], diff --git a/vendor/github.com/expr-lang/expr/parser/lexer/state.go b/vendor/github.com/expr-lang/expr/parser/lexer/state.go index d351e2f5..c694a2ca 100644 --- a/vendor/github.com/expr-lang/expr/parser/lexer/state.go +++ b/vendor/github.com/expr-lang/expr/parser/lexer/state.go @@ -129,9 +129,7 @@ loop: switch l.word() { case "not": return not - case "in", "or", "and", "matches", "contains", "startsWith", "endsWith": - l.emit(Operator) - case "let": + case "in", "or", "and", "matches", "contains", "startsWith", "endsWith", "let", "if", "else": l.emit(Operator) default: l.emit(Identifier) diff --git a/vendor/github.com/expr-lang/expr/parser/parser.go b/vendor/github.com/expr-lang/expr/parser/parser.go index 77b2a700..0a463fed 100644 --- a/vendor/github.com/expr-lang/expr/parser/parser.go +++ b/vendor/github.com/expr-lang/expr/parser/parser.go @@ -19,7 +19,7 @@ type arg byte const ( expr arg = 1 << iota - closure + predicate ) const optional arg = 1 << 7 @@ -27,30 +27,69 @@ const optional arg = 1 << 7 var predicates = map[string]struct { args []arg }{ - "all": {[]arg{expr, closure}}, - "none": {[]arg{expr, closure}}, - "any": {[]arg{expr, closure}}, - "one": {[]arg{expr, closure}}, - "filter": {[]arg{expr, closure}}, - "map": {[]arg{expr, closure}}, - "count": {[]arg{expr, closure | optional}}, - "sum": {[]arg{expr, closure | optional}}, - "find": {[]arg{expr, closure}}, - "findIndex": {[]arg{expr, closure}}, - "findLast": {[]arg{expr, closure}}, - "findLastIndex": {[]arg{expr, closure}}, - "groupBy": {[]arg{expr, closure}}, - "sortBy": {[]arg{expr, closure, expr | optional}}, - "reduce": {[]arg{expr, closure, expr | optional}}, + "all": {[]arg{expr, predicate}}, + "none": {[]arg{expr, predicate}}, + "any": {[]arg{expr, predicate}}, + "one": {[]arg{expr, predicate}}, + "filter": {[]arg{expr, predicate}}, + "map": {[]arg{expr, predicate}}, + "count": {[]arg{expr, predicate | optional}}, + "sum": {[]arg{expr, predicate | optional}}, + "find": {[]arg{expr, predicate}}, + "findIndex": {[]arg{expr, predicate}}, + "findLast": {[]arg{expr, predicate}}, + "findLastIndex": {[]arg{expr, predicate}}, + "groupBy": {[]arg{expr, predicate}}, + "sortBy": {[]arg{expr, predicate, expr | optional}}, + "reduce": {[]arg{expr, predicate, expr | optional}}, } type parser struct { - tokens []Token - current Token - pos int - err *file.Error - depth int // closure call depth - config *conf.Config + tokens []Token + current Token + pos int + err *file.Error + config *conf.Config + depth int // predicate call depth + nodeCount uint // tracks number of AST nodes created +} + +func (p *parser) checkNodeLimit() error { + p.nodeCount++ + if p.config == nil { + if p.nodeCount > conf.DefaultMaxNodes { + p.error("compilation failed: expression exceeds maximum allowed nodes") + return nil + } + return nil + } + if p.config.MaxNodes > 0 && p.nodeCount > p.config.MaxNodes { + p.error("compilation failed: expression exceeds maximum allowed nodes") + return nil + } + return nil +} + +func (p *parser) createNode(n Node, loc file.Location) Node { + if err := p.checkNodeLimit(); err != nil { + return nil + } + if n == nil || p.err != nil { + return nil + } + n.SetLocation(loc) + return n +} + +func (p *parser) createMemberNode(n *MemberNode, loc file.Location) *MemberNode { + if err := p.checkNodeLimit(); err != nil { + return nil + } + if n == nil || p.err != nil { + return nil + } + n.SetLocation(loc) + return n } type Tree struct { @@ -59,9 +98,7 @@ type Tree struct { } func Parse(input string) (*Tree, error) { - return ParseWithConfig(input, &conf.Config{ - Disabled: map[string]bool{}, - }) + return ParseWithConfig(input, nil) } func ParseWithConfig(input string, config *conf.Config) (*Tree, error) { @@ -78,7 +115,7 @@ func ParseWithConfig(input string, config *conf.Config) (*Tree, error) { config: config, } - node := p.parseExpression(0) + node := p.parseSequenceExpression() if !p.current.Is(EOF) { p.error("unexpected token %v", p.current) @@ -128,11 +165,40 @@ func (p *parser) expect(kind Kind, values ...string) { // parse functions +func (p *parser) parseSequenceExpression() Node { + nodes := []Node{p.parseExpression(0)} + + for p.current.Is(Operator, ";") && p.err == nil { + p.next() + // If a trailing semicolon is present, break out. + if p.current.Is(EOF) { + break + } + nodes = append(nodes, p.parseExpression(0)) + } + + if len(nodes) == 1 { + return nodes[0] + } + + return p.createNode(&SequenceNode{ + Nodes: nodes, + }, nodes[0].Location()) +} + func (p *parser) parseExpression(precedence int) Node { + if p.err != nil { + return nil + } + if precedence == 0 && p.current.Is(Operator, "let") { return p.parseVariableDeclaration() } + if precedence == 0 && p.current.Is(Operator, "if") { + return p.parseConditionalIf() + } + nodeLeft := p.parsePrimary() prevOperator := "" @@ -187,19 +253,23 @@ func (p *parser) parseExpression(precedence int) Node { nodeRight = p.parseExpression(op.Precedence) } - nodeLeft = &BinaryNode{ + nodeLeft = p.createNode(&BinaryNode{ Operator: opToken.Value, Left: nodeLeft, Right: nodeRight, + }, opToken.Location) + if nodeLeft == nil { + return nil } - nodeLeft.SetLocation(opToken.Location) if negate { - nodeLeft = &UnaryNode{ + nodeLeft = p.createNode(&UnaryNode{ Operator: "not", Node: nodeLeft, + }, notToken.Location) + if nodeLeft == nil { + return nil } - nodeLeft.SetLocation(notToken.Location) } goto next @@ -225,14 +295,31 @@ func (p *parser) parseVariableDeclaration() Node { p.expect(Operator, "=") value := p.parseExpression(0) p.expect(Operator, ";") - node := p.parseExpression(0) - let := &VariableDeclaratorNode{ + node := p.parseSequenceExpression() + return p.createNode(&VariableDeclaratorNode{ Name: variableName.Value, Value: value, Expr: node, + }, variableName.Location) +} + +func (p *parser) parseConditionalIf() Node { + p.next() + nodeCondition := p.parseExpression(0) + p.expect(Bracket, "{") + expr1 := p.parseSequenceExpression() + p.expect(Bracket, "}") + p.expect(Operator, "else") + p.expect(Bracket, "{") + expr2 := p.parseSequenceExpression() + p.expect(Bracket, "}") + + return &ConditionalNode{ + Cond: nodeCondition, + Exp1: expr1, + Exp2: expr2, } - let.SetLocation(variableName.Location) - return let + } func (p *parser) parseConditional(node Node) Node { @@ -250,10 +337,13 @@ func (p *parser) parseConditional(node Node) Node { expr2 = p.parseExpression(0) } - node = &ConditionalNode{ + node = p.createNode(&ConditionalNode{ Cond: node, Exp1: expr1, Exp2: expr2, + }, p.current.Location) + if node == nil { + return nil } } return node @@ -266,18 +356,20 @@ func (p *parser) parsePrimary() Node { if op, ok := operator.Unary[token.Value]; ok { p.next() expr := p.parseExpression(op.Precedence) - node := &UnaryNode{ + node := p.createNode(&UnaryNode{ Operator: token.Value, Node: expr, + }, token.Location) + if node == nil { + return nil } - node.SetLocation(token.Location) return p.parsePostfixExpression(node) } } if token.Is(Bracket, "(") { p.next() - expr := p.parseExpression(0) + expr := p.parseSequenceExpression() p.expect(Bracket, ")") // "an opened parenthesis is not properly closed" return p.parsePostfixExpression(expr) } @@ -292,14 +384,12 @@ func (p *parser) parsePrimary() Node { p.next() } } - node := &PointerNode{Name: name} - node.SetLocation(token.Location) + node := p.createNode(&PointerNode{Name: name}, token.Location) + if node == nil { + return nil + } return p.parsePostfixExpression(node) } - } else { - if token.Is(Operator, "#") || token.Is(Operator, ".") { - p.error("cannot use pointer accessor outside closure") - } } if token.Is(Operator, "::") { @@ -322,23 +412,31 @@ func (p *parser) parseSecondary() Node { p.next() switch token.Value { case "true": - node := &BoolNode{Value: true} - node.SetLocation(token.Location) + node = p.createNode(&BoolNode{Value: true}, token.Location) + if node == nil { + return nil + } return node case "false": - node := &BoolNode{Value: false} - node.SetLocation(token.Location) + node = p.createNode(&BoolNode{Value: false}, token.Location) + if node == nil { + return nil + } return node case "nil": - node := &NilNode{} - node.SetLocation(token.Location) + node = p.createNode(&NilNode{}, token.Location) + if node == nil { + return nil + } return node default: if p.current.Is(Bracket, "(") { node = p.parseCall(token, []Node{}, true) } else { - node = &IdentifierNode{Value: token.Value} - node.SetLocation(token.Location) + node = p.createNode(&IdentifierNode{Value: token.Value}, token.Location) + if node == nil { + return nil + } } } @@ -385,8 +483,10 @@ func (p *parser) parseSecondary() Node { return node case String: p.next() - node = &StringNode{Value: token.Value} - node.SetLocation(token.Location) + node = p.createNode(&StringNode{Value: token.Value}, token.Location) + if node == nil { + return nil + } default: if token.Is(Bracket, "[") { @@ -406,7 +506,7 @@ func (p *parser) toIntegerNode(number int64) Node { p.error("integer literal is too large") return nil } - return &IntegerNode{Value: int(number)} + return p.createNode(&IntegerNode{Value: int(number)}, p.current.Location) } func (p *parser) toFloatNode(number float64) Node { @@ -414,13 +514,16 @@ func (p *parser) toFloatNode(number float64) Node { p.error("float literal is too large") return nil } - return &FloatNode{Value: number} + return p.createNode(&FloatNode{Value: number}, p.current.Location) } func (p *parser) parseCall(token Token, arguments []Node, checkOverrides bool) Node { var node Node - isOverridden := p.config.IsOverridden(token.Value) + isOverridden := false + if p.config != nil { + isOverridden = p.config.IsOverridden(token.Value) + } isOverridden = isOverridden && checkOverrides if b, ok := predicates[token.Value]; ok && !isOverridden { @@ -448,33 +551,46 @@ func (p *parser) parseCall(token Token, arguments []Node, checkOverrides bool) N switch { case arg&expr == expr: node = p.parseExpression(0) - case arg&closure == closure: - node = p.parseClosure() + case arg&predicate == predicate: + node = p.parsePredicate() } arguments = append(arguments, node) } + // skip last comma + if p.current.Is(Operator, ",") { + p.next() + } p.expect(Bracket, ")") - node = &BuiltinNode{ + node = p.createNode(&BuiltinNode{ Name: token.Value, Arguments: arguments, + }, token.Location) + if node == nil { + return nil } - node.SetLocation(token.Location) - } else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] && !isOverridden { - node = &BuiltinNode{ + } else if _, ok := builtin.Index[token.Value]; ok && (p.config == nil || !p.config.Disabled[token.Value]) && !isOverridden { + node = p.createNode(&BuiltinNode{ Name: token.Value, Arguments: p.parseArguments(arguments), + }, token.Location) + if node == nil { + return nil } - node.SetLocation(token.Location) + } else { - callee := &IdentifierNode{Value: token.Value} - callee.SetLocation(token.Location) - node = &CallNode{ + callee := p.createNode(&IdentifierNode{Value: token.Value}, token.Location) + if callee == nil { + return nil + } + node = p.createNode(&CallNode{ Callee: callee, Arguments: p.parseArguments(arguments), + }, token.Location) + if node == nil { + return nil } - node.SetLocation(token.Location) } return node } @@ -489,6 +605,9 @@ func (p *parser) parseArguments(arguments []Node) []Node { if len(arguments) > offset { p.expect(Operator, ",") } + if p.current.Is(Bracket, ")") { + break + } node := p.parseExpression(0) arguments = append(arguments, node) } @@ -497,26 +616,36 @@ func (p *parser) parseArguments(arguments []Node) []Node { return arguments } -func (p *parser) parseClosure() Node { +func (p *parser) parsePredicate() Node { startToken := p.current - expectClosingBracket := false + withBrackets := false if p.current.Is(Bracket, "{") { p.next() - expectClosingBracket = true + withBrackets = true } p.depth++ - node := p.parseExpression(0) + var node Node + if withBrackets { + node = p.parseSequenceExpression() + } else { + node = p.parseExpression(0) + if p.current.Is(Operator, ";") { + p.error("wrap predicate with brackets { and }") + } + } p.depth-- - if expectClosingBracket { + if withBrackets { p.expect(Bracket, "}") } - closure := &ClosureNode{ + predicateNode := p.createNode(&PredicateNode{ Node: node, + }, startToken.Location) + if predicateNode == nil { + return nil } - closure.SetLocation(startToken.Location) - return closure + return predicateNode } func (p *parser) parseArrayExpression(token Token) Node { @@ -536,8 +665,10 @@ func (p *parser) parseArrayExpression(token Token) Node { end: p.expect(Bracket, "]") - node := &ArrayNode{Nodes: nodes} - node.SetLocation(token.Location) + node := p.createNode(&ArrayNode{Nodes: nodes}, token.Location) + if node == nil { + return nil + } return node } @@ -563,8 +694,10 @@ func (p *parser) parseMapExpression(token Token) Node { // * identifier, which is equivalent to a string // * expression, which must be enclosed in parentheses -- (1 + 2) if p.current.Is(Number) || p.current.Is(String) || p.current.Is(Identifier) { - key = &StringNode{Value: p.current.Value} - key.SetLocation(token.Location) + key = p.createNode(&StringNode{Value: p.current.Value}, p.current.Location) + if key == nil { + return nil + } p.next() } else if p.current.Is(Bracket, "(") { key = p.parseExpression(0) @@ -575,16 +708,20 @@ func (p *parser) parseMapExpression(token Token) Node { p.expect(Operator, ":") node := p.parseExpression(0) - pair := &PairNode{Key: key, Value: node} - pair.SetLocation(token.Location) + pair := p.createNode(&PairNode{Key: key, Value: node}, token.Location) + if pair == nil { + return nil + } nodes = append(nodes, pair) } end: p.expect(Bracket, "}") - node := &MapNode{Pairs: nodes} - node.SetLocation(token.Location) + node := p.createNode(&MapNode{Pairs: nodes}, token.Location) + if node == nil { + return nil + } return node } @@ -609,8 +746,10 @@ func (p *parser) parsePostfixExpression(node Node) Node { p.error("expected name") } - property := &StringNode{Value: propertyToken.Value} - property.SetLocation(propertyToken.Location) + property := p.createNode(&StringNode{Value: propertyToken.Value}, propertyToken.Location) + if property == nil { + return nil + } chainNode, isChain := node.(*ChainNode) optional := postfixToken.Value == "?." @@ -619,26 +758,33 @@ func (p *parser) parsePostfixExpression(node Node) Node { node = chainNode.Node } - memberNode := &MemberNode{ + memberNode := p.createMemberNode(&MemberNode{ Node: node, Property: property, Optional: optional, + }, propertyToken.Location) + if memberNode == nil { + return nil } - memberNode.SetLocation(propertyToken.Location) if p.current.Is(Bracket, "(") { memberNode.Method = true - node = &CallNode{ + node = p.createNode(&CallNode{ Callee: memberNode, Arguments: p.parseArguments([]Node{}), + }, propertyToken.Location) + if node == nil { + return nil } - node.SetLocation(propertyToken.Location) } else { node = memberNode } if isChain || optional { - node = &ChainNode{Node: node} + node = p.createNode(&ChainNode{Node: node}, propertyToken.Location) + if node == nil { + return nil + } } } else if postfixToken.Value == "[" { @@ -652,11 +798,13 @@ func (p *parser) parsePostfixExpression(node Node) Node { to = p.parseExpression(0) } - node = &SliceNode{ + node = p.createNode(&SliceNode{ Node: node, To: to, + }, postfixToken.Location) + if node == nil { + return nil } - node.SetLocation(postfixToken.Location) p.expect(Bracket, "]") } else { @@ -670,25 +818,32 @@ func (p *parser) parsePostfixExpression(node Node) Node { to = p.parseExpression(0) } - node = &SliceNode{ + node = p.createNode(&SliceNode{ Node: node, From: from, To: to, + }, postfixToken.Location) + if node == nil { + return nil } - node.SetLocation(postfixToken.Location) p.expect(Bracket, "]") } else { // Slice operator [:] was not found, // it should be just an index node. - node = &MemberNode{ + node = p.createNode(&MemberNode{ Node: node, Property: from, Optional: optional, + }, postfixToken.Location) + if node == nil { + return nil } - node.SetLocation(postfixToken.Location) if optional { - node = &ChainNode{Node: node} + node = p.createNode(&ChainNode{Node: node}, postfixToken.Location) + if node == nil { + return nil + } } p.expect(Bracket, "]") } @@ -700,26 +855,29 @@ func (p *parser) parsePostfixExpression(node Node) Node { } return node } - func (p *parser) parseComparison(left Node, token Token, precedence int) Node { var rootNode Node for { comparator := p.parseExpression(precedence + 1) - cmpNode := &BinaryNode{ + cmpNode := p.createNode(&BinaryNode{ Operator: token.Value, Left: left, Right: comparator, + }, token.Location) + if cmpNode == nil { + return nil } - cmpNode.SetLocation(token.Location) if rootNode == nil { rootNode = cmpNode } else { - rootNode = &BinaryNode{ + rootNode = p.createNode(&BinaryNode{ Operator: "&&", Left: rootNode, Right: cmpNode, + }, token.Location) + if rootNode == nil { + return nil } - rootNode.SetLocation(token.Location) } left = comparator diff --git a/vendor/github.com/expr-lang/expr/patcher/operator_override.go b/vendor/github.com/expr-lang/expr/patcher/operator_override.go index 551fe09b..308cbdba 100644 --- a/vendor/github.com/expr-lang/expr/patcher/operator_override.go +++ b/vendor/github.com/expr-lang/expr/patcher/operator_override.go @@ -6,13 +6,14 @@ import ( "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/checker/nature" "github.com/expr-lang/expr/conf" ) type OperatorOverloading struct { Operator string // Operator token to overload. Overloads []string // List of function names to replace operator with. - Types conf.TypesTable // Env types. + Env *nature.Nature // Env type. Functions conf.FunctionsTable // Env functions. applied bool // Flag to indicate if any changes were made to the tree. } @@ -42,6 +43,11 @@ func (p *OperatorOverloading) Visit(node *ast.Node) { } } +// Tracking must be reset before every walk over the AST tree +func (p *OperatorOverloading) Reset() { + p.applied = false +} + func (p *OperatorOverloading) ShouldRepeat() bool { return p.applied } @@ -56,7 +62,7 @@ func (p *OperatorOverloading) FindSuitableOperatorOverload(l, r reflect.Type) (r func (p *OperatorOverloading) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) { for _, fn := range p.Overloads { - fnType, ok := p.Types[fn] + fnType, ok := p.Env.Get(fn) if !ok { continue } @@ -103,7 +109,7 @@ func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex func (p *OperatorOverloading) Check() { for _, fn := range p.Overloads { - fnType, foundType := p.Types[fn] + fnType, foundType := p.Env.Get(fn) fnFunc, foundFunc := p.Functions[fn] if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) { panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator)) @@ -119,7 +125,7 @@ func (p *OperatorOverloading) Check() { } } -func checkType(fnType conf.Tag, fn string, operator string) { +func checkType(fnType nature.Nature, fn string, operator string) { requiredNumIn := 2 if fnType.Method { requiredNumIn = 3 // As first argument of method is receiver. diff --git a/vendor/github.com/expr-lang/expr/types/types.go b/vendor/github.com/expr-lang/expr/types/types.go new file mode 100644 index 00000000..bb1cbe5f --- /dev/null +++ b/vendor/github.com/expr-lang/expr/types/types.go @@ -0,0 +1,181 @@ +package types + +import ( + "fmt" + "reflect" + "strings" + + . "github.com/expr-lang/expr/checker/nature" +) + +// Type is a type that can be used to represent a value. +type Type interface { + Nature() Nature + Equal(Type) bool + String() string +} + +var ( + Int = TypeOf(0) + Int8 = TypeOf(int8(0)) + Int16 = TypeOf(int16(0)) + Int32 = TypeOf(int32(0)) + Int64 = TypeOf(int64(0)) + Uint = TypeOf(uint(0)) + Uint8 = TypeOf(uint8(0)) + Uint16 = TypeOf(uint16(0)) + Uint32 = TypeOf(uint32(0)) + Uint64 = TypeOf(uint64(0)) + Float = TypeOf(float32(0)) + Float64 = TypeOf(float64(0)) + String = TypeOf("") + Bool = TypeOf(true) + Nil = nilType{} + Any = anyType{} +) + +func TypeOf(v any) Type { + if v == nil { + return Nil + } + return rtype{t: reflect.TypeOf(v)} +} + +type anyType struct{} + +func (anyType) Nature() Nature { + return Nature{Type: nil} +} + +func (anyType) Equal(t Type) bool { + return true +} + +func (anyType) String() string { + return "any" +} + +type nilType struct{} + +func (nilType) Nature() Nature { + return Nature{Nil: true} +} + +func (nilType) Equal(t Type) bool { + if t == Any { + return true + } + return t == Nil +} + +func (nilType) String() string { + return "nil" +} + +type rtype struct { + t reflect.Type +} + +func (r rtype) Nature() Nature { + return Nature{Type: r.t} +} + +func (r rtype) Equal(t Type) bool { + if t == Any { + return true + } + if rt, ok := t.(rtype); ok { + return r.t.String() == rt.t.String() + } + return false +} + +func (r rtype) String() string { + return r.t.String() +} + +// Map represents a map[string]any type with defined keys. +type Map map[string]Type + +const Extra = "[[__extra_keys__]]" + +func (m Map) Nature() Nature { + nt := Nature{ + Type: reflect.TypeOf(map[string]any{}), + Fields: make(map[string]Nature, len(m)), + Strict: true, + } + for k, v := range m { + if k == Extra { + nt.Strict = false + natureOfDefaultValue := v.Nature() + nt.DefaultMapValue = &natureOfDefaultValue + continue + } + nt.Fields[k] = v.Nature() + } + return nt +} + +func (m Map) Equal(t Type) bool { + if t == Any { + return true + } + mt, ok := t.(Map) + if !ok { + return false + } + if len(m) != len(mt) { + return false + } + for k, v := range m { + if !v.Equal(mt[k]) { + return false + } + } + return true +} + +func (m Map) String() string { + pairs := make([]string, 0, len(m)) + for k, v := range m { + pairs = append(pairs, fmt.Sprintf("%s: %s", k, v.String())) + } + return fmt.Sprintf("Map{%s}", strings.Join(pairs, ", ")) +} + +// Array returns a type that represents an array of the given type. +func Array(of Type) Type { + return array{of} +} + +type array struct { + of Type +} + +func (a array) Nature() Nature { + of := a.of.Nature() + return Nature{ + Type: reflect.TypeOf([]any{}), + Fields: make(map[string]Nature, 1), + ArrayOf: &of, + } +} + +func (a array) Equal(t Type) bool { + if t == Any { + return true + } + at, ok := t.(array) + if !ok { + return false + } + if a.of.Equal(at.of) { + return true + } + return false +} + +func (a array) String() string { + return fmt.Sprintf("Array{%s}", a.of.String()) +} diff --git a/vendor/github.com/expr-lang/expr/vm/debug.go b/vendor/github.com/expr-lang/expr/vm/debug.go index ab95bf9a..470bf90e 100644 --- a/vendor/github.com/expr-lang/expr/vm/debug.go +++ b/vendor/github.com/expr-lang/expr/vm/debug.go @@ -1,4 +1,5 @@ //go:build expr_debug +// +build expr_debug package vm diff --git a/vendor/github.com/expr-lang/expr/vm/debug_off.go b/vendor/github.com/expr-lang/expr/vm/debug_off.go index e0f2955a..8a9e965e 100644 --- a/vendor/github.com/expr-lang/expr/vm/debug_off.go +++ b/vendor/github.com/expr-lang/expr/vm/debug_off.go @@ -1,4 +1,5 @@ //go:build !expr_debug +// +build !expr_debug package vm diff --git a/vendor/github.com/expr-lang/expr/vm/utils.go b/vendor/github.com/expr-lang/expr/vm/utils.go index fc2f5e7b..11005137 100644 --- a/vendor/github.com/expr-lang/expr/vm/utils.go +++ b/vendor/github.com/expr-lang/expr/vm/utils.go @@ -11,9 +11,6 @@ type ( ) var ( - // MemoryBudget represents an upper limit of memory usage. - MemoryBudget uint = 1e6 - errorType = reflect.TypeOf((*error)(nil)).Elem() ) diff --git a/vendor/github.com/expr-lang/expr/vm/vm.go b/vendor/github.com/expr-lang/expr/vm/vm.go index fa1223b4..de13cade 100644 --- a/vendor/github.com/expr-lang/expr/vm/vm.go +++ b/vendor/github.com/expr-lang/expr/vm/vm.go @@ -11,6 +11,7 @@ import ( "time" "github.com/expr-lang/expr/builtin" + "github.com/expr-lang/expr/conf" "github.com/expr-lang/expr/file" "github.com/expr-lang/expr/internal/deref" "github.com/expr-lang/expr/vm/runtime" @@ -20,7 +21,6 @@ func Run(program *Program, env any) (any, error) { if program == nil { return nil, fmt.Errorf("program is nil") } - vm := VM{} return vm.Run(program, env) } @@ -38,9 +38,9 @@ type VM struct { Stack []any Scopes []*Scope Variables []any + MemoryBudget uint ip int memory uint - memoryBudget uint debug bool step chan struct{} curr chan int @@ -76,7 +76,9 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { vm.Variables = make([]any, program.variables) } - vm.memoryBudget = MemoryBudget + if vm.MemoryBudget == 0 { + vm.MemoryBudget = conf.DefaultMemoryBudget + } vm.memory = 0 vm.ip = 0 @@ -332,10 +334,8 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { in := make([]reflect.Value, size) for i := int(size) - 1; i >= 0; i-- { param := vm.pop() - if param == nil && reflect.TypeOf(param) == nil { - // In case of nil value and nil type use this hack, - // otherwise reflect.Call will panic on zero value. - in[i] = reflect.ValueOf(¶m).Elem() + if param == nil { + in[i] = reflect.Zero(fn.Type().In(i)) } else { in[i] = reflect.ValueOf(param) } @@ -457,7 +457,7 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { case OpDeref: a := vm.pop() - vm.push(deref.Deref(a)) + vm.push(deref.Interface(a)) case OpIncrementIndex: vm.scope().Index++ @@ -599,7 +599,7 @@ func (vm *VM) pop() any { func (vm *VM) memGrow(size uint) { vm.memory += size - if vm.memory >= vm.memoryBudget { + if vm.memory >= vm.MemoryBudget { panic("memory budget exceeded") } } diff --git a/vendor/modules.txt b/vendor/modules.txt index 78b2f7d4..746cc3df 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -208,12 +208,13 @@ github.com/creachadair/jrpc2/code github.com/creachadair/jrpc2/handler github.com/creachadair/jrpc2/metrics github.com/creachadair/jrpc2/server -# github.com/expr-lang/expr v1.16.9 +# github.com/expr-lang/expr v1.17.2 ## explicit; go 1.18 github.com/expr-lang/expr github.com/expr-lang/expr/ast github.com/expr-lang/expr/builtin github.com/expr-lang/expr/checker +github.com/expr-lang/expr/checker/nature github.com/expr-lang/expr/compiler github.com/expr-lang/expr/conf github.com/expr-lang/expr/file @@ -224,6 +225,7 @@ github.com/expr-lang/expr/parser/lexer github.com/expr-lang/expr/parser/operator github.com/expr-lang/expr/parser/utils github.com/expr-lang/expr/patcher +github.com/expr-lang/expr/types github.com/expr-lang/expr/vm github.com/expr-lang/expr/vm/runtime # github.com/fatih/color v1.18.0