mirror of
https://github.com/usememos/memos.git
synced 2025-02-15 19:00:46 +01:00
chore: update acl middleware
This commit is contained in:
parent
873973a088
commit
d83f204d8c
@ -51,11 +51,10 @@ func removeUserSession(ctx echo.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use session to store user.id.
|
||||
func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
|
||||
func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(ctx echo.Context) error {
|
||||
// Skip auth for some paths.
|
||||
if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:userId") {
|
||||
if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:id") {
|
||||
return next(ctx)
|
||||
}
|
||||
|
||||
@ -76,42 +75,36 @@ func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
needAuth := true
|
||||
if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet {
|
||||
if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil {
|
||||
needAuth = false
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
sess, _ := session.Get("session", ctx)
|
||||
userIDValue := sess.Values[userIDContextKey]
|
||||
if userIDValue == nil && needAuth {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session")
|
||||
}
|
||||
|
||||
userID, err := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
|
||||
if err != nil && needAuth {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to malformatted user id in the session.").SetInternal(err)
|
||||
}
|
||||
|
||||
userFind := &api.UserFind{
|
||||
ID: &userID,
|
||||
}
|
||||
user, err := s.Store.FindUser(userFind)
|
||||
if err != nil && needAuth {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
|
||||
}
|
||||
if needAuth {
|
||||
if user == nil {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Not found user ID: %d", userID))
|
||||
} else if user.RowStatus == api.Archived {
|
||||
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
|
||||
if userIDValue != nil {
|
||||
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
|
||||
userFind := &api.UserFind{
|
||||
ID: &userID,
|
||||
}
|
||||
user, err := s.Store.FindUser(userFind)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
|
||||
}
|
||||
if user != nil {
|
||||
if user.RowStatus == api.Archived {
|
||||
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
|
||||
}
|
||||
ctx.Set(getUserIDContextKey(), userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save userID into context.
|
||||
ctx.Set(getUserIDContextKey(), userID)
|
||||
if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet {
|
||||
if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil {
|
||||
return next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
userID := ctx.Get(getUserIDContextKey())
|
||||
if userID == nil {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session")
|
||||
}
|
||||
|
||||
return next(ctx)
|
@ -72,8 +72,8 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
|
||||
memoFind.CreatorID = &userID
|
||||
}
|
||||
|
||||
currentUserID := c.Get(getUserIDContextKey()).(int)
|
||||
if currentUserID == api.UNKNOWN_ID {
|
||||
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
|
||||
if !ok {
|
||||
if memoFind.CreatorID == nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo")
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ func NewServer(profile *profile.Profile) *Server {
|
||||
|
||||
apiGroup := e.Group("/api")
|
||||
apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return BasicAuthMiddleware(s, next)
|
||||
return aclMiddleware(s, next)
|
||||
})
|
||||
s.registerSystemRoutes(apiGroup)
|
||||
s.registerAuthRoutes(apiGroup)
|
||||
|
@ -25,8 +25,8 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
|
||||
memoFind.CreatorID = &userID
|
||||
}
|
||||
|
||||
currentUserID := c.Get(getUserIDContextKey()).(int)
|
||||
if currentUserID == api.UNKNOWN_ID {
|
||||
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
|
||||
if !ok {
|
||||
if memoFind.CreatorID == nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo")
|
||||
}
|
||||
|
@ -83,12 +83,11 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
||||
|
||||
// GET /api/user/me is used to check if the user is logged in.
|
||||
g.GET("/user/me", func(c echo.Context) error {
|
||||
userSessionID := c.Get(getUserIDContextKey())
|
||||
if userSessionID == nil {
|
||||
userID, ok := c.Get(getUserIDContextKey()).(int)
|
||||
if !ok {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
|
||||
}
|
||||
|
||||
userID := userSessionID.(int)
|
||||
userFind := &api.UserFind{
|
||||
ID: &userID,
|
||||
}
|
||||
|
@ -255,7 +255,6 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) {
|
||||
&userRaw.UpdatedTs,
|
||||
&userRaw.RowStatus,
|
||||
); err != nil {
|
||||
fmt.Println(err)
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user