refactor: migrate auth routes to v1 package (#1841)

* feat: add api v1 packages

* chore: migrate auth to v1

* chore: update test
This commit is contained in:
boojack
2023-06-17 21:25:46 +08:00
committed by GitHub
parent f1d85eeaec
commit 4ed9a3a0ea
16 changed files with 180 additions and 131 deletions

View File

@@ -1,257 +0,0 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"github.com/pkg/errors"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
"github.com/usememos/memos/store"
"github.com/labstack/echo/v4"
"golang.org/x/crypto/bcrypt"
)
func (s *Server) registerAuthRoutes(g *echo.Group, secret string) {
g.POST("/auth/signin", func(c echo.Context) error {
ctx := c.Request().Context()
signin := &api.SignIn{}
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
}
userFind := &api.UserFind{
Username: &signin.Username,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again")
} else if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", signin.Username))
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(signin.Password)); err != nil {
// If the two passwords don't match, return a 401 status.
return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again")
}
if err := GenerateTokensAndSetCookies(c, user, secret); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
}
if err := s.createUserAuthSignInActivity(c, user); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(user))
})
g.POST("/auth/signin/sso", func(c echo.Context) error {
ctx := c.Request().Context()
signin := &api.SSOSignIn{}
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
}
identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
ID: &signin.IdentityProviderID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider").SetInternal(err)
}
var userInfo *idp.IdentityProviderUserInfo
if identityProviderMessage.Type == store.IdentityProviderOAuth2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProviderMessage.Config.OAuth2Config)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, signin.RedirectURI, signin.Code)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to exchange token").SetInternal(err)
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user info").SetInternal(err)
}
}
identifierFilter := identityProviderMessage.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compile identifier filter").SetInternal(err)
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return echo.NewHTTPError(http.StatusUnauthorized, "Access denied, identifier does not match the filter.").SetInternal(err)
}
}
user, err := s.Store.FindUser(ctx, &api.UserFind{
Username: &userInfo.Identifier,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
}
if user == nil {
userCreate := &api.UserCreate{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: api.NormalUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
Password: userInfo.Email,
OpenID: common.GenUUID(),
}
password, err := common.RandomString(20)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate random password").SetInternal(err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
}
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier))
}
if err := GenerateTokensAndSetCookies(c, user, secret); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
}
if err := s.createUserAuthSignInActivity(c, user); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(user))
})
g.POST("/auth/signup", func(c echo.Context) error {
ctx := c.Request().Context()
signup := &api.SignUp{}
if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err)
}
userCreate := &api.UserCreate{
Username: signup.Username,
// The new signup user should be normal user by default.
Role: api.NormalUser,
Nickname: signup.Username,
Password: signup.Password,
OpenID: common.GenUUID(),
}
hostUserType := api.Host
existedHostUsers, err := s.Store.FindUserList(ctx, &api.UserFind{
Role: &hostUserType,
})
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Failed to find users").SetInternal(err)
}
if len(existedHostUsers) == 0 {
// Change the default role to host if there is no host user.
userCreate.Role = api.Host
} else {
allowSignUpSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
Name: api.SystemSettingAllowSignUpName,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
}
allowSignUpSettingValue := false
if allowSignUpSetting != nil {
err = json.Unmarshal([]byte(allowSignUpSetting.Value), &allowSignUpSettingValue)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting allow signup").SetInternal(err)
}
}
if !allowSignUpSettingValue {
return echo.NewHTTPError(http.StatusUnauthorized, "signup is disabled").SetInternal(err)
}
}
if err := userCreate.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user create format").SetInternal(err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(signup.Password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
userCreate.PasswordHash = string(passwordHash)
user, err := s.Store.CreateUser(ctx, userCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
if err := GenerateTokensAndSetCookies(c, user, secret); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err)
}
if err := s.createUserAuthSignUpActivity(c, user); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(user))
})
g.POST("/auth/signout", func(c echo.Context) error {
RemoveTokensAndCookies(c)
return c.JSON(http.StatusOK, true)
})
}
func (s *Server) createUserAuthSignInActivity(c echo.Context, user *api.User) error {
ctx := c.Request().Context()
payload := api.ActivityUserAuthSignInPayload{
UserID: user.ID,
IP: echo.ExtractIPFromRealIPHeader()(c.Request()),
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return errors.Wrap(err, "failed to marshal activity payload")
}
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{
CreatorID: user.ID,
Type: api.ActivityUserAuthSignIn,
Level: api.ActivityInfo,
Payload: string(payloadBytes),
})
if err != nil || activity == nil {
return errors.Wrap(err, "failed to create activity")
}
return err
}
func (s *Server) createUserAuthSignUpActivity(c echo.Context, user *api.User) error {
ctx := c.Request().Context()
payload := api.ActivityUserAuthSignUpPayload{
Username: user.Username,
IP: echo.ExtractIPFromRealIPHeader()(c.Request()),
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return errors.Wrap(err, "failed to marshal activity payload")
}
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{
CreatorID: user.ID,
Type: api.ActivityUserAuthSignUp,
Level: api.ActivityInfo,
Payload: string(payloadBytes),
})
if err != nil || activity == nil {
return errors.Wrap(err, "failed to create activity")
}
return err
}

