mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
feat: memo filter for sqlite
This commit is contained in:
47
plugin/filter/filter.go
Normal file
47
plugin/filter/filter.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package filter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/google/cel-go/cel"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoFilterCELAttributes are the CEL attributes for memo.
|
||||||
|
var MemoFilterCELAttributes = []cel.EnvOption{
|
||||||
|
cel.Variable("content", cel.StringType),
|
||||||
|
cel.Variable("tag", cel.StringType),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses the filter string and returns the parsed expression.
|
||||||
|
// The filter string should be a CEL expression.
|
||||||
|
func Parse(filter string) (expr *exprv1.ParsedExpr, err error) {
|
||||||
|
e, err := cel.NewEnv(MemoFilterCELAttributes...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to create CEL environment")
|
||||||
|
}
|
||||||
|
ast, issues := e.Compile(filter)
|
||||||
|
if issues != nil {
|
||||||
|
return nil, errors.Errorf("failed to compile filter: %v", issues)
|
||||||
|
}
|
||||||
|
return cel.AstToParsedExpr(ast)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConstValue returns the constant value of the expression.
|
||||||
|
func GetConstValue(expr *exprv1.Expr) (any, error) {
|
||||||
|
switch v := expr.ExprKind.(type) {
|
||||||
|
case *exprv1.Expr_ConstExpr:
|
||||||
|
switch v.ConstExpr.ConstantKind.(type) {
|
||||||
|
case *exprv1.Constant_StringValue:
|
||||||
|
return v.ConstExpr.GetStringValue(), nil
|
||||||
|
case *exprv1.Constant_Int64Value:
|
||||||
|
return v.ConstExpr.GetInt64Value(), nil
|
||||||
|
case *exprv1.Constant_Uint64Value:
|
||||||
|
return v.ConstExpr.GetUint64Value(), nil
|
||||||
|
case *exprv1.Constant_DoubleValue:
|
||||||
|
return v.ConstExpr.GetDoubleValue(), nil
|
||||||
|
case *exprv1.Constant_BoolValue:
|
||||||
|
return v.ConstExpr.GetBoolValue(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("invalid constant expression")
|
||||||
|
}
|
@@ -5,7 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/cel-go/cel"
|
"github.com/google/cel-go/cel"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
@@ -18,52 +18,52 @@ func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store.
|
|||||||
find.PayloadFind = &store.FindMemoPayload{}
|
find.PayloadFind = &store.FindMemoPayload{}
|
||||||
}
|
}
|
||||||
if filter != "" {
|
if filter != "" {
|
||||||
filter, err := parseMemoFilter(filter)
|
filterExpr, err := parseMemoFilter(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||||
}
|
}
|
||||||
if len(filter.ContentSearch) > 0 {
|
if len(filterExpr.ContentSearch) > 0 {
|
||||||
find.ContentSearch = filter.ContentSearch
|
find.ContentSearch = filterExpr.ContentSearch
|
||||||
}
|
}
|
||||||
if len(filter.Visibilities) > 0 {
|
if len(filterExpr.Visibilities) > 0 {
|
||||||
find.VisibilityList = filter.Visibilities
|
find.VisibilityList = filterExpr.Visibilities
|
||||||
}
|
}
|
||||||
if filter.TagSearch != nil {
|
if filterExpr.TagSearch != nil {
|
||||||
if find.PayloadFind == nil {
|
if find.PayloadFind == nil {
|
||||||
find.PayloadFind = &store.FindMemoPayload{}
|
find.PayloadFind = &store.FindMemoPayload{}
|
||||||
}
|
}
|
||||||
find.PayloadFind.TagSearch = filter.TagSearch
|
find.PayloadFind.TagSearch = filterExpr.TagSearch
|
||||||
}
|
}
|
||||||
if filter.OrderByPinned {
|
if filterExpr.OrderByPinned {
|
||||||
find.OrderByPinned = filter.OrderByPinned
|
find.OrderByPinned = filterExpr.OrderByPinned
|
||||||
}
|
}
|
||||||
if filter.OrderByTimeAsc {
|
if filterExpr.OrderByTimeAsc {
|
||||||
find.OrderByTimeAsc = filter.OrderByTimeAsc
|
find.OrderByTimeAsc = filterExpr.OrderByTimeAsc
|
||||||
}
|
}
|
||||||
if filter.DisplayTimeAfter != nil {
|
if filterExpr.DisplayTimeAfter != nil {
|
||||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||||
}
|
}
|
||||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||||
find.UpdatedTsAfter = filter.DisplayTimeAfter
|
find.UpdatedTsAfter = filterExpr.DisplayTimeAfter
|
||||||
} else {
|
} else {
|
||||||
find.CreatedTsAfter = filter.DisplayTimeAfter
|
find.CreatedTsAfter = filterExpr.DisplayTimeAfter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if filter.DisplayTimeBefore != nil {
|
if filterExpr.DisplayTimeBefore != nil {
|
||||||
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
return status.Errorf(codes.Internal, "failed to get workspace memo related setting")
|
||||||
}
|
}
|
||||||
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
if workspaceMemoRelatedSetting.DisplayWithUpdateTime {
|
||||||
find.UpdatedTsBefore = filter.DisplayTimeBefore
|
find.UpdatedTsBefore = filterExpr.DisplayTimeBefore
|
||||||
} else {
|
} else {
|
||||||
find.CreatedTsBefore = filter.DisplayTimeBefore
|
find.CreatedTsBefore = filterExpr.DisplayTimeBefore
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if filter.Creator != nil {
|
if filterExpr.Creator != nil {
|
||||||
userID, err := ExtractUserIDFromName(*filter.Creator)
|
userID, err := ExtractUserIDFromName(*filterExpr.Creator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "invalid user name")
|
return errors.Wrap(err, "invalid user name")
|
||||||
}
|
}
|
||||||
@@ -78,28 +78,28 @@ func (s *APIV1Service) buildMemoFindWithFilter(ctx context.Context, find *store.
|
|||||||
}
|
}
|
||||||
find.CreatorID = &user.ID
|
find.CreatorID = &user.ID
|
||||||
}
|
}
|
||||||
if filter.RowStatus != nil {
|
if filterExpr.RowStatus != nil {
|
||||||
find.RowStatus = filter.RowStatus
|
find.RowStatus = filterExpr.RowStatus
|
||||||
}
|
}
|
||||||
if filter.Random {
|
if filterExpr.Random {
|
||||||
find.Random = filter.Random
|
find.Random = filterExpr.Random
|
||||||
}
|
}
|
||||||
if filter.Limit != nil {
|
if filterExpr.Limit != nil {
|
||||||
find.Limit = filter.Limit
|
find.Limit = filterExpr.Limit
|
||||||
}
|
}
|
||||||
if filter.IncludeComments {
|
if filterExpr.IncludeComments {
|
||||||
find.ExcludeComments = false
|
find.ExcludeComments = false
|
||||||
}
|
}
|
||||||
if filter.HasLink {
|
if filterExpr.HasLink {
|
||||||
find.PayloadFind.HasLink = true
|
find.PayloadFind.HasLink = true
|
||||||
}
|
}
|
||||||
if filter.HasTaskList {
|
if filterExpr.HasTaskList {
|
||||||
find.PayloadFind.HasTaskList = true
|
find.PayloadFind.HasTaskList = true
|
||||||
}
|
}
|
||||||
if filter.HasCode {
|
if filterExpr.HasCode {
|
||||||
find.PayloadFind.HasCode = true
|
find.PayloadFind.HasCode = true
|
||||||
}
|
}
|
||||||
if filter.HasIncompleteTasks {
|
if filterExpr.HasIncompleteTasks {
|
||||||
find.PayloadFind.HasIncompleteTasks = true
|
find.PayloadFind.HasIncompleteTasks = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -181,16 +181,16 @@ func parseMemoFilter(expression string) (*MemoFilter, error) {
|
|||||||
return nil, errors.Errorf("found issue %v", issues)
|
return nil, errors.Errorf("found issue %v", issues)
|
||||||
}
|
}
|
||||||
filter := &MemoFilter{}
|
filter := &MemoFilter{}
|
||||||
expr, err := cel.AstToParsedExpr(ast)
|
parsedExpr, err := cel.AstToParsedExpr(ast)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
callExpr := expr.GetExpr().GetCallExpr()
|
callExpr := parsedExpr.GetExpr().GetCallExpr()
|
||||||
findMemoField(callExpr, filter)
|
findMemoField(callExpr, filter)
|
||||||
return filter, nil
|
return filter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func findMemoField(callExpr *expr.Expr_Call, filter *MemoFilter) {
|
func findMemoField(callExpr *exprv1.Expr_Call, filter *MemoFilter) {
|
||||||
if len(callExpr.Args) == 2 {
|
if len(callExpr.Args) == 2 {
|
||||||
idExpr := callExpr.Args[0].GetIdentExpr()
|
idExpr := callExpr.Args[0].GetIdentExpr()
|
||||||
if idExpr != nil {
|
if idExpr != nil {
|
||||||
|
96
store/db/sqlite/memo_filter.go
Normal file
96
store/db/sqlite/memo_filter.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
|
||||||
|
var condition string
|
||||||
|
switch v := expr.ExprKind.(type) {
|
||||||
|
case *exprv1.Expr_CallExpr:
|
||||||
|
switch v.CallExpr.Function {
|
||||||
|
case "_||_", "_&&_":
|
||||||
|
for _, arg := range v.CallExpr.Args {
|
||||||
|
left, err := RestoreExprToSQL(arg)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
right, err := RestoreExprToSQL(arg)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
operator := "AND"
|
||||||
|
if v.CallExpr.Function == "_||_" {
|
||||||
|
operator = "OR"
|
||||||
|
}
|
||||||
|
condition = fmt.Sprintf("(%s %s %s)", left, operator, right)
|
||||||
|
}
|
||||||
|
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||||
|
if len(v.CallExpr.Args) != 2 {
|
||||||
|
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
// TODO(j): Implement this part.
|
||||||
|
case "@in":
|
||||||
|
if len(v.CallExpr.Args) != 2 {
|
||||||
|
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
factor := v.CallExpr.Args[0].GetIdentExpr().Name
|
||||||
|
if !slices.Contains([]string{"tag"}, factor) {
|
||||||
|
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
values := []any{}
|
||||||
|
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||||
|
value, err := filter.GetConstValue(element)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
if factor == "tag" {
|
||||||
|
t := []string{}
|
||||||
|
for _, v := range values {
|
||||||
|
t = append(t, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v)))
|
||||||
|
}
|
||||||
|
if len(t) == 1 {
|
||||||
|
condition = t[0]
|
||||||
|
} else {
|
||||||
|
condition = fmt.Sprintf("(%s)", strings.Join(t, " OR "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "contains":
|
||||||
|
if len(v.CallExpr.Args) != 1 {
|
||||||
|
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
factor, err := RestoreExprToSQL(v.CallExpr.Target)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if factor != "content" {
|
||||||
|
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
condition = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg))
|
||||||
|
case "!_":
|
||||||
|
if len(v.CallExpr.Args) != 1 {
|
||||||
|
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||||
|
}
|
||||||
|
arg, err := RestoreExprToSQL(v.CallExpr.Args[0])
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
condition = fmt.Sprintf("NOT (%s)", arg)
|
||||||
|
}
|
||||||
|
case *exprv1.Expr_IdentExpr:
|
||||||
|
return v.IdentExpr.GetName(), nil
|
||||||
|
}
|
||||||
|
return condition, nil
|
||||||
|
}
|
40
store/db/sqlite/memo_filter_test.go
Normal file
40
store/db/sqlite/memo_filter_test.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRestoreExprToSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
filter string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
filter: `tag in ["tag1", "tag2"]`,
|
||||||
|
want: "(JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag1\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag2\"%)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `!(tag in ["tag1", "tag2"])`,
|
||||||
|
want: "NOT ((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag1\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag2\"%))",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `tag in ["tag1", "tag2"] || tag in ["tag3", "tag4"]`,
|
||||||
|
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%))",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
filter: `content.contains("hello")`,
|
||||||
|
want: "JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %\"hello\"%",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
parsedExpr, err := filter.Parse(tt.filter)
|
||||||
|
require.NoError(t, err)
|
||||||
|
result, err := RestoreExprToSQL(parsedExpr.GetExpr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.want, result)
|
||||||
|
}
|
||||||
|
}
|
@@ -1,13 +0,0 @@
|
|||||||
package store
|
|
||||||
|
|
||||||
type LogicOperator string
|
|
||||||
|
|
||||||
const (
|
|
||||||
AND LogicOperator = "AND"
|
|
||||||
OR LogicOperator = "OR"
|
|
||||||
)
|
|
||||||
|
|
||||||
type QueryExpression struct {
|
|
||||||
Operator LogicOperator
|
|
||||||
Children []*QueryExpression
|
|
||||||
}
|
|
Reference in New Issue
Block a user