Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ linters:
enable:
- asciicheck
- bodyclose
- goconst
- gocritic
- goheader
- gomodguard
Expand All @@ -22,9 +21,6 @@ linters:
settings:
dupl:
threshold: 100
goconst:
min-len: 2
min-occurrences: 4
gocritic:
disabled-checks:
- singleCaseSwitch
Expand Down
5 changes: 5 additions & 0 deletions CEL.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,13 @@ Combines all elements of a collection using a binary function.

// For maps:
{"a": "apple", "b": "banana"}.fold(k, v, acc, acc + v) // "applebanana"

// Build a map from a list of key/value objects:
dyn(tags).fold(tag, acc, merge(acc, {tag.key: tag.value}))
```

When folding a variable declared as `any`, wrap it with `dyn(...)` so CEL can use it as a comprehension range. The `merge(left, right)` helper returns a map containing all keys from both maps; values from `right` replace values from `left` on duplicate keys.

### has

Tests whether a field is available in a message or map.
Expand Down
1 change: 1 addition & 0 deletions cel.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func GetCelEnv(environment map[string]any) []cel.EnvOption {
opts = append(opts, typeAdapters...)
opts = append(opts, getGoTemplateCelFunction())
opts = append(opts, getDebugCelFunction())
opts = append(opts, getFoldCelLibrary())

// Load input as variables
for k := range environment {
Expand Down
235 changes: 235 additions & 0 deletions cel_fold.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
package gomplate

import (
"fmt"
"sort"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)

const (
foldInitListFn = "cel.@foldInitList"
foldInitMapFn = "cel.@foldInitMap"
foldSortedMapEntries = "cel.@foldSortedMapEntries"
)

func getFoldCelLibrary() cel.EnvOption {
return cel.Lib(&foldCelLibrary{})
}

type foldCelLibrary struct{}

func (l *foldCelLibrary) LibraryName() string {
return "gomplate.fold"
}

func (l *foldCelLibrary) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
cel.ReceiverMacro("fold", 3, foldListMacro,
cel.MacroDocs("Folds a list using an element variable, an accumulator variable, and a step expression."),
cel.MacroExamples("[1, 2, 3].fold(e, acc, acc + e) // 6")),
cel.ReceiverMacro("fold", 4, foldMapMacro,
cel.MacroDocs("Folds a map using key/value variables, an accumulator variable, and a step expression."),
cel.MacroExamples(`{"a": "apple", "b": "banana"}.fold(k, v, acc, acc + v) // "applebanana"`)),
),
cel.Function(foldInitListFn,
cel.Overload("fold_init_list_dyn", []*cel.Type{cel.DynType}, cel.DynType,
cel.UnaryBinding(func(collection ref.Val) ref.Val {
return foldInitialValue(collection, false)
})),
),
cel.Function(foldInitMapFn,
cel.Overload("fold_init_map_dyn", []*cel.Type{cel.DynType}, cel.DynType,
cel.UnaryBinding(func(collection ref.Val) ref.Val {
return foldInitialValue(collection, true)
})),
),
cel.Function(foldSortedMapEntries,
cel.Overload("fold_sorted_map_entries_dyn", []*cel.Type{cel.DynType}, cel.ListType(cel.ListType(cel.DynType)),
cel.UnaryBinding(sortedMapEntries)),
),
cel.Function("merge",
cel.Overload("merge_map_map", []*cel.Type{cel.MapType(cel.DynType, cel.DynType), cel.MapType(cel.DynType, cel.DynType)}, cel.MapType(cel.DynType, cel.DynType),
cel.BinaryBinding(mergeMaps)),
),
}
}

func (*foldCelLibrary) ProgramOptions() []cel.ProgramOption {
return nil
}

func foldListMacro(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar, err := extractFoldIdent(mef, args[0])
if err != nil {
return nil, err
}
accuVar, err := extractFoldIdent(mef, args[1])
if err != nil {
return nil, err
}
if iterVar == accuVar {
return nil, mef.NewError(args[1].ID(), fmt.Sprintf("duplicate variable name: %s", accuVar))
}

return mef.NewComprehension(
target,
iterVar,
accuVar,
mef.NewCall(foldInitListFn, mef.Copy(target)),
mef.NewLiteral(types.True),
args[2],
mef.NewIdent(accuVar),
), nil
}

