diff --git a/server/basic_auth.go b/server/acl.go similarity index 63% rename from server/basic_auth.go rename to server/acl.go index 99e40457..97ab517b 100644 --- a/server/basic_auth.go +++ b/server/acl.go @@ -51,11 +51,10 @@ func removeUserSession(ctx echo.Context) error { return nil } -// Use session to store user.id. -func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { +func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { return func(ctx echo.Context) error { // Skip auth for some paths. - if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:userId") { + if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:id") { return next(ctx) } @@ -76,42 +75,36 @@ func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { } } - needAuth := true - if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet { - if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil { - needAuth = false - } - } - { sess, _ := session.Get("session", ctx) userIDValue := sess.Values[userIDContextKey] - if userIDValue == nil && needAuth { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session") - } - - userID, err := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) - if err != nil && needAuth { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to malformatted user id in the session.").SetInternal(err) - } - - userFind := &api.UserFind{ - ID: &userID, - } - user, err := s.Store.FindUser(userFind) - if err != nil && needAuth { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) - } - if needAuth { - if user == nil { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Not found user ID: %d", userID)) - } else if user.RowStatus == api.Archived { - return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email)) + if userIDValue != nil { + userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) + userFind := &api.UserFind{ + ID: &userID, + } + user, err := s.Store.FindUser(userFind) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) + } + if user != nil { + if user.RowStatus == api.Archived { + return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email)) + } + ctx.Set(getUserIDContextKey(), userID) } } + } - // Save userID into context. - ctx.Set(getUserIDContextKey(), userID) + if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet { + if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil { + return next(ctx) + } + } + + userID := ctx.Get(getUserIDContextKey()) + if userID == nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session") } return next(ctx) diff --git a/server/memo.go b/server/memo.go index 712b6042..8c49c872 100644 --- a/server/memo.go +++ b/server/memo.go @@ -72,8 +72,8 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { memoFind.CreatorID = &userID } - currentUserID := c.Get(getUserIDContextKey()).(int) - if currentUserID == api.UNKNOWN_ID { + currentUserID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { if memoFind.CreatorID == nil { return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") } diff --git a/server/server.go b/server/server.go index 6b72d965..f2573751 100644 --- a/server/server.go +++ b/server/server.go @@ -58,7 +58,7 @@ func NewServer(profile *profile.Profile) *Server { apiGroup := e.Group("/api") apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return BasicAuthMiddleware(s, next) + return aclMiddleware(s, next) }) s.registerSystemRoutes(apiGroup) s.registerAuthRoutes(apiGroup) diff --git a/server/tag.go b/server/tag.go index c4686769..8c755eef 100644 --- a/server/tag.go +++ b/server/tag.go @@ -25,8 +25,8 @@ func (s *Server) registerTagRoutes(g *echo.Group) { memoFind.CreatorID = &userID } - currentUserID := c.Get(getUserIDContextKey()).(int) - if currentUserID == api.UNKNOWN_ID { + currentUserID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { if memoFind.CreatorID == nil { return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") } diff --git a/server/user.go b/server/user.go index 6e6fb41d..5df2793e 100644 --- a/server/user.go +++ b/server/user.go @@ -83,12 +83,11 @@ func (s *Server) registerUserRoutes(g *echo.Group) { // GET /api/user/me is used to check if the user is logged in. g.GET("/user/me", func(c echo.Context) error { - userSessionID := c.Get(getUserIDContextKey()) - if userSessionID == nil { + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") } - userID := userSessionID.(int) userFind := &api.UserFind{ ID: &userID, } diff --git a/store/user.go b/store/user.go index 552bf74c..ba25c2d4 100644 --- a/store/user.go +++ b/store/user.go @@ -255,7 +255,6 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) { &userRaw.UpdatedTs, &userRaw.RowStatus, ); err != nil { - fmt.Println(err) return nil, FormatError(err) }