View File

@@ -1,10 +1,14 @@
package auth
import (
"net/http"
"strconv"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/api"
)
const (
@@ -59,6 +63,48 @@ func GenerateRefreshToken(userName string, userID int, secret string) (string, e
return generateToken(userName, userID, RefreshTokenAudienceName, expirationTime, []byte(secret))
}
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
func GenerateTokensAndSetCookies(c echo.Context, user *api.User, secret string) error {
accessToken, err := GenerateAccessToken(user.Username, user.ID, secret)
if err != nil {
return errors.Wrap(err, "failed to generate access token")
}
cookieExp := time.Now().Add(CookieExpDuration)
setTokenCookie(c, AccessTokenCookieName, accessToken, cookieExp)
// We generate here a new refresh token and saving it to the cookie.
refreshToken, err := GenerateRefreshToken(user.Username, user.ID, secret)
if err != nil {
return errors.Wrap(err, "failed to generate refresh token")
}
setTokenCookie(c, RefreshTokenCookieName, refreshToken, cookieExp)
return nil
}
// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies.
func RemoveTokensAndCookies(c echo.Context) {
// We set the expiration time to the past, so that the cookie will be removed.
cookieExp := time.Now().Add(-1 * time.Hour)
setTokenCookie(c, AccessTokenCookieName, "", cookieExp)
setTokenCookie(c, RefreshTokenCookieName, "", cookieExp)
}
// setTokenCookie sets the token to the cookie.
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
cookie := new(http.Cookie)
cookie.Name = name
cookie.Value = token
cookie.Expires = expiration
cookie.Path = "/"
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteStrictMode
c.SetCookie(cookie)
}
// generateToken generates a jwt token.
func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) {
// Create the JWT claims, which includes the username and expiry time.
claims := &claimsMessage{

View File

@@ -27,12 +27,12 @@ func defaultAPIRequestSkipper(c echo.Context) bool {
return common.HasPrefixes(path, "/api")
}
func (server *Server) defaultAuthSkipper(c echo.Context) bool {
func (s *Server) defaultAuthSkipper(c echo.Context) bool {
ctx := c.Request().Context()
path := c.Path()
// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
if common.HasPrefixes(path, "/api/v1/auth") {
return true
}
@@ -42,7 +42,7 @@ func (server *Server) defaultAuthSkipper(c echo.Context) bool {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := server.Store.FindUser(ctx, userFind)
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return false
}

View File

@@ -33,47 +33,6 @@ func getUserIDContextKey() string {
return userIDContextKey
}
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
func GenerateTokensAndSetCookies(c echo.Context, user *api.User, secret string) error {
accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, secret)
if err != nil {
return errors.Wrap(err, "failed to generate access token")
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
// We generate here a new refresh token and saving it to the cookie.
refreshToken, err := auth.GenerateRefreshToken(user.Username, user.ID, secret)
if err != nil {
return errors.Wrap(err, "failed to generate refresh token")
}
setTokenCookie(c, auth.RefreshTokenCookieName, refreshToken, cookieExp)
return nil
}
// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies.
func RemoveTokensAndCookies(c echo.Context) {
// We set the expiration time to the past, so that the cookie will be removed.
cookieExp := time.Now().Add(-1 * time.Hour)
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
setTokenCookie(c, auth.RefreshTokenCookieName, "", cookieExp)
}
// Here we are creating a new cookie, which will store the valid JWT token.
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
cookie := new(http.Cookie)
cookie.Name = name
cookie.Value = token
cookie.Expires = expiration
cookie.Path = "/"
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteStrictMode
c.SetCookie(cookie)
}
func extractTokenFromHeader(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
@@ -101,6 +60,15 @@ func findAccessToken(c echo.Context) string {
return accessToken
}
func audienceContains(audience jwt.ClaimStrings, token string) bool {
for _, v := range audience {
if v == token {
return true
}
}
return false
}
// JWTMiddleware validates the access token.
// If the access token is about to expire or has expired and the request has a valid refresh token, it
// will try to generate new access token and refresh token.
@@ -226,7 +194,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
// If we have a valid refresh token, we will generate new access token and refresh token
if refreshToken != nil && refreshToken.Valid {
if err := GenerateTokensAndSetCookies(c, user, secret); err != nil {
if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err)
}
}
@@ -246,12 +214,3 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
return next(c)
}
}
func audienceContains(audience jwt.ClaimStrings, token string) bool {
for _, v := range audience {
if v == token {
return true
}
}
return false
}

View File

@@ -2,53 +2,45 @@ package server
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/usememos/memos/api"
apiV1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/plugin/telegram"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
"github.com/usememos/memos/store/db"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
type Server struct {
e *echo.Echo
db *sql.DB
e *echo.Echo
ID string
Secret string
Profile *profile.Profile
Store *store.Store
telegramBot *telegram.Bot
}
func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) {
e := echo.New()
e.Debug = true
e.HideBanner = true
e.HidePort = true
db := db.NewDB(profile)
if err := db.Open(ctx); err != nil {
return nil, errors.Wrap(err, "cannot open db")
}
s := &Server{
e: e,
db: db.DBInstance,
Store: store,
Profile: profile,
}
storeInstance := store.New(db.DBInstance, profile)
s.Store = storeInstance
telegramBotHandler := newTelegramHandler(storeInstance)
telegramBotHandler := newTelegramHandler(store)
s.telegramBot = telegram.NewBotWithHandler(telegramBotHandler)
e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
@@ -89,23 +81,23 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
return nil, err
}
}
s.Secret = secret
rootGroup := e.Group("")
s.registerRSSRoutes(rootGroup)
publicGroup := e.Group("/o")
publicGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, secret)
return JWTMiddleware(s, next, s.Secret)
})
registerGetterPublicRoutes(publicGroup)
s.registerResourcePublicRoutes(publicGroup)
apiGroup := e.Group("/api")
apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, secret)
return JWTMiddleware(s, next, s.Secret)
})
s.registerSystemRoutes(apiGroup)
s.registerAuthRoutes(apiGroup, secret)
s.registerUserRoutes(apiGroup)
s.registerMemoRoutes(apiGroup)
s.registerMemoResourceRoutes(apiGroup)
@@ -117,6 +109,9 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
s.registerOpenAIRoutes(apiGroup)
s.registerMemoRelationRoutes(apiGroup)
apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store)
apiV1Service.Register(e)
return s, nil
}
@@ -140,7 +135,7 @@ func (s *Server) Shutdown(ctx context.Context) {
}
// Close database connection
if err := s.db.Close(); err != nil {
if err := s.Store.GetDB().Close(); err != nil {
fmt.Printf("failed to close database, error: %v\n", err)
}