func foldMapMacro(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
keyVar, err := extractFoldIdent(mef, args[0])
if err != nil {
return nil, err
}
valVar, err := extractFoldIdent(mef, args[1])
if err != nil {
return nil, err
}
accuVar, err := extractFoldIdent(mef, args[2])
if err != nil {
return nil, err
}
if keyVar == valVar || keyVar == accuVar || valVar == accuVar {
return nil, mef.NewError(args[2].ID(), "fold variable names must be unique")
}

entryVar := "__fold_entry__"
if entryVar == keyVar || entryVar == valVar || entryVar == accuVar {
entryVar = "__fold_entry2__"
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
entry := mef.NewIdent(entryVar)
key := mef.NewCall(operators.Index, entry, mef.NewLiteral(types.IntZero))
value := mef.NewCall(operators.Index, mef.NewIdent(entryVar), mef.NewLiteral(types.Int(1)))
step := mef.NewComprehensionTwoVar(
mef.NewMap(mef.NewMapEntry(key, value, false)),
keyVar,
valVar,
accuVar,
mef.NewIdent(accuVar),
mef.NewLiteral(types.True),
args[3],
mef.NewIdent(accuVar),
)

return mef.NewComprehension(
mef.NewCall(foldSortedMapEntries, mef.Copy(target)),
entryVar,
accuVar,
mef.NewCall(foldInitMapFn, mef.Copy(target)),
mef.NewLiteral(types.True),
step,
mef.NewIdent(accuVar),
), nil
}

func extractFoldIdent(mef cel.MacroExprFactory, expr ast.Expr) (string, *cel.Error) {
if expr.Kind() != ast.IdentKind {
return "", mef.NewError(expr.ID(), "argument must be a simple name")
}
return expr.AsIdent(), nil
}

func foldInitialValue(collection ref.Val, mapValues bool) ref.Val {
var first ref.Val
if mapValues {
m, ok := collection.(traits.Mapper)
if !ok {
return types.NewErr("fold target is not a map")
}
it := m.Iterator()
if it.HasNext() != types.True {
return types.DefaultTypeAdapter.NativeToValue(nil)
}
first = m.Get(it.Next())
} else {
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
l, ok := collection.(traits.Lister)
if !ok {
return types.NewErr("fold target is not a list")
}
if l.Size().Equal(types.IntZero) == types.True {
return types.DefaultTypeAdapter.NativeToValue(nil)
}
first = l.Get(types.IntZero)
}
return zeroValueForFold(first)
}

func sortedMapEntries(collection ref.Val) ref.Val {
m, ok := collection.(traits.Mapper)
if !ok {
return types.NewErr("fold target is not a map")
}

keys := []ref.Val{}
for it := m.Iterator(); it.HasNext() == types.True; {
keys = append(keys, it.Next())
}
sort.SliceStable(keys, func(i, j int) bool {
return fmt.Sprint(keys[i].Value()) < fmt.Sprint(keys[j].Value())
})
Comment thread
coderabbitai[bot] marked this conversation as resolved.

entries := make([]ref.Val, 0, len(keys))
for _, key := range keys {
entries = append(entries, types.NewRefValList(types.DefaultTypeAdapter, []ref.Val{key, m.Get(key)}))
}
return types.NewRefValList(types.DefaultTypeAdapter, entries)
}

func mergeMaps(lhs, rhs ref.Val) ref.Val {
left, ok := lhs.(traits.Mapper)
if !ok {
return types.NewErr("left operand is not a map")
}
right, ok := rhs.(traits.Mapper)
if !ok {
return types.NewErr("right operand is not a map")
}

out := map[ref.Val]ref.Val{}
for it := left.Iterator(); it.HasNext() == types.True; {
key := it.Next()
out[key] = left.Get(key)
}
for it := right.Iterator(); it.HasNext() == types.True; {
key := it.Next()
out[key] = right.Get(key)
}
return types.NewRefValMap(types.DefaultTypeAdapter, out)
}

func zeroValueForFold(v ref.Val) ref.Val {
if types.IsError(v) {
return v
}
switch v.(type) {
case types.Int:
return types.IntZero
case types.Uint:
return types.Uint(0)
case types.Double:
return types.Double(0)
case types.String:
return types.String("")
case types.Bytes:
return types.Bytes([]byte{})
case traits.Mapper:
return types.NewRefValMap(types.DefaultTypeAdapter, map[ref.Val]ref.Val{})
case traits.Lister:
return types.NewRefValList(types.DefaultTypeAdapter, []ref.Val{})
default:
return types.DefaultTypeAdapter.NativeToValue(nil)
}
}
18 changes: 18 additions & 0 deletions tests/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,24 @@ func TestCelExtensions(t *testing.T) {
})
}

func TestCelFold(t *testing.T) {
runTests(t, []Test{
{nil, `[1, 2, 3].fold(e, acc, acc + e)`, "6"},
{nil, `{"a": "apple", "b": "banana"}.fold(k, v, acc, acc + v)`, "applebanana"},
})

out, err := gomplate.RunExpression(map[string]any{
"tags": []map[string]any{
{"key": "application", "value": "orders"},
{"key": "environment", "value": "prod"},
},
}, gomplate.Template{
Expression: `dyn(tags).fold(tag, acc, merge(acc, {tag.key: tag.value})).toJSON()`,
})
assert.NoError(t, err)
assert.JSONEq(t, `{"application":"orders","environment":"prod"}`, out.(string))
}

func TestCelEncode(t *testing.T) {
tests := []Test{
{map[string]interface{}{"hello": "hello world ?"}, "urlencode(hello)", `hello+world+%3F`},
Expand Down
Loading