chore: update acl middleware

This commit is contained in:
boojack 2022-07-27 19:45:37 +08:00
parent 873973a088
commit d83f204d8c
6 changed files with 33 additions and 42 deletions

View File

@ -51,11 +51,10 @@ func removeUserSession(ctx echo.Context) error {
return nil
}
// Use session to store user.id.
func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
return func(ctx echo.Context) error {
// Skip auth for some paths.
if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:userId") {
if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:id") {
return next(ctx)
}
@ -76,42 +75,36 @@ func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
}
}
needAuth := true
if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet {
if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil {
needAuth = false
}
}
{
sess, _ := session.Get("session", ctx)
userIDValue := sess.Values[userIDContextKey]
if userIDValue == nil && needAuth {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session")
}
userID, err := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
if err != nil && needAuth {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to malformatted user id in the session.").SetInternal(err)
}
userFind := &api.UserFind{
ID: &userID,
}
user, err := s.Store.FindUser(userFind)
if err != nil && needAuth {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if needAuth {
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Not found user ID: %d", userID))
} else if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
if userIDValue != nil {
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
userFind := &api.UserFind{
ID: &userID,
}
user, err := s.Store.FindUser(userFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if user != nil {
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
}
ctx.Set(getUserIDContextKey(), userID)
}
}
}
// Save userID into context.
ctx.Set(getUserIDContextKey(), userID)
if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet {
if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil {
return next(ctx)
}
}
userID := ctx.Get(getUserIDContextKey())
if userID == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session")
}
return next(ctx)

View File

@ -72,8 +72,8 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoFind.CreatorID = &userID
}
currentUserID := c.Get(getUserIDContextKey()).(int)
if currentUserID == api.UNKNOWN_ID {
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
if memoFind.CreatorID == nil {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo")
}

View File

@ -58,7 +58,7 @@ func NewServer(profile *profile.Profile) *Server {
apiGroup := e.Group("/api")
apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return BasicAuthMiddleware(s, next)
return aclMiddleware(s, next)
})
s.registerSystemRoutes(apiGroup)
s.registerAuthRoutes(apiGroup)

View File

@ -25,8 +25,8 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
memoFind.CreatorID = &userID
}
currentUserID := c.Get(getUserIDContextKey()).(int)
if currentUserID == api.UNKNOWN_ID {
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
if memoFind.CreatorID == nil {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo")
}

View File

@ -83,12 +83,11 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
// GET /api/user/me is used to check if the user is logged in.
g.GET("/user/me", func(c echo.Context) error {
userSessionID := c.Get(getUserIDContextKey())
if userSessionID == nil {
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
userID := userSessionID.(int)
userFind := &api.UserFind{
ID: &userID,
}

View File

@ -255,7 +255,6 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) {
&userRaw.UpdatedTs,
&userRaw.RowStatus,
); err != nil {
fmt.Println(err)
return nil, FormatError(err)
}