mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
feat: validate shortcut's filter
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/emptypb"
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/usememos/memos/internal/util"
|
"github.com/usememos/memos/internal/util"
|
||||||
"github.com/usememos/memos/plugin/filter"
|
"github.com/usememos/memos/plugin/filter"
|
||||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||||
@@ -78,10 +79,7 @@ func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateS
|
|||||||
if newShortcut.Title == "" {
|
if newShortcut.Title == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "title is required")
|
return nil, status.Errorf(codes.InvalidArgument, "title is required")
|
||||||
}
|
}
|
||||||
if newShortcut.Filter == "" {
|
if err := s.validateFilter(ctx, newShortcut.Filter); err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "filter is required")
|
|
||||||
}
|
|
||||||
if _, err := filter.Parse(newShortcut.Filter, filter.MemoFilterCELAttributes...); err != nil {
|
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||||
}
|
}
|
||||||
if request.ValidateOnly {
|
if request.ValidateOnly {
|
||||||
@@ -171,11 +169,7 @@ func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateS
|
|||||||
}
|
}
|
||||||
shortcut.Title = request.Shortcut.GetTitle()
|
shortcut.Title = request.Shortcut.GetTitle()
|
||||||
} else if field == "filter" {
|
} else if field == "filter" {
|
||||||
if request.Shortcut.GetFilter() == "" {
|
if err := s.validateFilter(ctx, request.Shortcut.GetFilter()); err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "filter is required")
|
|
||||||
}
|
|
||||||
// Validate the filter.
|
|
||||||
if _, err := filter.Parse(request.Shortcut.GetFilter(), filter.MemoFilterCELAttributes...); err != nil {
|
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
||||||
}
|
}
|
||||||
shortcut.Filter = request.Shortcut.GetFilter()
|
shortcut.Filter = request.Shortcut.GetFilter()
|
||||||
@@ -244,3 +238,20 @@ func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteS
|
|||||||
|
|
||||||
return &emptypb.Empty{}, nil
|
return &emptypb.Empty{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error {
|
||||||
|
if filterStr == "" {
|
||||||
|
return errors.New("filter cannot be empty")
|
||||||
|
}
|
||||||
|
// Validate the filter.
|
||||||
|
parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to parse filter")
|
||||||
|
}
|
||||||
|
convertCtx := filter.NewConvertContext()
|
||||||
|
err = s.Store.GetDriver().ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to convert filter to SQL")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@@ -118,7 +118,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||||||
}
|
}
|
||||||
convertCtx := filter.NewConvertContext()
|
convertCtx := filter.NewConvertContext()
|
||||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
||||||
|
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/usememos/memos/plugin/filter"
|
"github.com/usememos/memos/plugin/filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
switch v.CallExpr.Function {
|
switch v.CallExpr.Function {
|
||||||
case "_||_", "_&&_":
|
case "_||_", "_&&_":
|
||||||
@@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
operator := "AND"
|
operator := "AND"
|
||||||
@@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
@@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
@@ -52,10 +52,11 @@ func TestConvertExprToSQL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
db := &DB{}
|
||||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
convertCtx := filter.NewConvertContext()
|
convertCtx := filter.NewConvertContext()
|
||||||
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
require.Equal(t, tt.args, convertCtx.Args)
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
@@ -110,7 +110,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||||||
convertCtx := filter.NewConvertContext()
|
convertCtx := filter.NewConvertContext()
|
||||||
convertCtx.ArgsOffset = len(args)
|
convertCtx.ArgsOffset = len(args)
|
||||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
||||||
|
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/usememos/memos/plugin/filter"
|
"github.com/usememos/memos/plugin/filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
switch v.CallExpr.Function {
|
switch v.CallExpr.Function {
|
||||||
case "_||_", "_&&_":
|
case "_||_", "_&&_":
|
||||||
@@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
operator := "AND"
|
operator := "AND"
|
||||||
@@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
@@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
@@ -52,10 +52,11 @@ func TestRestoreExprToSQL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
db := &DB{}
|
||||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
convertCtx := filter.NewConvertContext()
|
convertCtx := filter.NewConvertContext()
|
||||||
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
require.Equal(t, tt.args, convertCtx.Args)
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
@@ -110,7 +110,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
|
|||||||
}
|
}
|
||||||
convertCtx := filter.NewConvertContext()
|
convertCtx := filter.NewConvertContext()
|
||||||
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
|
||||||
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
|
||||||
|
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/usememos/memos/plugin/filter"
|
"github.com/usememos/memos/plugin/filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
||||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||||
switch v.CallExpr.Function {
|
switch v.CallExpr.Function {
|
||||||
case "_||_", "_&&_":
|
case "_||_", "_&&_":
|
||||||
@@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
operator := "AND"
|
operator := "AND"
|
||||||
@@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
@@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
|
|||||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||||
|
@@ -57,10 +57,11 @@ func TestConvertExprToSQL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
db := &DB{}
|
||||||
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
convertCtx := filter.NewConvertContext()
|
convertCtx := filter.NewConvertContext()
|
||||||
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||||
require.Equal(t, tt.args, convertCtx.Args)
|
require.Equal(t, tt.args, convertCtx.Args)
|
||||||
|
@@ -3,6 +3,10 @@ package store
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
|
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/plugin/filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Driver is an interface for store driver.
|
// Driver is an interface for store driver.
|
||||||
@@ -73,4 +77,7 @@ type Driver interface {
|
|||||||
UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error)
|
UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error)
|
||||||
ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error)
|
ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error)
|
||||||
DeleteReaction(ctx context.Context, delete *DeleteReaction) error
|
DeleteReaction(ctx context.Context, delete *DeleteReaction) error
|
||||||
|
|
||||||
|
// Shortcut related methods.
|
||||||
|
ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error
|
||||||
}
|
}
|
||||||
|
@@ -24,6 +24,10 @@ func New(driver Driver, profile *profile.Profile) *Store {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetDriver() Driver {
|
||||||
|
return s.driver
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) Close() error {
|
func (s *Store) Close() error {
|
||||||
return s.driver.Close()
|
return s.driver.Close()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user