mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
feat: support more factors in filter
This commit is contained in:
@@ -9,7 +9,12 @@ import (
|
||||
// MemoFilterCELAttributes are the CEL attributes for memo.
|
||||
var MemoFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
// As the built-in timestamp type is deprecated, we use string type for now.
|
||||
// e.g., "2021-01-01T00:00:00Z"
|
||||
cel.Variable("create_time", cel.StringType),
|
||||
cel.Variable("tag", cel.StringType),
|
||||
cel.Variable("update_time", cel.StringType),
|
||||
cel.Variable("visibility", cel.StringType),
|
||||
}
|
||||
|
||||
// Parse parses the filter string and returns the parsed expression.
|
||||
|
@@ -141,7 +141,6 @@ var MemoFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("display_time_before", cel.IntType),
|
||||
cel.Variable("display_time_after", cel.IntType),
|
||||
cel.Variable("creator", cel.StringType),
|
||||
cel.Variable("uid", cel.StringType),
|
||||
cel.Variable("state", cel.StringType),
|
||||
cel.Variable("random", cel.BoolType),
|
||||
cel.Variable("limit", cel.IntType),
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
@@ -100,6 +101,20 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
||||
where = append(where, "JSON_EXTRACT(`memo`.`payload`, '$.property.hasIncompleteTasks') IS TRUE")
|
||||
}
|
||||
}
|
||||
if v := find.Filter; v != nil {
|
||||
// Parse filter string and return the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
parsedExpr, err := filter.Parse(*v, filter.MemoFilterCELAttributes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// RestoreExprToSQL parses the expression and returns the SQL condition.
|
||||
condition, err := RestoreExprToSQL(parsedExpr.GetExpr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
where = append(where, condition)
|
||||
}
|
||||
if find.ExcludeComments {
|
||||
where = append(where, "`parent_id` IS NULL")
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
@@ -36,15 +37,55 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
// TODO(j): Implement this part.
|
||||
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
|
||||
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
|
||||
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
value, err := filter.GetConstValue(v.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
operator := "="
|
||||
switch v.CallExpr.Function {
|
||||
case "_==_":
|
||||
operator = "="
|
||||
case "_!=_":
|
||||
operator = "!="
|
||||
case "_<_":
|
||||
operator = "<"
|
||||
case "_>_":
|
||||
operator = ">"
|
||||
case "_<=_":
|
||||
operator = "<="
|
||||
case "_>=_":
|
||||
operator = ">="
|
||||
}
|
||||
|
||||
if identifier == "create_time" || identifier == "update_time" {
|
||||
timestampStr, ok := value.(string)
|
||||
if !ok {
|
||||
return "", errors.New("invalid timestamp value")
|
||||
}
|
||||
timestamp, err := time.Parse(time.RFC3339, timestampStr)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse timestamp")
|
||||
}
|
||||
|
||||
if identifier == "create_time" {
|
||||
condition = fmt.Sprintf("`memo`.`created_ts` %s %d", operator, timestamp.Unix())
|
||||
} else if identifier == "update_time" {
|
||||
condition = fmt.Sprintf("`memo`.`updated_ts` %s %d", operator, timestamp.Unix())
|
||||
}
|
||||
}
|
||||
case "@in":
|
||||
if len(v.CallExpr.Args) != 2 {
|
||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
factor := v.CallExpr.Args[0].GetIdentExpr().Name
|
||||
if !slices.Contains([]string{"tag"}, factor) {
|
||||
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
|
||||
identifier := v.CallExpr.Args[0].GetIdentExpr().GetName()
|
||||
if !slices.Contains([]string{"tag", "visibility"}, identifier) {
|
||||
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range v.CallExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := filter.GetConstValue(element)
|
||||
@@ -53,33 +94,43 @@ func RestoreExprToSQL(expr *exprv1.Expr) (string, error) {
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if factor == "tag" {
|
||||
t := []string{}
|
||||
if identifier == "tag" {
|
||||
subcodition := []string{}
|
||||
for _, v := range values {
|
||||
t = append(t, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v)))
|
||||
subcodition = append(subcodition, fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %s", fmt.Sprintf(`%%"%s"%%`, v)))
|
||||
}
|
||||
if len(t) == 1 {
|
||||
condition = t[0]
|
||||
if len(subcodition) == 1 {
|
||||
condition = subcodition[0]
|
||||
} else {
|
||||
condition = fmt.Sprintf("(%s)", strings.Join(t, " OR "))
|
||||
condition = fmt.Sprintf("(%s)", strings.Join(subcodition, " OR "))
|
||||
}
|
||||
} else if identifier == "visibility" {
|
||||
vs := []string{}
|
||||
for _, v := range values {
|
||||
vs = append(vs, fmt.Sprintf(`"%s"`, v))
|
||||
}
|
||||
if len(vs) == 1 {
|
||||
condition = fmt.Sprintf("`memo`.`visibility` = %s", vs[0])
|
||||
} else {
|
||||
condition = fmt.Sprintf("`memo`.`visibility` IN (%s)", strings.Join(vs, ","))
|
||||
}
|
||||
}
|
||||
case "contains":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
}
|
||||
factor, err := RestoreExprToSQL(v.CallExpr.Target)
|
||||
identifier, err := RestoreExprToSQL(v.CallExpr.Target)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if factor != "content" {
|
||||
return "", errors.Errorf("invalid factor for %s", v.CallExpr.Function)
|
||||
if identifier != "content" {
|
||||
return "", errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
|
||||
}
|
||||
arg, err := filter.GetConstValue(v.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
condition = fmt.Sprintf("JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg))
|
||||
condition = fmt.Sprintf("`memo`.`content` LIKE %s", fmt.Sprintf(`%%"%s"%%`, arg))
|
||||
case "!_":
|
||||
if len(v.CallExpr.Args) != 1 {
|
||||
return "", errors.Errorf("invalid number of arguments for %s", v.CallExpr.Function)
|
||||
|
@@ -26,8 +26,20 @@ func TestRestoreExprToSQL(t *testing.T) {
|
||||
want: "((JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%) OR (JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag3\"% OR JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE %\"tag4\"%))",
|
||||
},
|
||||
{
|
||||
filter: `content.contains("hello")`,
|
||||
want: "JSON_EXTRACT(`memo`.`payload`, '$.content') LIKE %\"hello\"%",
|
||||
filter: `content.contains("memos")`,
|
||||
want: "`memo`.`content` LIKE %\"memos\"%",
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC"]`,
|
||||
want: "`memo`.`visibility` = \"PUBLIC\"",
|
||||
},
|
||||
{
|
||||
filter: `visibility in ["PUBLIC", "PRIVATE"]`,
|
||||
want: "`memo`.`visibility` IN (\"PUBLIC\",\"PRIVATE\")",
|
||||
},
|
||||
{
|
||||
filter: `create_time == "2006-01-02T15:04:05+07:00"`,
|
||||
want: "`memo`.`created_ts` = 1136189045",
|
||||
},
|
||||
}
|
||||
|
||||
|
@@ -74,6 +74,7 @@ type FindMemo struct {
|
||||
ExcludeContent bool
|
||||
ExcludeComments bool
|
||||
Random bool
|
||||
Filter *string
|
||||
|
||||
// Pagination
|
||||
Limit *int
|
||||
|
Reference in New Issue
Block a user