| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package ext |
| |
| import ( |
| "errors" |
| "fmt" |
| "math" |
| "sort" |
| "strconv" |
| "strings" |
| "time" |
| "unicode" |
| |
| "github.com/google/cel-go/cel" |
| "github.com/google/cel-go/common/ast" |
| "github.com/google/cel-go/common/types" |
| "github.com/google/cel-go/common/types/ref" |
| "github.com/google/cel-go/common/types/traits" |
| ) |
| |
| type clauseImplV2 func(ref.Val) (string, error) |
| |
| type appendingFormatterV2 struct { |
| buf []byte |
| } |
| |
| type formattedMapEntryV2 struct { |
| key string |
| val string |
| } |
| |
| func (af *appendingFormatterV2) format(arg ref.Val) error { |
| switch arg.Type() { |
| case types.BoolType: |
| argBool, ok := arg.Value().(bool) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BoolType) |
| } |
| af.buf = strconv.AppendBool(af.buf, argBool) |
| return nil |
| case types.IntType: |
| argInt, ok := arg.Value().(int64) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) |
| } |
| af.buf = strconv.AppendInt(af.buf, argInt, 10) |
| return nil |
| case types.UintType: |
| argUint, ok := arg.Value().(uint64) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) |
| } |
| af.buf = strconv.AppendUint(af.buf, argUint, 10) |
| return nil |
| case types.DoubleType: |
| argDbl, ok := arg.Value().(float64) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) |
| } |
| if math.IsNaN(argDbl) { |
| af.buf = append(af.buf, "NaN"...) |
| return nil |
| } |
| if math.IsInf(argDbl, -1) { |
| af.buf = append(af.buf, "-Infinity"...) |
| return nil |
| } |
| if math.IsInf(argDbl, 1) { |
| af.buf = append(af.buf, "Infinity"...) |
| return nil |
| } |
| af.buf = strconv.AppendFloat(af.buf, argDbl, 'f', -1, 64) |
| return nil |
| case types.BytesType: |
| argBytes, ok := arg.Value().([]byte) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BytesType) |
| } |
| af.buf = append(af.buf, argBytes...) |
| return nil |
| case types.StringType: |
| argStr, ok := arg.Value().(string) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.StringType) |
| } |
| af.buf = append(af.buf, argStr...) |
| return nil |
| case types.DurationType: |
| argDur, ok := arg.Value().(time.Duration) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DurationType) |
| } |
| af.buf = strconv.AppendFloat(af.buf, argDur.Seconds(), 'f', -1, 64) |
| af.buf = append(af.buf, "s"...) |
| return nil |
| case types.TimestampType: |
| argTime, ok := arg.Value().(time.Time) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.TimestampType) |
| } |
| af.buf = argTime.UTC().AppendFormat(af.buf, time.RFC3339Nano) |
| return nil |
| case types.NullType: |
| af.buf = append(af.buf, "null"...) |
| return nil |
| case types.TypeType: |
| argType, ok := arg.Value().(string) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.TypeType) |
| } |
| af.buf = append(af.buf, argType...) |
| return nil |
| case types.ListType: |
| argList, ok := arg.(traits.Lister) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.ListType) |
| } |
| argIter := argList.Iterator() |
| af.buf = append(af.buf, "["...) |
| if argIter.HasNext() == types.True { |
| if err := af.format(argIter.Next()); err != nil { |
| return err |
| } |
| for argIter.HasNext() == types.True { |
| af.buf = append(af.buf, ", "...) |
| if err := af.format(argIter.Next()); err != nil { |
| return err |
| } |
| } |
| } |
| af.buf = append(af.buf, "]"...) |
| return nil |
| case types.MapType: |
| argMap, ok := arg.(traits.Mapper) |
| if !ok { |
| return fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.MapType) |
| } |
| argIter := argMap.Iterator() |
| ents := []formattedMapEntryV2{} |
| for argIter.HasNext() == types.True { |
| key := argIter.Next() |
| val, ok := argMap.Find(key) |
| if !ok { |
| return fmt.Errorf("key missing from map: '%s'", key) |
| } |
| keyStr, err := formatStringV2(key) |
| if err != nil { |
| return err |
| } |
| valStr, err := formatStringV2(val) |
| if err != nil { |
| return err |
| } |
| ents = append(ents, formattedMapEntryV2{keyStr, valStr}) |
| } |
| sort.SliceStable(ents, func(x, y int) bool { |
| return ents[x].key < ents[y].key |
| }) |
| af.buf = append(af.buf, "{"...) |
| for i, e := range ents { |
| if i > 0 { |
| af.buf = append(af.buf, ", "...) |
| } |
| af.buf = append(af.buf, e.key...) |
| af.buf = append(af.buf, ": "...) |
| af.buf = append(af.buf, e.val...) |
| } |
| af.buf = append(af.buf, "}"...) |
| return nil |
| default: |
| return stringFormatErrorV2(runtimeID, arg.Type().TypeName()) |
| } |
| } |
| |
| func formatStringV2(arg ref.Val) (string, error) { |
| var fmter appendingFormatterV2 |
| if err := fmter.format(arg); err != nil { |
| return "", err |
| } |
| return string(fmter.buf), nil |
| } |
| |
| type stringFormatterV2 struct{} |
| |
| // String implements formatStringInterpolatorV2.String. |
| func (c *stringFormatterV2) String(arg ref.Val) (string, error) { |
| return formatStringV2(arg) |
| } |
| |
| // Decimal implements formatStringInterpolatorV2.Decimal. |
| func (c *stringFormatterV2) Decimal(arg ref.Val) (string, error) { |
| switch arg.Type() { |
| case types.IntType: |
| argInt, ok := arg.Value().(int64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) |
| } |
| return strconv.FormatInt(argInt, 10), nil |
| case types.UintType: |
| argUint, ok := arg.Value().(uint64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) |
| } |
| return strconv.FormatUint(argUint, 10), nil |
| case types.DoubleType: |
| argDbl, ok := arg.Value().(float64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) |
| } |
| if math.IsNaN(argDbl) { |
| return "NaN", nil |
| } |
| if math.IsInf(argDbl, -1) { |
| return "-Infinity", nil |
| } |
| if math.IsInf(argDbl, 1) { |
| return "Infinity", nil |
| } |
| return strconv.FormatFloat(argDbl, 'f', -1, 64), nil |
| default: |
| return "", decimalFormatErrorV2(runtimeID, arg.Type().TypeName()) |
| } |
| } |
| |
| // Fixed implements formatStringInterpolatorV2.Fixed. |
| func (c *stringFormatterV2) Fixed(precision int) func(ref.Val) (string, error) { |
| return func(arg ref.Val) (string, error) { |
| fmtStr := fmt.Sprintf("%%.%df", precision) |
| switch arg.Type() { |
| case types.IntType: |
| argInt, ok := arg.Value().(int64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) |
| } |
| return fmt.Sprintf(fmtStr, float64(argInt)), nil |
| case types.UintType: |
| argUint, ok := arg.Value().(uint64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) |
| } |
| return fmt.Sprintf(fmtStr, float64(argUint)), nil |
| case types.DoubleType: |
| argDbl, ok := arg.Value().(float64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) |
| } |
| if math.IsNaN(argDbl) { |
| return "NaN", nil |
| } |
| if math.IsInf(argDbl, -1) { |
| return "-Infinity", nil |
| } |
| if math.IsInf(argDbl, 1) { |
| return "Infinity", nil |
| } |
| return fmt.Sprintf(fmtStr, argDbl), nil |
| default: |
| return "", fixedPointFormatErrorV2(runtimeID, arg.Type().TypeName()) |
| } |
| } |
| } |
| |
| // Scientific implements formatStringInterpolatorV2.Scientific. |
| func (c *stringFormatterV2) Scientific(precision int) func(ref.Val) (string, error) { |
| return func(arg ref.Val) (string, error) { |
| fmtStr := fmt.Sprintf("%%1.%de", precision) |
| switch arg.Type() { |
| case types.IntType: |
| argInt, ok := arg.Value().(int64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) |
| } |
| return fmt.Sprintf(fmtStr, float64(argInt)), nil |
| case types.UintType: |
| argUint, ok := arg.Value().(uint64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) |
| } |
| return fmt.Sprintf(fmtStr, float64(argUint)), nil |
| case types.DoubleType: |
| argDbl, ok := arg.Value().(float64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.DoubleType) |
| } |
| if math.IsNaN(argDbl) { |
| return "NaN", nil |
| } |
| if math.IsInf(argDbl, -1) { |
| return "-Infinity", nil |
| } |
| if math.IsInf(argDbl, 1) { |
| return "Infinity", nil |
| } |
| return fmt.Sprintf(fmtStr, argDbl), nil |
| default: |
| return "", scientificFormatErrorV2(runtimeID, arg.Type().TypeName()) |
| } |
| } |
| } |
| |
| // Binary implements formatStringInterpolatorV2.Binary. |
| func (c *stringFormatterV2) Binary(arg ref.Val) (string, error) { |
| switch arg.Type() { |
| case types.BoolType: |
| argBool, ok := arg.Value().(bool) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BoolType) |
| } |
| if argBool { |
| return "1", nil |
| } |
| return "0", nil |
| case types.IntType: |
| argInt, ok := arg.Value().(int64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) |
| } |
| return strconv.FormatInt(argInt, 2), nil |
| case types.UintType: |
| argUint, ok := arg.Value().(uint64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) |
| } |
| return strconv.FormatUint(argUint, 2), nil |
| default: |
| return "", binaryFormatErrorV2(runtimeID, arg.Type().TypeName()) |
| } |
| } |
| |
| // Hex implements formatStringInterpolatorV2.Hex. |
| func (c *stringFormatterV2) Hex(useUpper bool) func(ref.Val) (string, error) { |
| return func(arg ref.Val) (string, error) { |
| var fmtStr string |
| if useUpper { |
| fmtStr = "%X" |
| } else { |
| fmtStr = "%x" |
| } |
| switch arg.Type() { |
| case types.IntType: |
| argInt, ok := arg.Value().(int64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) |
| } |
| return fmt.Sprintf(fmtStr, argInt), nil |
| case types.UintType: |
| argUint, ok := arg.Value().(uint64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) |
| } |
| return fmt.Sprintf(fmtStr, argUint), nil |
| case types.StringType: |
| argStr, ok := arg.Value().(string) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.StringType) |
| } |
| return fmt.Sprintf(fmtStr, argStr), nil |
| case types.BytesType: |
| argBytes, ok := arg.Value().([]byte) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.BytesType) |
| } |
| return fmt.Sprintf(fmtStr, argBytes), nil |
| default: |
| return "", hexFormatErrorV2(runtimeID, arg.Type().TypeName()) |
| } |
| } |
| } |
| |
| // Octal implements formatStringInterpolatorV2.Octal. |
| func (c *stringFormatterV2) Octal(arg ref.Val) (string, error) { |
| switch arg.Type() { |
| case types.IntType: |
| argInt, ok := arg.Value().(int64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.IntType) |
| } |
| return strconv.FormatInt(argInt, 8), nil |
| case types.UintType: |
| argUint, ok := arg.Value().(uint64) |
| if !ok { |
| return "", fmt.Errorf("type conversion error from '%s' to '%s'", arg.Type(), types.UintType) |
| } |
| return strconv.FormatUint(argUint, 8), nil |
| default: |
| return "", octalFormatErrorV2(runtimeID, arg.Type().TypeName()) |
| } |
| } |
| |
| // stringFormatValidatorV2 implements the cel.ASTValidator interface allowing for static validation |
| // of string.format calls. |
| type stringFormatValidatorV2 struct { |
| maxPrecision int |
| } |
| |
| // Name returns the name of the validator. |
| func (stringFormatValidatorV2) Name() string { |
| return "cel.validator.string_format" |
| } |
| |
| // Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip |
| // during homogeneous aggregate literal type-checks. |
| func (stringFormatValidatorV2) Configure(config cel.MutableValidatorConfig) error { |
| functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string) |
| functions = append(functions, "format") |
| return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions) |
| } |
| |
| // Validate parses all literal format strings and type checks the format clause against the argument |
| // at the corresponding ordinal within the list literal argument to the function, if one is specified. |
| func (v stringFormatValidatorV2) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) { |
| root := ast.NavigateAST(a) |
| formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a)) |
| for _, e := range formatCallExprs { |
| call := e.AsCall() |
| formatStr := call.Target().AsLiteral().Value().(string) |
| args := call.Args()[0].AsList().Elements() |
| formatCheck := &stringFormatCheckerV2{ |
| args: args, |
| ast: a, |
| } |
| // use a placeholder locale, since locale doesn't affect syntax |
| _, err := parseFormatStringV2(formatStr, formatCheck, formatCheck, v.maxPrecision) |
| if err != nil { |
| iss.ReportErrorAtID(getErrorExprID(e.ID(), err), "%v", err) |
| continue |
| } |
| seenArgs := formatCheck.argsRequested |
| if len(args) > seenArgs { |
| iss.ReportErrorAtID(e.ID(), |
| "too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(args)) |
| } |
| } |
| } |
| |
| // stringFormatCheckerV2 implements the formatStringInterpolater interface |
| type stringFormatCheckerV2 struct { |
| args []ast.Expr |
| argsRequested int |
| currArgIndex int64 |
| ast *ast.AST |
| } |
| |
| // String implements formatStringInterpolatorV2.String. |
| func (c *stringFormatCheckerV2) String(arg ref.Val) (string, error) { |
| formatArg := c.args[c.currArgIndex] |
| valid, badID := c.verifyString(formatArg) |
| if !valid { |
| return "", stringFormatErrorV2(badID, c.typeOf(badID).TypeName()) |
| } |
| return "", nil |
| } |
| |
| // Decimal implements formatStringInterpolatorV2.Decimal. |
| func (c *stringFormatCheckerV2) Decimal(arg ref.Val) (string, error) { |
| id := c.args[c.currArgIndex].ID() |
| valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) |
| if !valid { |
| return "", decimalFormatErrorV2(id, c.typeOf(id).TypeName()) |
| } |
| return "", nil |
| } |
| |
| // Fixed implements formatStringInterpolatorV2.Fixed. |
| func (c *stringFormatCheckerV2) Fixed(precision int) func(ref.Val) (string, error) { |
| return func(arg ref.Val) (string, error) { |
| id := c.args[c.currArgIndex].ID() |
| valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) |
| if !valid { |
| return "", fixedPointFormatErrorV2(id, c.typeOf(id).TypeName()) |
| } |
| return "", nil |
| } |
| } |
| |
| // Scientific implements formatStringInterpolatorV2.Scientific. |
| func (c *stringFormatCheckerV2) Scientific(precision int) func(ref.Val) (string, error) { |
| return func(arg ref.Val) (string, error) { |
| id := c.args[c.currArgIndex].ID() |
| valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.DoubleType) |
| if !valid { |
| return "", scientificFormatErrorV2(id, c.typeOf(id).TypeName()) |
| } |
| return "", nil |
| } |
| } |
| |
| // Binary implements formatStringInterpolatorV2.Binary. |
| func (c *stringFormatCheckerV2) Binary(arg ref.Val) (string, error) { |
| id := c.args[c.currArgIndex].ID() |
| valid := c.verifyTypeOneOf(id, types.BoolType, types.IntType, types.UintType) |
| if !valid { |
| return "", binaryFormatErrorV2(id, c.typeOf(id).TypeName()) |
| } |
| return "", nil |
| } |
| |
| // Hex implements formatStringInterpolatorV2.Hex. |
| func (c *stringFormatCheckerV2) Hex(useUpper bool) func(ref.Val) (string, error) { |
| return func(arg ref.Val) (string, error) { |
| id := c.args[c.currArgIndex].ID() |
| valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.StringType, types.BytesType) |
| if !valid { |
| return "", hexFormatErrorV2(id, c.typeOf(id).TypeName()) |
| } |
| return "", nil |
| } |
| } |
| |
| // Octal implements formatStringInterpolatorV2.Octal. |
| func (c *stringFormatCheckerV2) Octal(arg ref.Val) (string, error) { |
| id := c.args[c.currArgIndex].ID() |
| valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) |
| if !valid { |
| return "", octalFormatErrorV2(id, c.typeOf(id).TypeName()) |
| } |
| return "", nil |
| } |
| |
| // Arg implements formatListArgs.Arg. |
| func (c *stringFormatCheckerV2) Arg(index int64) (ref.Val, error) { |
| c.argsRequested++ |
| c.currArgIndex = index |
| // return a dummy value - this is immediately passed to back to us |
| // through one of the FormatCallback functions, so anything will do |
| return types.Int(0), nil |
| } |
| |
| // Size implements formatListArgs.Size. |
| func (c *stringFormatCheckerV2) Size() int64 { |
| return int64(len(c.args)) |
| } |
| |
| func (c *stringFormatCheckerV2) typeOf(id int64) *cel.Type { |
| return c.ast.GetType(id) |
| } |
| |
| func (c *stringFormatCheckerV2) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool { |
| t := c.typeOf(id) |
| if t == cel.DynType { |
| return true |
| } |
| for _, vt := range validTypes { |
| // Only check runtime type compatibility without delving deeper into parameterized types |
| if t.Kind() == vt.Kind() { |
| return true |
| } |
| } |
| return false |
| } |
| |
| func (c *stringFormatCheckerV2) verifyString(sub ast.Expr) (bool, int64) { |
| paramA := cel.TypeParamType("A") |
| paramB := cel.TypeParamType("B") |
| subVerified := c.verifyTypeOneOf(sub.ID(), |
| cel.ListType(paramA), cel.MapType(paramA, paramB), |
| cel.IntType, cel.UintType, cel.DoubleType, cel.BoolType, cel.StringType, |
| cel.TimestampType, cel.BytesType, cel.DurationType, cel.TypeType, cel.NullType) |
| if !subVerified { |
| return false, sub.ID() |
| } |
| switch sub.Kind() { |
| case ast.ListKind: |
| for _, e := range sub.AsList().Elements() { |
| // recursively verify if we're dealing with a list/map |
| verified, id := c.verifyString(e) |
| if !verified { |
| return false, id |
| } |
| } |
| return true, sub.ID() |
| case ast.MapKind: |
| for _, e := range sub.AsMap().Entries() { |
| // recursively verify if we're dealing with a list/map |
| entry := e.AsMapEntry() |
| verified, id := c.verifyString(entry.Key()) |
| if !verified { |
| return false, id |
| } |
| verified, id = c.verifyString(entry.Value()) |
| if !verified { |
| return false, id |
| } |
| } |
| return true, sub.ID() |
| default: |
| return true, sub.ID() |
| } |
| } |
| |
| // helper routines for reporting common errors during string formatting static validation and |
| // runtime execution. |
| |
| func binaryFormatErrorV2(id int64, badType string) error { |
| return newFormatError(id, "only ints, uints, and bools can be formatted as binary, was given %s", badType) |
| } |
| |
| func decimalFormatErrorV2(id int64, badType string) error { |
| return newFormatError(id, "decimal clause can only be used on ints, uints, and doubles, was given %s", badType) |
| } |
| |
| func fixedPointFormatErrorV2(id int64, badType string) error { |
| return newFormatError(id, "fixed-point clause can only be used on ints, uints, and doubles, was given %s", badType) |
| } |
| |
| func hexFormatErrorV2(id int64, badType string) error { |
| return newFormatError(id, "only ints, uints, bytes, and strings can be formatted as hex, was given %s", badType) |
| } |
| |
| func octalFormatErrorV2(id int64, badType string) error { |
| return newFormatError(id, "octal clause can only be used on ints and uints, was given %s", badType) |
| } |
| |
| func scientificFormatErrorV2(id int64, badType string) error { |
| return newFormatError(id, "scientific clause can only be used on ints, uints, and doubles, was given %s", badType) |
| } |
| |
| func stringFormatErrorV2(id int64, badType string) error { |
| return newFormatError(id, "string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", badType) |
| } |
| |
| // formatStringInterpolatorV2 is an interface that allows user-defined behavior |
| // for formatting clause implementations, as well as argument retrieval. |
| // Each function is expected to support the appropriate types as laid out in |
| // the string.format documentation, and to return an error if given an inappropriate type. |
| type formatStringInterpolatorV2 interface { |
| // String takes a ref.Val and a string representing the current locale identifier |
| // and returns the Val formatted as a string, or an error if one occurred. |
| String(ref.Val) (string, error) |
| |
| // Decimal takes a ref.Val and a string representing the current locale identifier |
| // and returns the Val formatted as a decimal integer, or an error if one occurred. |
| Decimal(ref.Val) (string, error) |
| |
| // Fixed takes an int pointer representing precision (or nil if none was given) and |
| // returns a function operating in a similar manner to String and Decimal, taking a |
| // ref.Val and locale and returning the appropriate string. A closure is returned |
| // so precision can be set without needing an additional function call/configuration. |
| Fixed(int) func(ref.Val) (string, error) |
| |
| // Scientific functions identically to Fixed, except the string returned from the closure |
| // is expected to be in scientific notation. |
| Scientific(int) func(ref.Val) (string, error) |
| |
| // Binary takes a ref.Val and a string representing the current locale identifier |
| // and returns the Val formatted as a binary integer, or an error if one occurred. |
| Binary(ref.Val) (string, error) |
| |
| // Hex takes a boolean that, if true, indicates the hex string output by the returned |
| // closure should use uppercase letters for A-F. |
| Hex(bool) func(ref.Val) (string, error) |
| |
| // Octal takes a ref.Val and a string representing the current locale identifier and |
| // returns the Val formatted in octal, or an error if one occurred. |
| Octal(ref.Val) (string, error) |
| } |
| |
| // parseFormatString formats a string according to the string.format syntax, taking the clause implementations |
| // from the provided FormatCallback and the args from the given FormatList. |
| func parseFormatStringV2(formatStr string, callback formatStringInterpolatorV2, list formatListArgs, maxPrecision int) (string, error) { |
| i := 0 |
| argIndex := 0 |
| var builtStr strings.Builder |
| for i < len(formatStr) { |
| if formatStr[i] == '%' { |
| if i+1 < len(formatStr) && formatStr[i+1] == '%' { |
| err := builtStr.WriteByte('%') |
| if err != nil { |
| return "", fmt.Errorf("error writing format string: %w", err) |
| } |
| i += 2 |
| continue |
| } else { |
| argAny, err := list.Arg(int64(argIndex)) |
| if err != nil { |
| return "", err |
| } |
| if i+1 >= len(formatStr) { |
| return "", errors.New("unexpected end of string") |
| } |
| if int64(argIndex) >= list.Size() { |
| return "", fmt.Errorf("index %d out of range", argIndex) |
| } |
| numRead, val, refErr := parseAndFormatClauseV2(formatStr[i:], argAny, callback, list, maxPrecision) |
| if refErr != nil { |
| return "", refErr |
| } |
| _, err = builtStr.WriteString(val) |
| if err != nil { |
| return "", fmt.Errorf("error writing format string: %w", err) |
| } |
| i += numRead |
| argIndex++ |
| } |
| } else { |
| err := builtStr.WriteByte(formatStr[i]) |
| if err != nil { |
| return "", fmt.Errorf("error writing format string: %w", err) |
| } |
| i++ |
| } |
| } |
| return builtStr.String(), nil |
| } |
| |
| // parseAndFormatClause parses the format clause at the start of the given string with val, and returns |
| // how many characters were consumed and the substituted string form of val, or an error if one occurred. |
| func parseAndFormatClauseV2(formatStr string, val ref.Val, callback formatStringInterpolatorV2, list formatListArgs, maxPrecision int) (int, string, error) { |
| i := 1 |
| read, formatter, err := parseFormattingClauseV2(formatStr[i:], callback, maxPrecision) |
| i += read |
| if err != nil { |
| return -1, "", newParseFormatError("could not parse formatting clause", err) |
| } |
| |
| valStr, err := formatter(val) |
| if err != nil { |
| return -1, "", newParseFormatError("error during formatting", err) |
| } |
| return i, valStr, nil |
| } |
| |
| func parseFormattingClauseV2(formatStr string, callback formatStringInterpolatorV2, maxPrecision int) (int, clauseImplV2, error) { |
| i := 0 |
| read, precision, err := parsePrecisionV2(formatStr[i:], maxPrecision) |
| i += read |
| if err != nil { |
| return -1, nil, fmt.Errorf("error while parsing precision: %w", err) |
| } |
| r := rune(formatStr[i]) |
| i++ |
| switch r { |
| case 's': |
| return i, callback.String, nil |
| case 'd': |
| return i, callback.Decimal, nil |
| case 'f': |
| return i, callback.Fixed(precision), nil |
| case 'e': |
| return i, callback.Scientific(precision), nil |
| case 'b': |
| return i, callback.Binary, nil |
| case 'x', 'X': |
| return i, callback.Hex(unicode.IsUpper(r)), nil |
| case 'o': |
| return i, callback.Octal, nil |
| default: |
| return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r) |
| } |
| } |
| |
| func parsePrecisionV2(formatStr string, maxPrecision int) (int, int, error) { |
| i := 0 |
| if formatStr[i] != '.' { |
| return i, defaultPrecision, nil |
| } |
| i++ |
| var buffer strings.Builder |
| for { |
| if i >= len(formatStr) { |
| return -1, -1, errors.New("could not find end of precision specifier") |
| } |
| if !isASCIIDigit(rune(formatStr[i])) { |
| break |
| } |
| buffer.WriteByte(formatStr[i]) |
| i++ |
| } |
| precision, err := strconv.Atoi(buffer.String()) |
| if err != nil { |
| return -1, -1, fmt.Errorf("error while converting precision to integer: %w", err) |
| } |
| if precision < 0 { |
| return -1, -1, fmt.Errorf("negative precision: %d", precision) |
| } |
| if maxPrecision > 0 && precision > maxPrecision { |
| return -1, -1, fmt.Errorf("precision %d exceeds maximum allowed precision %d", precision, maxPrecision) |
| } |
| return i, precision, nil |
| } |