| // Copyright 2018 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package interpreter |
| |
| import ( |
| "github.com/google/cel-go/common/ast" |
| "github.com/google/cel-go/common/operators" |
| "github.com/google/cel-go/common/overloads" |
| "github.com/google/cel-go/common/types" |
| "github.com/google/cel-go/common/types/ref" |
| "github.com/google/cel-go/common/types/traits" |
| ) |
| |
| type astPruner struct { |
| ast.ExprFactory |
| expr ast.Expr |
| macroCalls map[int64]ast.Expr |
| state EvalState |
| nextExprID int64 |
| } |
| |
| // TODO Consider having a separate walk of the AST that finds common |
| // subexpressions. This can be called before or after constant folding to find |
| // common subexpressions. |
| |
| // PruneAst prunes the given AST based on the given EvalState and generates a new AST. |
| // Given AST is copied on write and a new AST is returned. |
| // Couple of typical use cases this interface would be: |
| // |
| // A) |
| // 1) Evaluate expr with some unknowns, |
| // 2) If result is unknown: |
| // |
| // a) PruneAst |
| // b) Goto 1 |
| // |
| // Functional call results which are known would be effectively cached across |
| // iterations. |
| // |
| // B) |
| // 1) Compile the expression (maybe via a service and maybe after checking a |
| // |
| // compiled expression does not exists in local cache) |
| // |
| // 2) Prepare the environment and the interpreter. Activation might be empty. |
| // 3) Eval the expression. This might return unknown or error or a concrete |
| // |
| // value. |
| // |
| // 4) PruneAst |
| // 4) Maybe cache the expression |
| // This is effectively constant folding the expression. How the environment is |
| // prepared in step 2 is flexible. For example, If the caller caches the |
| // compiled and constant folded expressions, but is not willing to constant |
| // fold(and thus cache results of) some external calls, then they can prepare |
| // the overloads accordingly. |
| func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST { |
| pruneState := NewEvalState() |
| for _, id := range state.IDs() { |
| v, _ := state.Value(id) |
| pruneState.SetValue(id, v) |
| } |
| pruner := &astPruner{ |
| ExprFactory: ast.NewExprFactory(), |
| expr: expr, |
| macroCalls: macroCalls, |
| state: pruneState, |
| nextExprID: getMaxID(expr)} |
| newExpr, _ := pruner.maybePrune(expr) |
| newInfo := ast.NewSourceInfo(nil) |
| for id, call := range pruner.macroCalls { |
| newInfo.SetMacroCall(id, call) |
| } |
| return ast.NewAST(newExpr, newInfo) |
| } |
| |
| func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) { |
| switch v := val.(type) { |
| case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint, *types.Optional: |
| p.state.SetValue(id, val) |
| return p.NewLiteral(id, val), true |
| case types.Duration: |
| p.state.SetValue(id, val) |
| durationString := v.ConvertToType(types.StringType).(types.String) |
| return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true |
| case types.Timestamp: |
| timestampString := v.ConvertToType(types.StringType).(types.String) |
| return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true |
| } |
| |
| // Attempt to build a list literal. |
| if list, isList := val.(traits.Lister); isList { |
| sz := list.Size().(types.Int) |
| elemExprs := make([]ast.Expr, sz) |
| for i := types.Int(0); i < sz; i++ { |
| elem := list.Get(i) |
| if types.IsUnknownOrError(elem) { |
| return nil, false |
| } |
| elemExpr, ok := p.maybeCreateLiteral(p.nextID(), elem) |
| if !ok { |
| return nil, false |
| } |
| elemExprs[i] = elemExpr |
| } |
| p.state.SetValue(id, val) |
| return p.NewList(id, elemExprs, []int32{}), true |
| } |
| |
| // Create a map literal if possible. |
| if mp, isMap := val.(traits.Mapper); isMap { |
| it := mp.Iterator() |
| entries := make([]ast.EntryExpr, mp.Size().(types.Int)) |
| i := 0 |
| for it.HasNext() != types.False { |
| key := it.Next() |
| val := mp.Get(key) |
| if types.IsUnknownOrError(key) || types.IsUnknownOrError(val) { |
| return nil, false |
| } |
| keyExpr, ok := p.maybeCreateLiteral(p.nextID(), key) |
| if !ok { |
| return nil, false |
| } |
| valExpr, ok := p.maybeCreateLiteral(p.nextID(), val) |
| if !ok { |
| return nil, false |
| } |
| entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false) |
| entries[i] = entry |
| i++ |
| } |
| p.state.SetValue(id, val) |
| return p.NewMap(id, entries), true |
| } |
| |
| // TODO(issues/377) To construct message literals, the type provider will need to support |
| // the enumeration the fields for a given message. |
| return nil, false |
| } |
| |
| func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) { |
| elemVal, found := p.value(elem.ID()) |
| if found && elemVal.Type() == types.OptionalType { |
| opt := elemVal.(*types.Optional) |
| if !opt.HasValue() { |
| return nil, true |
| } |
| if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned { |
| return newElem, true |
| } |
| } |
| return elem, false |
| } |
| |
| func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) { |
| // elem in list |
| call := node.AsCall() |
| val, exists := p.maybeValue(call.Args()[1].ID()) |
| if !exists { |
| return nil, false |
| } |
| if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { |
| return p.maybeCreateLiteral(node.ID(), types.False) |
| } |
| return nil, false |
| } |
| |
| func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) { |
| call := node.AsCall() |
| arg := call.Args()[0] |
| val, exists := p.maybeValue(arg.ID()) |
| if !exists { |
| return nil, false |
| } |
| if b, ok := val.(types.Bool); ok { |
| return p.maybeCreateLiteral(node.ID(), !b) |
| } |
| return nil, false |
| } |
| |
| func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) { |
| call := node.AsCall() |
| // We know result is unknown, so we have at least one unknown arg |
| // and if one side is a known value, we know we can ignore it. |
| if v, exists := p.maybeValue(call.Args()[0].ID()); exists { |
| if v == types.True { |
| return p.maybeCreateLiteral(node.ID(), types.True) |
| } |
| return call.Args()[1], true |
| } |
| if v, exists := p.maybeValue(call.Args()[1].ID()); exists { |
| if v == types.True { |
| return p.maybeCreateLiteral(node.ID(), types.True) |
| } |
| return call.Args()[0], true |
| } |
| return nil, false |
| } |
| |
| func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) { |
| call := node.AsCall() |
| // We know result is unknown, so we have at least one unknown arg |
| // and if one side is a known value, we know we can ignore it. |
| if v, exists := p.maybeValue(call.Args()[0].ID()); exists { |
| if v == types.False { |
| return p.maybeCreateLiteral(node.ID(), types.False) |
| } |
| return call.Args()[1], true |
| } |
| if v, exists := p.maybeValue(call.Args()[1].ID()); exists { |
| if v == types.False { |
| return p.maybeCreateLiteral(node.ID(), types.False) |
| } |
| return call.Args()[0], true |
| } |
| return nil, false |
| } |
| |
| func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) { |
| call := node.AsCall() |
| cond, exists := p.maybeValue(call.Args()[0].ID()) |
| if !exists { |
| return nil, false |
| } |
| if cond.Value().(bool) { |
| return call.Args()[1], true |
| } |
| return call.Args()[2], true |
| } |
| |
| func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) { |
| if _, exists := p.value(node.ID()); !exists { |
| return nil, false |
| } |
| call := node.AsCall() |
| if call.FunctionName() == operators.LogicalOr { |
| return p.maybePruneOr(node) |
| } |
| if call.FunctionName() == operators.LogicalAnd { |
| return p.maybePruneAnd(node) |
| } |
| if call.FunctionName() == operators.Conditional { |
| return p.maybePruneConditional(node) |
| } |
| if call.FunctionName() == operators.In { |
| return p.maybePruneIn(node) |
| } |
| if call.FunctionName() == operators.LogicalNot { |
| return p.maybePruneLogicalNot(node) |
| } |
| return nil, false |
| } |
| |
| func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) { |
| return p.prune(node) |
| } |
| |
| func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) { |
| if node == nil { |
| return node, false |
| } |
| val, valueExists := p.maybeValue(node.ID()) |
| if valueExists { |
| if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok { |
| delete(p.macroCalls, node.ID()) |
| return newNode, true |
| } |
| } |
| if macro, found := p.macroCalls[node.ID()]; found { |
| // Ensure that intermediate values for the comprehension are cleared during pruning |
| pruneMacroCall := node.Kind() != ast.UnspecifiedExprKind |
| if node.Kind() == ast.ComprehensionKind { |
| // Only prune cel.bind() calls since the variables of the comprehension are all |
| // visible to the user, so there's no chance of an incorrect value being observed |
| // as a result of looking at intermediate computations within a comprehension. |
| pruneMacroCall = isCelBindMacro(macro) |
| } |
| if pruneMacroCall { |
| // prune the expression in terms of the macro call instead of the expanded form when |
| // dealing with macro call tracking references. |
| if newMacro, pruned := p.prune(macro); pruned { |
| p.macroCalls[node.ID()] = newMacro |
| } |
| } else { |
| // Otherwise just prune the macro target in keeping with the pruning behavior of the |
| // comprehensions later in the call graph. |
| macroCall := macro.AsCall() |
| if macroCall.Target() != nil { |
| if newTarget, pruned := p.prune(macroCall.Target()); pruned { |
| macro = p.NewMemberCall(macro.ID(), macroCall.FunctionName(), newTarget, macroCall.Args()...) |
| p.macroCalls[node.ID()] = macro |
| } |
| } |
| } |
| } |
| |
| // We have either an unknown/error value, or something we don't want to |
| // transform, or expression was not evaluated. If possible, drill down |
| // more. |
| switch node.Kind() { |
| case ast.SelectKind: |
| sel := node.AsSelect() |
| if operand, isPruned := p.maybePrune(sel.Operand()); isPruned { |
| if sel.IsTestOnly() { |
| return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true |
| } |
| return p.NewSelect(node.ID(), operand, sel.FieldName()), true |
| } |
| case ast.CallKind: |
| argsPruned := false |
| call := node.AsCall() |
| args := call.Args() |
| newArgs := make([]ast.Expr, len(args)) |
| for i, a := range args { |
| newArgs[i] = a |
| if arg, isPruned := p.maybePrune(a); isPruned { |
| argsPruned = true |
| newArgs[i] = arg |
| } |
| } |
| if !call.IsMemberFunction() { |
| newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...) |
| if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { |
| return prunedCall, true |
| } |
| return newCall, argsPruned |
| } |
| newTarget := call.Target() |
| targetPruned := false |
| if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned { |
| targetPruned = true |
| newTarget = prunedTarget |
| } |
| newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...) |
| if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { |
| return prunedCall, true |
| } |
| return newCall, targetPruned || argsPruned |
| case ast.ListKind: |
| l := node.AsList() |
| elems := l.Elements() |
| optIndices := l.OptionalIndices() |
| optIndexMap := map[int32]bool{} |
| for _, i := range optIndices { |
| optIndexMap[i] = true |
| } |
| newOptIndexMap := make(map[int32]bool, len(optIndexMap)) |
| newElems := make([]ast.Expr, 0, len(elems)) |
| var listPruned bool |
| prunedIdx := 0 |
| for i, elem := range elems { |
| _, isOpt := optIndexMap[int32(i)] |
| if isOpt { |
| newElem, pruned := p.maybePruneOptional(elem) |
| if pruned { |
| listPruned = true |
| if newElem != nil { |
| newElems = append(newElems, newElem) |
| prunedIdx++ |
| } |
| continue |
| } |
| newOptIndexMap[int32(prunedIdx)] = true |
| } |
| if newElem, prunedElem := p.maybePrune(elem); prunedElem { |
| newElems = append(newElems, newElem) |
| listPruned = true |
| } else { |
| newElems = append(newElems, elem) |
| } |
| prunedIdx++ |
| } |
| optIndices = make([]int32, len(newOptIndexMap)) |
| idx := 0 |
| for i := range newOptIndexMap { |
| optIndices[idx] = i |
| idx++ |
| } |
| if listPruned { |
| return p.NewList(node.ID(), newElems, optIndices), true |
| } |
| case ast.MapKind: |
| var mapPruned bool |
| m := node.AsMap() |
| entries := m.Entries() |
| newEntries := make([]ast.EntryExpr, len(entries)) |
| for i, entry := range entries { |
| newEntries[i] = entry |
| e := entry.AsMapEntry() |
| newKey, keyPruned := p.maybePrune(e.Key()) |
| newValue, valuePruned := p.maybePrune(e.Value()) |
| if !keyPruned && !valuePruned { |
| continue |
| } |
| mapPruned = true |
| newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional()) |
| newEntries[i] = newEntry |
| } |
| if mapPruned { |
| return p.NewMap(node.ID(), newEntries), true |
| } |
| case ast.StructKind: |
| var structPruned bool |
| obj := node.AsStruct() |
| fields := obj.Fields() |
| newFields := make([]ast.EntryExpr, len(fields)) |
| for i, field := range fields { |
| newFields[i] = field |
| f := field.AsStructField() |
| newValue, prunedValue := p.maybePrune(f.Value()) |
| if !prunedValue { |
| continue |
| } |
| structPruned = true |
| newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional()) |
| newFields[i] = newEntry |
| } |
| if structPruned { |
| return p.NewStruct(node.ID(), obj.TypeName(), newFields), true |
| } |
| case ast.ComprehensionKind: |
| compre := node.AsComprehension() |
| // Only the range of the comprehension is pruned since the state tracking only records |
| // the last iteration of the comprehension and not each step in the evaluation which |
| // means that the any residuals computed in between might be inaccurate. |
| if newRange, pruned := p.maybePrune(compre.IterRange()); pruned { |
| if compre.HasIterVar2() { |
| return p.NewComprehensionTwoVar( |
| node.ID(), |
| newRange, |
| compre.IterVar(), |
| compre.IterVar2(), |
| compre.AccuVar(), |
| compre.AccuInit(), |
| compre.LoopCondition(), |
| compre.LoopStep(), |
| compre.Result(), |
| ), true |
| } |
| return p.NewComprehension( |
| node.ID(), |
| newRange, |
| compre.IterVar(), |
| compre.AccuVar(), |
| compre.AccuInit(), |
| compre.LoopCondition(), |
| compre.LoopStep(), |
| compre.Result(), |
| ), true |
| } |
| } |
| return node, false |
| } |
| |
| func (p *astPruner) value(id int64) (ref.Val, bool) { |
| val, found := p.state.Value(id) |
| return val, (found && val != nil) |
| } |
| |
| func (p *astPruner) maybeValue(id int64) (ref.Val, bool) { |
| val, found := p.value(id) |
| if !found || types.IsUnknownOrError(val) { |
| return nil, false |
| } |
| return val, true |
| } |
| |
| func (p *astPruner) nextID() int64 { |
| next := p.nextExprID |
| p.nextExprID++ |
| return next |
| } |
| |
| type astVisitor struct { |
| // visitEntry is called on every expr node, including those within a map/struct entry. |
| visitExpr func(expr ast.Expr) |
| // visitEntry is called before entering the key, value of a map/struct entry. |
| visitEntry func(entry ast.EntryExpr) |
| } |
| |
| func getMaxID(expr ast.Expr) int64 { |
| maxID := int64(1) |
| visit(expr, maxIDVisitor(&maxID)) |
| return maxID |
| } |
| |
| func maxIDVisitor(maxID *int64) astVisitor { |
| return astVisitor{ |
| visitExpr: func(e ast.Expr) { |
| if e.ID() >= *maxID { |
| *maxID = e.ID() + 1 |
| } |
| }, |
| visitEntry: func(e ast.EntryExpr) { |
| if e.ID() >= *maxID { |
| *maxID = e.ID() + 1 |
| } |
| }, |
| } |
| } |
| |
| func visit(expr ast.Expr, visitor astVisitor) { |
| exprs := []ast.Expr{expr} |
| for len(exprs) != 0 { |
| e := exprs[0] |
| if visitor.visitExpr != nil { |
| visitor.visitExpr(e) |
| } |
| exprs = exprs[1:] |
| switch e.Kind() { |
| case ast.SelectKind: |
| exprs = append(exprs, e.AsSelect().Operand()) |
| case ast.CallKind: |
| call := e.AsCall() |
| if call.Target() != nil { |
| exprs = append(exprs, call.Target()) |
| } |
| exprs = append(exprs, call.Args()...) |
| case ast.ComprehensionKind: |
| compre := e.AsComprehension() |
| exprs = append(exprs, |
| compre.IterRange(), |
| compre.AccuInit(), |
| compre.LoopCondition(), |
| compre.LoopStep(), |
| compre.Result()) |
| case ast.ListKind: |
| list := e.AsList() |
| exprs = append(exprs, list.Elements()...) |
| case ast.MapKind: |
| for _, entry := range e.AsMap().Entries() { |
| e := entry.AsMapEntry() |
| if visitor.visitEntry != nil { |
| visitor.visitEntry(entry) |
| } |
| exprs = append(exprs, e.Key()) |
| exprs = append(exprs, e.Value()) |
| } |
| case ast.StructKind: |
| for _, entry := range e.AsStruct().Fields() { |
| f := entry.AsStructField() |
| if visitor.visitEntry != nil { |
| visitor.visitEntry(entry) |
| } |
| exprs = append(exprs, f.Value()) |
| } |
| } |
| } |
| } |
| |
| func isCelBindMacro(macro ast.Expr) bool { |
| if macro.Kind() != ast.CallKind { |
| return false |
| } |
| macroCall := macro.AsCall() |
| target := macroCall.Target() |
| return macroCall.FunctionName() == "bind" && |
| macroCall.IsMemberFunction() && |
| target.Kind() == ast.IdentKind && |
| target.AsIdent() == "cel" |
| } |