From b7339e00baea4214824e8d43b5f3b56c77631618 Mon Sep 17 00:00:00 2001 From: boojack Date: Tue, 26 Jul 2022 21:12:20 +0800 Subject: [PATCH] feat: update finding memo with visibility --- api/auth.go | 4 ++ api/memo.go | 10 ++-- server/basic_auth.go | 92 +++++++++++++++++------------------ server/memo.go | 37 ++++++++------ server/tag.go | 24 ++++----- server/user.go | 30 ++++++++++++ store/db/seed/10002__memo.sql | 12 +++-- store/memo.go | 9 +++- 8 files changed, 136 insertions(+), 82 deletions(-) diff --git a/api/auth.go b/api/auth.go index ddb7998d..b04310d6 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,5 +1,9 @@ package api +var ( + UNKNOWN_ID = 0 +) + type Signin struct { Email string `json:"email"` Password string `json:"password"` diff --git a/api/memo.go b/api/memo.go index 0d49399e..d2933414 100644 --- a/api/memo.go +++ b/api/memo.go @@ -6,6 +6,8 @@ type Visibility string const ( // Public is the PUBLIC visibility. Public Visibility = "PUBLIC" + // Protected is the PROTECTED visibility. + Protected Visibility = "PROTECTED" // Privite is the PRIVATE visibility. Privite Visibility = "PRIVATE" ) @@ -14,6 +16,8 @@ func (e Visibility) String() string { switch e { case Public: return "PUBLIC" + case Protected: + return "PROTECTED" case Privite: return "PRIVATE" } @@ -65,9 +69,9 @@ type MemoFind struct { CreatorID *int `json:"creatorId"` // Domain specific fields - Pinned *bool - ContentSearch *string - Visibility *Visibility + Pinned *bool + ContentSearch *string + VisibilityList []Visibility // Pagination Limit int diff --git a/server/basic_auth.go b/server/basic_auth.go index 05f844f5..99e40457 100644 --- a/server/basic_auth.go +++ b/server/basic_auth.go @@ -21,30 +21,30 @@ func getUserIDContextKey() string { return userIDContextKey } -func setUserSession(c echo.Context, user *api.User) error { - sess, _ := session.Get("session", c) +func setUserSession(ctx echo.Context, user *api.User) error { + sess, _ := session.Get("session", ctx) sess.Options = &sessions.Options{ Path: "/", MaxAge: 1000 * 3600 * 24 * 30, HttpOnly: true, } sess.Values[userIDContextKey] = user.ID - err := sess.Save(c.Request(), c.Response()) + err := sess.Save(ctx.Request(), ctx.Response()) if err != nil { return fmt.Errorf("failed to set session, err: %w", err) } return nil } -func removeUserSession(c echo.Context) error { - sess, _ := session.Get("session", c) +func removeUserSession(ctx echo.Context) error { + sess, _ := session.Get("session", ctx) sess.Options = &sessions.Options{ Path: "/", MaxAge: 0, HttpOnly: true, } sess.Values[userIDContextKey] = nil - err := sess.Save(c.Request(), c.Response()) + err := sess.Save(ctx.Request(), ctx.Response()) if err != nil { return fmt.Errorf("failed to set session, err: %w", err) } @@ -53,14 +53,14 @@ func removeUserSession(c echo.Context) error { // Use session to store user.id. func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(ctx echo.Context) error { // Skip auth for some paths. - if common.HasPrefixes(c.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:userId") { - return next(c) + if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:userId") { + return next(ctx) } // If there is openId in query string and related user is found, then skip auth. - openID := c.QueryParam("openId") + openID := ctx.QueryParam("openId") if openID != "" { userFind := &api.UserFind{ OpenID: &openID, @@ -71,49 +71,49 @@ func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { } if user != nil { // Stores userID into context. - c.Set(getUserIDContextKey(), user.ID) - return next(c) + ctx.Set(getUserIDContextKey(), user.ID) + return next(ctx) } } - if common.HasPrefixes(c.Path(), "/api/memo", "/api/tag", "/api/shortcut") && c.Request().Method == http.MethodGet { - if _, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { - return next(c) + needAuth := true + if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet { + if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil { + needAuth = false } } - sess, err := session.Get("session", c) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing session").SetInternal(err) + { + sess, _ := session.Get("session", ctx) + userIDValue := sess.Values[userIDContextKey] + if userIDValue == nil && needAuth { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session") + } + + userID, err := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) + if err != nil && needAuth { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to malformatted user id in the session.").SetInternal(err) + } + + userFind := &api.UserFind{ + ID: &userID, + } + user, err := s.Store.FindUser(userFind) + if err != nil && needAuth { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) + } + if needAuth { + if user == nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Not found user ID: %d", userID)) + } else if user.RowStatus == api.Archived { + return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email)) + } + } + + // Save userID into context. + ctx.Set(getUserIDContextKey(), userID) } - userIDValue := sess.Values[userIDContextKey] - if userIDValue == nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session") - } - - userID, err := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to malformatted user id in the session.").SetInternal(err) - } - - // Even if there is no error, we still need to make sure the user still exists. - userFind := &api.UserFind{ - ID: &userID, - } - user, err := s.Store.FindUser(userFind) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) - } - if user == nil { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Not found user ID: %d", userID)) - } else if user.RowStatus == api.Archived { - return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email)) - } - - // Stores userID into context. - c.Set(getUserIDContextKey(), userID) - - return next(c) + return next(ctx) } } diff --git a/server/memo.go b/server/memo.go index 25df508f..712b6042 100644 --- a/server/memo.go +++ b/server/memo.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "strconv" + "strings" "github.com/usememos/memos/api" "github.com/usememos/memos/common" @@ -68,21 +69,21 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { memoFind := &api.MemoFind{} if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { - memoFind.CreatorID = &userID - } else { - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") - } - memoFind.CreatorID = &userID } - // Only can get PUBLIC memos in visitor mode - _, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - publicVisibility := api.Public - memoFind.Visibility = &publicVisibility + currentUserID := c.Get(getUserIDContextKey()).(int) + if currentUserID == api.UNKNOWN_ID { + if memoFind.CreatorID == nil { + return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") + } + memoFind.VisibilityList = []api.Visibility{api.Public} + } else { + if memoFind.CreatorID == nil { + memoFind.CreatorID = ¤tUserID + } else { + memoFind.VisibilityList = []api.Visibility{api.Public, api.Protected} + } } rowStatus := api.RowStatus(c.QueryParam("rowStatus")) @@ -99,6 +100,14 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { contentSearch := "#" + tag + " " memoFind.ContentSearch = &contentSearch } + visibilitListStr := c.QueryParam("visibility") + if visibilitListStr != "" { + visibilityList := []api.Visibility{} + for _, visibility := range strings.Split(visibilitListStr, ",") { + visibilityList = append(visibilityList, api.Visibility(visibility)) + } + memoFind.VisibilityList = visibilityList + } if limit, err := strconv.Atoi(c.QueryParam("limit")); err == nil { memoFind.Limit = limit } @@ -190,9 +199,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { memoDelete := &api.MemoDelete{ ID: memoID, } - - err = s.Store.DeleteMemo(memoDelete) - if err != nil { + if err := s.Store.DeleteMemo(memoDelete); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete memo ID: %v", memoID)).SetInternal(err) } diff --git a/server/tag.go b/server/tag.go index 0737d674..c4686769 100644 --- a/server/tag.go +++ b/server/tag.go @@ -22,21 +22,21 @@ func (s *Server) registerTagRoutes(g *echo.Group) { } if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil { - memoFind.CreatorID = &userID - } else { - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag") - } - memoFind.CreatorID = &userID } - // Only can get PUBLIC memos in visitor mode - _, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - publicVisibility := api.Public - memoFind.Visibility = &publicVisibility + currentUserID := c.Get(getUserIDContextKey()).(int) + if currentUserID == api.UNKNOWN_ID { + if memoFind.CreatorID == nil { + return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") + } + memoFind.VisibilityList = []api.Visibility{api.Public} + } else { + if memoFind.CreatorID == nil { + memoFind.CreatorID = ¤tUserID + } else { + memoFind.VisibilityList = []api.Visibility{api.Public, api.Protected} + } } memoList, err := s.Store.FindMemoList(&memoFind) diff --git a/server/user.go b/server/user.go index 9c9340a5..701041a8 100644 --- a/server/user.go +++ b/server/user.go @@ -182,4 +182,34 @@ func (s *Server) registerUserRoutes(g *echo.Group) { } return nil }) + + g.DELETE("/user/:userId", func(c echo.Context) error { + currentUserID := c.Get(getUserIDContextKey()).(int) + currentUser, err := s.Store.FindUser(&api.UserFind{ + ID: ¤tUserID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + if currentUser == nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Current session user not found with ID: %d", currentUserID)).SetInternal(err) + } else if currentUser.Role != api.Host { + return echo.NewHTTPError(http.StatusForbidden, "Access forbidden for current session user").SetInternal(err) + } + + userID, err := strconv.Atoi(c.Param("userId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("userId"))).SetInternal(err) + } + + userDelete := &api.UserDelete{ + ID: userID, + } + if err := s.Store.DeleteUser(userDelete); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete user").SetInternal(err) + } + + return c.JSON(http.StatusOK, true) + }) + } diff --git a/store/db/seed/10002__memo.sql b/store/db/seed/10002__memo.sql index f1e0206b..c656da0d 100644 --- a/store/db/seed/10002__memo.sql +++ b/store/db/seed/10002__memo.sql @@ -16,7 +16,8 @@ INSERT INTO memo ( `id`, `content`, - `creator_id` + `creator_id`, + `visibility` ) VALUES ( @@ -26,7 +27,8 @@ VALUES - [x] Clean the room; - [x] Read *📖 The Little Prince*; (👆 click to toggle status)', - 101 + 101, + 'PROTECTED' ); INSERT INTO @@ -48,7 +50,8 @@ INSERT INTO memo ( `id`, `content`, - `creator_id` + `creator_id`, + `visibility` ) VALUES ( @@ -59,7 +62,8 @@ VALUES - [ ] Watch *👦 The Boys*; (👆 click to toggle status) ', - 102 + 102, + 'PROTECTED' ); INSERT INTO diff --git a/store/memo.go b/store/memo.go index 7512d93c..cf2cde52 100644 --- a/store/memo.go +++ b/store/memo.go @@ -222,8 +222,13 @@ func findMemoRawList(db *sql.DB, find *api.MemoFind) ([]*memoRaw, error) { if v := find.ContentSearch; v != nil { where, args = append(where, "content LIKE ?"), append(args, "%"+*v+"%") } - if v := find.Visibility; v != nil { - where, args = append(where, "visibility = ?"), append(args, *v) + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for _, visibility := range v { + list = append(list, fmt.Sprintf("$%d", len(args)+1)) + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ","))) } pagination := ""