From de3e55c2e6cf43208cff68a63e3a72df308d7a98 Mon Sep 17 00:00:00 2001 From: Steven Date: Wed, 28 May 2025 21:18:49 +0800 Subject: [PATCH] feat: support `now()` time functions --- plugin/filter/expr.go | 88 +++++++++++++++++++++++++++ plugin/filter/filter.go | 20 ++++-- store/db/mysql/memo_filter.go | 25 +++----- store/db/mysql/memo_filter_test.go | 11 ++-- store/db/postgres/memo_filter.go | 23 +++---- store/db/postgres/memo_filter_test.go | 11 ++-- store/db/sqlite/memo_filter.go | 21 +++---- store/db/sqlite/memo_filter_test.go | 11 ++-- 8 files changed, 149 insertions(+), 61 deletions(-) diff --git a/plugin/filter/expr.go b/plugin/filter/expr.go index 5bdcdc27..01ce5395 100644 --- a/plugin/filter/expr.go +++ b/plugin/filter/expr.go @@ -2,6 +2,7 @@ package filter import ( "errors" + "time" exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -37,3 +38,90 @@ func GetIdentExprName(expr *exprv1.Expr) (string, error) { } return expr.GetIdentExpr().GetName(), nil } + +// GetFunctionValue evaluates CEL function calls and returns their value. +// This is specifically for time functions like now(). +func GetFunctionValue(expr *exprv1.Expr) (any, error) { + callExpr, ok := expr.ExprKind.(*exprv1.Expr_CallExpr) + if !ok { + return nil, errors.New("invalid function call expression") + } + + switch callExpr.CallExpr.Function { + case "now": + if len(callExpr.CallExpr.Args) != 0 { + return nil, errors.New("now() function takes no arguments") + } + return time.Now().Unix(), nil + case "_-_": + // Handle subtraction for expressions like "now() - 60 * 60 * 24" + if len(callExpr.CallExpr.Args) != 2 { + return nil, errors.New("subtraction requires exactly two arguments") + } + left, err := GetExprValue(callExpr.CallExpr.Args[0]) + if err != nil { + return nil, err + } + right, err := GetExprValue(callExpr.CallExpr.Args[1]) + if err != nil { + return nil, err + } + leftInt, ok1 := left.(int64) + rightInt, ok2 := right.(int64) + if !ok1 || !ok2 { + return nil, errors.New("subtraction operands must be integers") + } + return leftInt - rightInt, nil + case "_*_": + // Handle multiplication for expressions like "60 * 60 * 24" + if len(callExpr.CallExpr.Args) != 2 { + return nil, errors.New("multiplication requires exactly two arguments") + } + left, err := GetExprValue(callExpr.CallExpr.Args[0]) + if err != nil { + return nil, err + } + right, err := GetExprValue(callExpr.CallExpr.Args[1]) + if err != nil { + return nil, err + } + leftInt, ok1 := left.(int64) + rightInt, ok2 := right.(int64) + if !ok1 || !ok2 { + return nil, errors.New("multiplication operands must be integers") + } + return leftInt * rightInt, nil + case "_+_": + // Handle addition + if len(callExpr.CallExpr.Args) != 2 { + return nil, errors.New("addition requires exactly two arguments") + } + left, err := GetExprValue(callExpr.CallExpr.Args[0]) + if err != nil { + return nil, err + } + right, err := GetExprValue(callExpr.CallExpr.Args[1]) + if err != nil { + return nil, err + } + leftInt, ok1 := left.(int64) + rightInt, ok2 := right.(int64) + if !ok1 || !ok2 { + return nil, errors.New("addition operands must be integers") + } + return leftInt + rightInt, nil + default: + return nil, errors.New("unsupported function: " + callExpr.CallExpr.Function) + } +} + +// GetExprValue attempts to get a value from an expression, trying constants first, then functions. +func GetExprValue(expr *exprv1.Expr) (any, error) { + // Try to get constant value first + if constValue, err := GetConstValue(expr); err == nil { + return constValue, nil + } + + // If not a constant, try to evaluate as a function + return GetFunctionValue(expr) +} diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index cdb13955..e3eed9a4 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -1,7 +1,11 @@ package filter import ( + "time" + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" "github.com/pkg/errors" exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -10,14 +14,22 @@ import ( var MemoFilterCELAttributes = []cel.EnvOption{ cel.Variable("content", cel.StringType), cel.Variable("creator_id", cel.IntType), - // As the built-in timestamp type is deprecated, we use string type for now. - // e.g., "2021-01-01T00:00:00Z" - cel.Variable("create_time", cel.StringType), + cel.Variable("created_ts", cel.IntType), + cel.Variable("updated_ts", cel.IntType), cel.Variable("pinned", cel.BoolType), cel.Variable("tag", cel.StringType), - cel.Variable("update_time", cel.StringType), cel.Variable("visibility", cel.StringType), cel.Variable("has_task_list", cel.BoolType), + // Current timestamp function. + cel.Function("now", + cel.Overload("now", + []*cel.Type{}, + cel.IntType, + cel.FunctionBinding(func(args ...ref.Val) ref.Val { + return types.Int(time.Now().Unix()) + }), + ), + ), } // Parse parses the filter string and returns the parsed expression. diff --git a/store/db/mysql/memo_filter.go b/store/db/mysql/memo_filter.go index 0f166b10..1b2e225f 100644 --- a/store/db/mysql/memo_filter.go +++ b/store/db/mysql/memo_filter.go @@ -4,7 +4,6 @@ import ( "fmt" "slices" "strings" - "time" "github.com/pkg/errors" exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -59,10 +58,10 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content", "has_task_list"}, identifier) { + if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } - value, err := filter.GetConstValue(v.CallExpr.Args[1]) + value, err := filter.GetExprValue(v.CallExpr.Args[1]) if err != nil { return err } @@ -82,26 +81,22 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err operator = ">=" } - if identifier == "create_time" || identifier == "update_time" { - timestampStr, ok := value.(string) + if identifier == "created_ts" || identifier == "updated_ts" { + timestampInt, ok := value.(int64) if !ok { return errors.New("invalid timestamp value") } - timestamp, err := time.Parse(time.RFC3339, timestampStr) - if err != nil { - return errors.Wrap(err, "failed to parse timestamp") - } var factor string - if identifier == "create_time" { - factor = "`memo`.`created_ts`" - } else if identifier == "update_time" { - factor = "`memo`.`updated_ts`" + if identifier == "created_ts" { + factor = "UNIX_TIMESTAMP(`memo`.`created_ts`)" + } else if identifier == "updated_ts" { + factor = "UNIX_TIMESTAMP(`memo`.`updated_ts`)" } - if _, err := ctx.Buffer.WriteString(fmt.Sprintf("UNIX_TIMESTAMP(%s) %s ?", factor, operator)); err != nil { + if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { return err } - ctx.Args = append(ctx.Args, timestamp.Unix()) + ctx.Args = append(ctx.Args, timestampInt) } else if identifier == "visibility" || identifier == "content" { if operator != "=" && operator != "!=" { return errors.Errorf("invalid operator for %s", v.CallExpr.Function) diff --git a/store/db/mysql/memo_filter_test.go b/store/db/mysql/memo_filter_test.go index de7ad30e..7a89afaa 100644 --- a/store/db/mysql/memo_filter_test.go +++ b/store/db/mysql/memo_filter_test.go @@ -2,6 +2,7 @@ package mysql import ( "testing" + "time" "github.com/stretchr/testify/require" @@ -39,11 +40,6 @@ func TestConvertExprToSQL(t *testing.T) { want: "`memo`.`visibility` IN (?,?)", args: []any{"PUBLIC", "PRIVATE"}, }, - { - filter: `create_time == "2006-01-02T15:04:05+07:00"`, - want: "UNIX_TIMESTAMP(`memo`.`created_ts`) = ?", - args: []any{int64(1136189045)}, - }, { filter: `tag in ['tag1'] || content.contains('hello')`, want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR `memo`.`content` LIKE ?)", @@ -94,6 +90,11 @@ func TestConvertExprToSQL(t *testing.T) { want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON) AND `memo`.`content` LIKE ?)", args: []any{"%todo%"}, }, + { + filter: `created_ts > now() - 60 * 60 * 24`, + want: "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?", + args: []any{time.Now().Unix() - 60*60*24}, + }, } for _, tt := range tests { diff --git a/store/db/postgres/memo_filter.go b/store/db/postgres/memo_filter.go index 233693a9..a6385a1d 100644 --- a/store/db/postgres/memo_filter.go +++ b/store/db/postgres/memo_filter.go @@ -4,7 +4,6 @@ import ( "fmt" "slices" "strings" - "time" "github.com/pkg/errors" exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -59,10 +58,10 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content", "has_task_list"}, identifier) { + if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } - value, err := filter.GetConstValue(v.CallExpr.Args[1]) + value, err := filter.GetExprValue(v.CallExpr.Args[1]) if err != nil { return err } @@ -82,26 +81,22 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err operator = ">=" } - if identifier == "create_time" || identifier == "update_time" { - timestampStr, ok := value.(string) + if identifier == "created_ts" || identifier == "updated_ts" { + timestampInt, ok := value.(int64) if !ok { return errors.New("invalid timestamp value") } - timestamp, err := time.Parse(time.RFC3339, timestampStr) - if err != nil { - return errors.Wrap(err, "failed to parse timestamp") - } var factor string - if identifier == "create_time" { - factor = "memo.created_ts" - } else if identifier == "update_time" { - factor = "memo.updated_ts" + if identifier == "created_ts" { + factor = "EXTRACT(EPOCH FROM memo.created_ts)" + } else if identifier == "updated_ts" { + factor = "EXTRACT(EPOCH FROM memo.updated_ts)" } if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil { return err } - ctx.Args = append(ctx.Args, timestamp.Unix()) + ctx.Args = append(ctx.Args, timestampInt) } else if identifier == "visibility" || identifier == "content" { if operator != "=" && operator != "!=" { return errors.Errorf("invalid operator for %s", v.CallExpr.Function) diff --git a/store/db/postgres/memo_filter_test.go b/store/db/postgres/memo_filter_test.go index ad659cfb..a610680f 100644 --- a/store/db/postgres/memo_filter_test.go +++ b/store/db/postgres/memo_filter_test.go @@ -2,6 +2,7 @@ package postgres import ( "testing" + "time" "github.com/stretchr/testify/require" @@ -39,11 +40,6 @@ func TestRestoreExprToSQL(t *testing.T) { want: "memo.visibility IN ($1,$2)", args: []any{"PUBLIC", "PRIVATE"}, }, - { - filter: `create_time == "2006-01-02T15:04:05+07:00"`, - want: "memo.created_ts = $1", - args: []any{int64(1136189045)}, - }, { filter: `tag in ['tag1'] || content.contains('hello')`, want: "(memo.payload->'tags' @> jsonb_build_array($1) OR memo.content ILIKE $2)", @@ -94,6 +90,11 @@ func TestRestoreExprToSQL(t *testing.T) { want: "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.content ILIKE $1)", args: []any{"%todo%"}, }, + { + filter: `created_ts > now() - 60 * 60 * 24`, + want: "EXTRACT(EPOCH FROM memo.created_ts) > $1", + args: []any{time.Now().Unix() - 60*60*24}, + }, } for _, tt := range tests { diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go index eeb597f8..b348f726 100644 --- a/store/db/sqlite/memo_filter.go +++ b/store/db/sqlite/memo_filter.go @@ -4,7 +4,6 @@ import ( "fmt" "slices" "strings" - "time" "github.com/pkg/errors" exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -59,10 +58,10 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err if err != nil { return err } - if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content", "has_task_list"}, identifier) { + if !slices.Contains([]string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "has_task_list"}, identifier) { return errors.Errorf("invalid identifier for %s", v.CallExpr.Function) } - value, err := filter.GetConstValue(v.CallExpr.Args[1]) + value, err := filter.GetExprValue(v.CallExpr.Args[1]) if err != nil { return err } @@ -82,26 +81,22 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err operator = ">=" } - if identifier == "create_time" || identifier == "update_time" { - timestampStr, ok := value.(string) + if identifier == "created_ts" || identifier == "updated_ts" { + valueInt, ok := value.(int64) if !ok { - return errors.New("invalid timestamp value") - } - timestamp, err := time.Parse(time.RFC3339, timestampStr) - if err != nil { - return errors.Wrap(err, "failed to parse timestamp") + return errors.New("invalid integer timestamp value") } var factor string - if identifier == "create_time" { + if identifier == "created_ts" { factor = "`memo`.`created_ts`" - } else if identifier == "update_time" { + } else if identifier == "updated_ts" { factor = "`memo`.`updated_ts`" } if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil { return err } - ctx.Args = append(ctx.Args, timestamp.Unix()) + ctx.Args = append(ctx.Args, valueInt) } else if identifier == "visibility" || identifier == "content" { if operator != "=" && operator != "!=" { return errors.Errorf("invalid operator for %s", v.CallExpr.Function) diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index 522eb54a..c4156341 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -2,6 +2,7 @@ package sqlite import ( "testing" + "time" "github.com/stretchr/testify/require" @@ -44,11 +45,6 @@ func TestConvertExprToSQL(t *testing.T) { want: "`memo`.`visibility` IN (?,?)", args: []any{"PUBLIC", "PRIVATE"}, }, - { - filter: `create_time == "2006-01-02T15:04:05+07:00"`, - want: "`memo`.`created_ts` = ?", - args: []any{int64(1136189045)}, - }, { filter: `tag in ['tag1'] || content.contains('hello')`, want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR `memo`.`content` LIKE ?)", @@ -109,6 +105,11 @@ func TestConvertExprToSQL(t *testing.T) { want: "(JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE AND `memo`.`content` LIKE ?)", args: []any{"%todo%"}, }, + { + filter: `created_ts > now() - 60 * 60 * 24`, + want: "`memo`.`created_ts` > ?", + args: []any{time.Now().Unix() - 60*60*24}, + }, } for _, tt := range tests {