mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
feat: support memo filter for mysql and postgres
This commit is contained in:
20
plugin/filter/converter.go
Normal file
20
plugin/filter/converter.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package filter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConvertContext struct {
|
||||||
|
Buffer strings.Builder
|
||||||
|
Args []any
|
||||||
|
// The offset of the next argument in the condition string.
|
||||||
|
// Mainly using for PostgreSQL.
|
||||||
|
ArgsOffset int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConvertContext() *ConvertContext {
|
||||||
|
return &ConvertContext{
|
||||||
|
Buffer: strings.Builder{},
|
||||||
|
Args: []any{},
|
||||||
|
}
|
||||||
|
}
|
39
plugin/filter/expr.go
Normal file
39
plugin/filter/expr.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package filter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetConstValue returns the constant value of the expression.
|
||||||
|
func GetConstValue(expr *exprv1.Expr) (any, error) {
|
||||||
|
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid constant expression")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v.ConstExpr.ConstantKind.(type) {
|
||||||
|
case *exprv1.Constant_StringValue:
|
||||||
|
return v.ConstExpr.GetStringValue(), nil
|
||||||
|
case *exprv1.Constant_Int64Value:
|
||||||
|
return v.ConstExpr.GetInt64Value(), nil
|
||||||
|
case *exprv1.Constant_Uint64Value:
|
||||||
|
return v.ConstExpr.GetUint64Value(), nil
|
||||||
|
case *exprv1.Constant_DoubleValue:
|
||||||
|
return v.ConstExpr.GetDoubleValue(), nil
|
||||||
|
case *exprv1.Constant_BoolValue:
|
||||||
|
return v.ConstExpr.GetBoolValue(), nil
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unexpected constant type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentExprName returns the name of the identifier expression.
|
||||||
|
func GetIdentExprName(expr *exprv1.Expr) (string, error) {
|
||||||
|
_, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("invalid identifier expression")
|
||||||
|
}
|
||||||
|
return expr.GetIdentExpr().GetName(), nil
|
||||||
|
}
|
@@ -30,26 +30,3 @@ func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err e
|
|||||||
}
|
}
|
||||||
return cel.AstToParsedExpr(ast)
|
return cel.AstToParsedExpr(ast)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConstValue returns the constant value of the expression.
|
|
||||||
func GetConstValue(expr *exprv1.Expr) (any, error) {
|
|
||||||
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("invalid constant expression")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v.ConstExpr.ConstantKind.(type) {
|
|
||||||
case *exprv1.Constant_StringValue:
|
|
||||||
return v.ConstExpr.GetStringValue(), nil
|
|
||||||
case *exprv1.Constant_Int64Value:
|
|
||||||
return v.ConstExpr.GetInt64Value(), nil
|
|
||||||
case *exprv1.Constant_Uint64Value:
|
|
||||||
return v.ConstExpr.GetUint64Value(), nil
|
|
||||||
case *exprv1.Constant_DoubleValue:
|
|
||||||
return v.ConstExpr.GetDoubleValue(), nil
|
|
||||||
case *exprv1.Constant_BoolValue:
|
|
||||||
return v.ConstExpr.GetBoolValue(), nil
|
|
||||||
default:
|
|
||||||
return nil, errors.New("unexpected constant type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
storepb "github.com/usememos/memos/proto/gen/store"
|
storepb "github.com/usememos/memos/proto/gen/store"
|
||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
@@ -108,6 +109,21 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
|
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v := find.Filter; v != nil {
|
||||||
|
// Parse filter string and return the parsed expression.
|
||||||
|
// The filter string should be a CEL expression.
|
||||||
|
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
|
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
||||||
|
args = append(args, convertCtx.Args...)
|
||||||
|
}
|
||||||
if find.ExcludeComments {
|
if find.ExcludeComments {
|
||||||
having = append(having, "`parent_id` IS NULL")
|
having = append(having, "`parent_id` IS NULL")
|
||||||
}
|
}
|
||||||
|
175
store/db/mysql/memo_filter.go
Normal file
175
store/db/mysql/memo_filter.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
|
switch v := expr.ExprKind.(type) {
|
||||||
|
case *exprv1.Expr_CallExpr:
|
||||||
|
switch v.CallExpr.Function {
|
||||||
|
case "_||_", "_&&_":
|
||||||
|
if len(v.CallExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
operator := "AND"
|
||||||
|
if v.CallExpr.Function == "_||_" {
|
||||||
|
operator = "OR"
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "!_":
|
||||||
|
if len(v.CallExpr.Args) != 1 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||||
|
if len(v.CallExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
|
||||||
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
value, err := filter.GetConstValue(v.CallExpr.Args[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
operator := "="
|
||||||
|
switch v.CallExpr.Function {
|
||||||
|
case "_==_":
|
||||||
|
operator = "="
|
||||||
|
case "_!=_":
|
||||||
|
operator = "!="
|
||||||
|
case "_<_":
|
||||||
|
operator = "<"
|
||||||
|
case "_>_":
|
||||||
|
operator = ">"
|
||||||
|
case "_<=_":
|
||||||
|
operator = "<="
|
||||||
|
case "_>=_":
|
||||||
|
operator = ">="
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier == "create_time" || identifier == "update_time" {
|
||||||
|
timestampStr, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("invalid timestamp value")
|
||||||
|
}
|
||||||
|
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to parse timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
var factor string
|
||||||
|
if identifier == "create_time" {
|
||||||
|
factor = "`memo`.`created_ts`"
|
||||||
|
} else if identifier == "update_time" {
|
||||||
|
factor = "`memo`.`updated_ts`"
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("UNIX_TIMESTAMP(%s) %s ?", factor, operator)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, timestamp.Unix())
|
||||||
|
}
|
||||||
|
case "@in":
|
||||||
|
if len(v.CallExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||||
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := []any{}
|
||||||
|
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||||
|
value, err := filter.GetConstValue(element)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
if identifier == "tag" {
|
||||||
|
subcodition := []string{}
|
||||||
|
args := []any{}
|
||||||
|
for _, v := range values {
|
||||||
|
subcodition, args = append(subcodition, "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)"), append(args, v)
|
||||||
|
}
|
||||||
|
if len(subcodition) == 1 {
|
||||||
|
if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, args...)
|
||||||
|
} else if identifier == "visibility" {
|
||||||
|
placeholder := []string{}
|
||||||
|
for range values {
|
||||||
|
placeholder = append(placeholder, "?")
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, values...)
|
||||||
|
}
|
||||||
|
case "contains":
|
||||||
|
if len(v.CallExpr.Args) != 1 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if identifier != "content" {
|
||||||
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
63
store/db/mysql/memo_filter_test.go
Normal file
63
store/db/mysql/memo_filter_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package mysql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertExprToSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filter string
|
||||||
|
want string
|
||||||
|
args []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
filter: `tag in ["tag1", "tag2"]`,
|
||||||
|
want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?))",
|
||||||
|
args: []any{"tag1", "tag2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `!(tag in ["tag1", "tag2"])`,
|
||||||
|
want: "NOT ((JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)))",
|
||||||
|
args: []any{"tag1", "tag2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `content.contains("memos")`,
|
||||||
|
want: "`memo`.`content` LIKE ?",
|
||||||
|
args: []any{"%memos%"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `visibility in ["PUBLIC"]`,
|
||||||
|
want: "`memo`.`visibility` IN (?)",
|
||||||
|
args: []any{"PUBLIC"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||||
|
want: "`memo`.`visibility` IN (?,?)",
|
||||||
|
args: []any{"PUBLIC", "PRIVATE"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
|
||||||
|
want: "UNIX_TIMESTAMP(`memo`.`created_ts`) = ?",
|
||||||
|
args: []any{int64(1136189045)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||||
|
want: "(JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?) OR `memo`.`content` LIKE ?)",
|
||||||
|
args: []any{"tag1", "%hello%"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
}
|
||||||
|
}
|
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
storepb "github.com/usememos/memos/proto/gen/store"
|
storepb "github.com/usememos/memos/proto/gen/store"
|
||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
@@ -99,6 +100,22 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||||||
where = append(where, "(memo.payload->'property'->>'hasIncompleteTasks')::BOOLEAN IS TRUE")
|
where = append(where, "(memo.payload->'property'->>'hasIncompleteTasks')::BOOLEAN IS TRUE")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if v := find.Filter; v != nil {
|
||||||
|
// Parse filter string and return the parsed expression.
|
||||||
|
// The filter string should be a CEL expression.
|
||||||
|
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
convertCtx.ArgsOffset = len(args)
|
||||||
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
|
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
||||||
|
args = append(args, convertCtx.Args...)
|
||||||
|
}
|
||||||
if find.ExcludeComments {
|
if find.ExcludeComments {
|
||||||
where = append(where, "memo_relation.related_memo_id IS NULL")
|
where = append(where, "memo_relation.related_memo_id IS NULL")
|
||||||
}
|
}
|
||||||
|
175
store/db/postgres/memo_filter.go
Normal file
175
store/db/postgres/memo_filter.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
|
switch v := expr.ExprKind.(type) {
|
||||||
|
case *exprv1.Expr_CallExpr:
|
||||||
|
switch v.CallExpr.Function {
|
||||||
|
case "_||_", "_&&_":
|
||||||
|
if len(v.CallExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
operator := "AND"
|
||||||
|
if v.CallExpr.Function == "_||_" {
|
||||||
|
operator = "OR"
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "!_":
|
||||||
|
if len(v.CallExpr.Args) != 1 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||||
|
if len(v.CallExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
|
||||||
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
value, err := filter.GetConstValue(v.CallExpr.Args[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
operator := "="
|
||||||
|
switch v.CallExpr.Function {
|
||||||
|
case "_==_":
|
||||||
|
operator = "="
|
||||||
|
case "_!=_":
|
||||||
|
operator = "!="
|
||||||
|
case "_<_":
|
||||||
|
operator = "<"
|
||||||
|
case "_>_":
|
||||||
|
operator = ">"
|
||||||
|
case "_<=_":
|
||||||
|
operator = "<="
|
||||||
|
case "_>=_":
|
||||||
|
operator = ">="
|
||||||
|
}
|
||||||
|
|
||||||
|
if identifier == "create_time" || identifier == "update_time" {
|
||||||
|
timestampStr, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("invalid timestamp value")
|
||||||
|
}
|
||||||
|
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to parse timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
var factor string
|
||||||
|
if identifier == "create_time" {
|
||||||
|
factor = "memo.created_ts"
|
||||||
|
} else if identifier == "update_time" {
|
||||||
|
factor = "memo.updated_ts"
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", factor, operator, placeholder(len(ctx.Args)+ctx.ArgsOffset+1))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, timestamp.Unix())
|
||||||
|
}
|
||||||
|
case "@in":
|
||||||
|
if len(v.CallExpr.Args) != 2 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||||
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
|
||||||
|
values := []any{}
|
||||||
|
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||||
|
value, err := filter.GetConstValue(element)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
if identifier == "tag" {
|
||||||
|
subcodition := []string{}
|
||||||
|
args := []any{}
|
||||||
|
for _, v := range values {
|
||||||
|
subcodition, args = append(subcodition, fmt.Sprintf(`memo.payload->'tags' @> %s::jsonb`, placeholder(len(ctx.Args)+ctx.ArgsOffset+len(args)+1))), append(args, []any{v})
|
||||||
|
}
|
||||||
|
if len(subcodition) == 1 {
|
||||||
|
if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, args...)
|
||||||
|
} else if identifier == "visibility" {
|
||||||
|
placeholders := []string{}
|
||||||
|
for i := range values {
|
||||||
|
placeholders = append(placeholders, placeholder(len(ctx.Args)+ctx.ArgsOffset+i+1))
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("memo.visibility IN (%s)", strings.Join(placeholders, ","))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, values...)
|
||||||
|
}
|
||||||
|
case "contains":
|
||||||
|
if len(v.CallExpr.Args) != 1 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if identifier != "content" {
|
||||||
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString("memo.content ILIKE LIKE " + placeholder(len(ctx.Args)+ctx.ArgsOffset+1)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
63
store/db/postgres/memo_filter_test.go
Normal file
63
store/db/postgres/memo_filter_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRestoreExprToSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filter string
|
||||||
|
want string
|
||||||
|
args []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
filter: `tag in ["tag1", "tag2"]`,
|
||||||
|
want: "(memo.payload->'tags' @> $1::jsonb OR memo.payload->'tags' @> $2::jsonb)",
|
||||||
|
args: []any{[]any{"tag1"}, []any{"tag2"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `!(tag in ["tag1", "tag2"])`,
|
||||||
|
want: `NOT ((memo.payload->'tags' @> $1::jsonb OR memo.payload->'tags' @> $2::jsonb))`,
|
||||||
|
args: []any{[]any{"tag1"}, []any{"tag2"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `content.contains("memos")`,
|
||||||
|
want: "memo.content ILIKE LIKE $1",
|
||||||
|
args: []any{"%memos%"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `visibility in ["PUBLIC"]`,
|
||||||
|
want: "memo.visibility IN ($1)",
|
||||||
|
args: []any{"PUBLIC"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||||
|
want: "memo.visibility IN ($1,$2)",
|
||||||
|
args: []any{"PUBLIC", "PRIVATE"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
|
||||||
|
want: "memo.created_ts = $1",
|
||||||
|
args: []any{int64(1136189045)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||||
|
want: "(memo.payload->'tags' @> $1::jsonb OR memo.content ILIKE LIKE $2)",
|
||||||
|
args: []any{[]any{"tag1"}, "%hello%"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||||
|
require.NoError(t, err)
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
}
|
||||||
|
}
|
@@ -108,12 +108,13 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// RestoreExprToSQL parses the expression and returns the SQL condition.
|
convertCtx := filter.NewConvertContext()
|
||||||
condition, err := RestoreExprToSQL(parsedExpr.GetExpr())
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
if err != nil {
|
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
where = append(where, fmt.Sprintf("(%s)", condition))
|
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
||||||
|
args = append(args, convertCtx.Args...)
|
||||||
}
|
}
|
||||||
if find.ExcludeComments {
|
if find.ExcludeComments {
|
||||||
where = append(where, "`parent_id` IS NULL")
|
where = append(where, "`parent_id` IS NULL")
|
||||||
|
@@ -12,39 +12,60 @@ import (
|
|||||||
"github.com/usememos/memos/plugin/filter"
|
"github.com/usememos/memos/plugin/filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
|
func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
var condition string
|
|
||||||
switch v := expr.ExprKind.(type) {
|
switch v := expr.ExprKind.(type) {
|
||||||
case *exprv1.Expr_CallExpr:
|
case *exprv1.Expr_CallExpr:
|
||||||
switch v.CallExpr.Function {
|
switch v.CallExpr.Function {
|
||||||
case "_||_", "_&&_":
|
case "_||_", "_&&_":
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
left, err := RestoreExprToSQL(v.CallExpr.Args[0])
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
right, err := RestoreExprToSQL(v.CallExpr.Args[1])
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
operator := "AND"
|
operator := "AND"
|
||||||
if v.CallExpr.Function == "_||_" {
|
if v.CallExpr.Function == "_||_" {
|
||||||
operator = "OR"
|
operator = "OR"
|
||||||
}
|
}
|
||||||
condition = fmt.Sprintf("(%s %s %s)", left, operator, right)
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "!_":
|
||||||
|
if len(v.CallExpr.Args) != 1 {
|
||||||
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
|
|
||||||
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
|
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
|
||||||
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
value, err := filter.GetConstValue(v.CallExpr.Args[1])
|
value, err := filter.GetConstValue(v.CallExpr.Args[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
operator := "="
|
operator := "="
|
||||||
switch v.CallExpr.Function {
|
switch v.CallExpr.Function {
|
||||||
@@ -65,85 +86,90 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
|
|||||||
if identifier == "create_time" || identifier == "update_time" {
|
if identifier == "create_time" || identifier == "update_time" {
|
||||||
timestampStr, ok := value.(string)
|
timestampStr, ok := value.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", errors.New("invalid timestamp value")
|
return errors.New("invalid timestamp value")
|
||||||
}
|
}
|
||||||
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "failed to parse timestamp")
|
return errors.Wrap(err, "failed to parse timestamp")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var factor string
|
||||||
if identifier == "create_time" {
|
if identifier == "create_time" {
|
||||||
condition = fmt.Sprintf("`memo`.`created_ts` %s %d", operator, timestamp.Unix())
|
factor = "`memo`.`created_ts`"
|
||||||
} else if identifier == "update_time" {
|
} else if identifier == "update_time" {
|
||||||
condition = fmt.Sprintf("`memo`.`updated_ts` %s %d", operator, timestamp.Unix())
|
factor = "`memo`.`updated_ts`"
|
||||||
}
|
}
|
||||||
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ctx.Args = append(ctx.Args, timestamp.Unix())
|
||||||
}
|
}
|
||||||
case "@in":
|
case "@in":
|
||||||
if len(v.CallExpr.Args) != 2 {
|
if len(v.CallExpr.Args) != 2 {
|
||||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
|
|
||||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||||
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
|
|
||||||
values := []any{}
|
values := []any{}
|
||||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||||
value, err := filter.GetConstValue(element)
|
value, err := filter.GetConstValue(element)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
values = append(values, value)
|
values = append(values, value)
|
||||||
}
|
}
|
||||||
if identifier == "tag" {
|
if identifier == "tag" {
|
||||||
subcodition := []string{}
|
subcodition := []string{}
|
||||||
|
args := []any{}
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
subcodition = append(subcodition, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`'%%"%s"%%'`, v)))
|
subcodition, args = append(subcodition, "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?"), append(args, fmt.Sprintf(`%%"%s"%%`, v))
|
||||||
}
|
}
|
||||||
if len(subcodition) == 1 {
|
if len(subcodition) == 1 {
|
||||||
condition = subcodition[0]
|
if _, err := ctx.Buffer.WriteString(subcodition[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
condition = fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
ctx.Args = append(ctx.Args, args...)
|
||||||
} else if identifier == "visibility" {
|
} else if identifier == "visibility" {
|
||||||
vs := []string{}
|
placeholder := []string{}
|
||||||
for _, v := range values {
|
for range values {
|
||||||
vs = append(vs, fmt.Sprintf(`"%s"`, v))
|
placeholder = append(placeholder, "?")
|
||||||
}
|
}
|
||||||
if len(vs) == 1 {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(placeholder, ","))); err != nil {
|
||||||
condition = fmt.Sprintf("`memo`.`visibility` = %s", vs[0])
|
return err
|
||||||
} else {
|
|
||||||
condition = fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(vs, ","))
|
|
||||||
}
|
}
|
||||||
|
ctx.Args = append(ctx.Args, values...)
|
||||||
}
|
}
|
||||||
case "contains":
|
case "contains":
|
||||||
if len(v.CallExpr.Args) != 1 {
|
if len(v.CallExpr.Args) != 1 {
|
||||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
identifier, err := RestoreExprToSQL(v.CallExpr.Target)
|
identifier, err := filter.GetIdentExprName(v.CallExpr.Target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
if identifier != "content" {
|
if identifier != "content" {
|
||||||
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||||
}
|
}
|
||||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
condition = fmt.Sprintf("`memo`.`content` LIKE %s", fmt.Sprintf(`'%%%s%%'`, arg))
|
if _, err := ctx.Buffer.WriteString("`memo`.`content` LIKE ?"); err != nil {
|
||||||
case "!_":
|
return err
|
||||||
if len(v.CallExpr.Args) != 1 {
|
|
||||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
|
||||||
}
|
}
|
||||||
arg, err := RestoreExprToSQL(v.CallExpr.Args[0])
|
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
condition = fmt.Sprintf("NOT (%s)", arg)
|
|
||||||
}
|
}
|
||||||
case *exprv1.Expr_IdentExpr:
|
|
||||||
return v.IdentExpr.GetName(), nil
|
|
||||||
}
|
}
|
||||||
return condition, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -8,50 +8,61 @@ import (
|
|||||||
"github.com/usememos/memos/plugin/filter"
|
"github.com/usememos/memos/plugin/filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRestoreExprToSQL(t *testing.T) {
|
func TestConvertExprToSQL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
filter string
|
filter string
|
||||||
want string
|
want string
|
||||||
|
args []any
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
filter: `tag in ["tag1", "tag2"]`,
|
filter: `tag in ["tag1", "tag2"]`,
|
||||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%')",
|
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?)",
|
||||||
|
args: []any{`%"tag1"%`, `%"tag2"%`},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
filter: `!(tag in ["tag1", "tag2"])`,
|
filter: `!(tag in ["tag1", "tag2"])`,
|
||||||
want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%'))",
|
want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
|
||||||
|
args: []any{`%"tag1"%`, `%"tag2"%`},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`,
|
filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`,
|
||||||
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag2\"%') OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag3\"%' OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag4\"%'))",
|
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?))",
|
||||||
|
args: []any{`%"tag1"%`, `%"tag2"%`, `%"tag3"%`, `%"tag4"%`},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
filter: `content.contains("memos")`,
|
filter: `content.contains("memos")`,
|
||||||
want: "`memo`.`content` LIKE '%memos%'",
|
want: "`memo`.`content` LIKE ?",
|
||||||
|
args: []any{"%memos%"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
filter: `visibility in ["PUBLIC"]`,
|
filter: `visibility in ["PUBLIC"]`,
|
||||||
want: "`memo`.`visibility` = \"PUBLIC\"",
|
want: "`memo`.`visibility` IN (?)",
|
||||||
|
args: []any{"PUBLIC"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||||
want: "`memo`.`visibility` IN (\"PUBLIC\",\"PRIVATE\")",
|
want: "`memo`.`visibility` IN (?,?)",
|
||||||
|
args: []any{"PUBLIC", "PRIVATE"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
|
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
|
||||||
want: "`memo`.`created_ts` = 1136189045",
|
want: "`memo`.`created_ts` = ?",
|
||||||
|
args: []any{int64(1136189045)},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
filter: `tag in ['tag1'] || content.contains('hello')`,
|
filter: `tag in ['tag1'] || content.contains('hello')`,
|
||||||
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE '%\"tag1\"%' OR `memo`.`content` LIKE '%hello%')",
|
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ? OR `memo`.`content` LIKE ?)",
|
||||||
|
args: []any{`%"tag1"%`, "%hello%"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
result, err := RestoreExprToSQL(parsedExpr.GetExpr())
|
convertCtx := filter.NewConvertContext()
|
||||||
|
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tt.want, result)
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user