refactor: migration idp api (#1842)

* refactor: migration idp api

* chore: update
This commit is contained in:
boojack 2023-06-17 22:35:17 +08:00 committed by GitHub
parent 4ed9a3a0ea
commit b34aded376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 410 additions and 119 deletions

View File

@ -1,58 +0,0 @@
package api
type IdentityProviderType string
const (
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
)
type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
}
type IdentityProviderOAuth2Config struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthURL string `json:"authUrl"`
TokenURL string `json:"tokenUrl"`
UserInfoURL string `json:"userInfoUrl"`
Scopes []string `json:"scopes"`
FieldMapping *FieldMapping `json:"fieldMapping"`
}
type FieldMapping struct {
Identifier string `json:"identifier"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
}
type IdentityProvider struct {
ID int `json:"id"`
Name string `json:"name"`
Type IdentityProviderType `json:"type"`
IdentifierFilter string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type IdentityProviderCreate struct {
Name string `json:"name"`
Type IdentityProviderType `json:"type"`
IdentifierFilter string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type IdentityProviderFind struct {
ID *int
}
type IdentityProviderPatch struct {
ID int
Type IdentityProviderType `json:"type"`
Name *string `json:"name"`
IdentifierFilter *string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type IdentityProviderDelete struct {
ID int
}

24
api/v1/common.go Normal file
View File

@ -0,0 +1,24 @@
package v1
// UnknownID is the ID for unknowns.
const UnknownID = -1
// RowStatus is the status for a row.
type RowStatus string
const (
// Normal is the status for a normal row.
Normal RowStatus = "NORMAL"
// Archived is the status for an archived row.
Archived RowStatus = "ARCHIVED"
)
func (e RowStatus) String() string {
switch e {
case Normal:
return "NORMAL"
case Archived:
return "ARCHIVED"
}
return ""
}

View File

@ -1,4 +1,4 @@
package server package v1
import ( import (
"encoding/json" "encoding/json"
@ -7,12 +7,60 @@ import (
"strconv" "strconv"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (s *Server) registerIdentityProviderRoutes(g *echo.Group) { type IdentityProviderType string
const (
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
)
type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
}
type IdentityProviderOAuth2Config struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthURL string `json:"authUrl"`
TokenURL string `json:"tokenUrl"`
UserInfoURL string `json:"userInfoUrl"`
Scopes []string `json:"scopes"`
FieldMapping *FieldMapping `json:"fieldMapping"`
}
type FieldMapping struct {
Identifier string `json:"identifier"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
}
type IdentityProvider struct {
ID int `json:"id"`
Name string `json:"name"`
Type IdentityProviderType `json:"type"`
IdentifierFilter string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type IdentityProviderCreate struct {
Name string `json:"name"`
Type IdentityProviderType `json:"type"`
IdentifierFilter string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type IdentityProviderPatch struct {
ID int
Type IdentityProviderType `json:"type"`
Name *string `json:"name"`
IdentifierFilter *string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
g.POST("/idp", func(c echo.Context) error { g.POST("/idp", func(c echo.Context) error {
ctx := c.Request().Context() ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
@ -20,17 +68,17 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.Host {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
identityProviderCreate := &api.IdentityProviderCreate{} identityProviderCreate := &IdentityProviderCreate{}
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil { if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
} }
@ -44,7 +92,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage))) return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
}) })
g.PATCH("/idp/:idpId", func(c echo.Context) error { g.PATCH("/idp/:idpId", func(c echo.Context) error {
@ -54,13 +102,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.Host {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -69,7 +117,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
} }
identityProviderPatch := &api.IdentityProviderPatch{ identityProviderPatch := &IdentityProviderPatch{
ID: identityProviderID, ID: identityProviderID,
} }
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil { if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil {
@ -86,7 +134,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage))) return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
}) })
g.GET("/idp", func(c echo.Context) error { g.GET("/idp", func(c echo.Context) error {
@ -99,18 +147,18 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
isHostUser := false isHostUser := false
if ok { if ok {
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
ID: &userID, ID: &userID,
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user != nil && user.Role == api.Host { if user == nil || user.Role == store.Host {
isHostUser = true isHostUser = true
} }
} }
identityProviderList := []*api.IdentityProvider{} identityProviderList := []*IdentityProvider{}
for _, identityProviderMessage := range identityProviderMessageList { for _, identityProviderMessage := range identityProviderMessageList {
identityProvider := convertIdentityProviderFromStore(identityProviderMessage) identityProvider := convertIdentityProviderFromStore(identityProviderMessage)
// data desensitize // data desensitize
@ -119,7 +167,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
} }
identityProviderList = append(identityProviderList, identityProvider) identityProviderList = append(identityProviderList, identityProvider)
} }
return c.JSON(http.StatusOK, composeResponse(identityProviderList)) return c.JSON(http.StatusOK, identityProviderList)
}) })
g.GET("/idp/:idpId", func(c echo.Context) error { g.GET("/idp/:idpId", func(c echo.Context) error {
@ -129,14 +177,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
// We should only show identity provider list to host user. if user == nil || user.Role != store.Host {
if user == nil || user.Role != api.Host {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -150,7 +197,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage))) return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
}) })
g.DELETE("/idp/:idpId", func(c echo.Context) error { g.DELETE("/idp/:idpId", func(c echo.Context) error {
@ -160,13 +207,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.Host {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
@ -185,26 +232,26 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
}) })
} }
func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *api.IdentityProvider { func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *IdentityProvider {
return &api.IdentityProvider{ return &IdentityProvider{
ID: identityProviderMessage.ID, ID: identityProviderMessage.ID,
Name: identityProviderMessage.Name, Name: identityProviderMessage.Name,
Type: api.IdentityProviderType(identityProviderMessage.Type), Type: IdentityProviderType(identityProviderMessage.Type),
IdentifierFilter: identityProviderMessage.IdentifierFilter, IdentifierFilter: identityProviderMessage.IdentifierFilter,
Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config), Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config),
} }
} }
func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *api.IdentityProviderConfig { func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *IdentityProviderConfig {
return &api.IdentityProviderConfig{ return &IdentityProviderConfig{
OAuth2Config: &api.IdentityProviderOAuth2Config{ OAuth2Config: &IdentityProviderOAuth2Config{
ClientID: config.OAuth2Config.ClientID, ClientID: config.OAuth2Config.ClientID,
ClientSecret: config.OAuth2Config.ClientSecret, ClientSecret: config.OAuth2Config.ClientSecret,
AuthURL: config.OAuth2Config.AuthURL, AuthURL: config.OAuth2Config.AuthURL,
TokenURL: config.OAuth2Config.TokenURL, TokenURL: config.OAuth2Config.TokenURL,
UserInfoURL: config.OAuth2Config.UserInfoURL, UserInfoURL: config.OAuth2Config.UserInfoURL,
Scopes: config.OAuth2Config.Scopes, Scopes: config.OAuth2Config.Scopes,
FieldMapping: &api.FieldMapping{ FieldMapping: &FieldMapping{
Identifier: config.OAuth2Config.FieldMapping.Identifier, Identifier: config.OAuth2Config.FieldMapping.Identifier,
DisplayName: config.OAuth2Config.FieldMapping.DisplayName, DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
Email: config.OAuth2Config.FieldMapping.Email, Email: config.OAuth2Config.FieldMapping.Email,
@ -213,7 +260,7 @@ func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig
} }
} }
func convertIdentityProviderConfigToStore(config *api.IdentityProviderConfig) *store.IdentityProviderConfig { func convertIdentityProviderConfigToStore(config *IdentityProviderConfig) *store.IdentityProviderConfig {
return &store.IdentityProviderConfig{ return &store.IdentityProviderConfig{
OAuth2Config: &store.IdentityProviderOAuth2Config{ OAuth2Config: &store.IdentityProviderOAuth2Config{
ClientID: config.OAuth2Config.ClientID, ClientID: config.OAuth2Config.ClientID,

25
api/v1/user.go Normal file
View File

@ -0,0 +1,25 @@
package v1
// Role is the type of a role.
type Role string
const (
// Host is the HOST role.
Host Role = "HOST"
// Admin is the ADMIN role.
Admin Role = "ADMIN"
// NormalUser is the USER role.
NormalUser Role = "USER"
)
func (e Role) String() string {
switch e {
case Host:
return "HOST"
case Admin:
return "ADMIN"
case NormalUser:
return "USER"
}
return "USER"
}

View File

@ -6,6 +6,17 @@ import (
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
const (
// Context section
// The key name used to store user id in the context
// user id is extracted from the jwt token subject field.
userIDContextKey = "user-id"
)
func getUserIDContextKey() string {
return userIDContextKey
}
type APIV1Service struct { type APIV1Service struct {
Secret string Secret string
Profile *profile.Profile Profile *profile.Profile
@ -20,8 +31,8 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
} }
} }
func (s *APIV1Service) Register(e *echo.Echo) { func (s *APIV1Service) Register(apiV1Group *echo.Group) {
apiV1Group := e.Group("/api/v1")
s.registerTestRoutes(apiV1Group) s.registerTestRoutes(apiV1Group)
s.registerAuthRoutes(apiV1Group, s.Secret) s.registerAuthRoutes(apiV1Group, s.Secret)
s.registerIdentityProviderRoutes(apiV1Group)
} }

View File

@ -22,6 +22,10 @@ const (
userIDContextKey = "user-id" userIDContextKey = "user-id"
) )
func getUserIDContextKey() string {
return userIDContextKey
}
// Claims creates a struct that will be encoded to a JWT. // Claims creates a struct that will be encoded to a JWT.
// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. // We add jwt.RegisteredClaims as an embedded type, to provide fields such as name.
type Claims struct { type Claims struct {
@ -29,10 +33,6 @@ type Claims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
} }
func getUserIDContextKey() string {
return userIDContextKey
}
func extractTokenFromHeader(c echo.Context) (string, error) { func extractTokenFromHeader(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization") authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" { if authHeader == "" {
@ -82,7 +82,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
} }
// Skip validation for server status endpoints. // Skip validation for server status endpoints.
if common.HasPrefixes(path, "/api/ping", "/api/idp", "/api/user/:id") && method == http.MethodGet { if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
return next(c) return next(c)
} }
@ -111,15 +111,9 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
} }
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
}) })
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
return echo.NewHTTPError(http.StatusUnauthorized, return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName))
fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
claims.Audience,
auth.AccessTokenAudienceName,
))
} }
generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration
if err != nil { if err != nil {
var ve *jwt.ValidationError var ve *jwt.ValidationError
@ -130,11 +124,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
generateToken = true generateToken = true
} }
} else { } else {
return &echo.HTTPError{ return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token"))
Code: http.StatusUnauthorized,
Message: "Invalid or expired access token",
Internal: err,
}
} }
} }

View File

@ -105,12 +105,15 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
s.registerResourceRoutes(apiGroup) s.registerResourceRoutes(apiGroup)
s.registerTagRoutes(apiGroup) s.registerTagRoutes(apiGroup)
s.registerStorageRoutes(apiGroup) s.registerStorageRoutes(apiGroup)
s.registerIdentityProviderRoutes(apiGroup)
s.registerOpenAIRoutes(apiGroup) s.registerOpenAIRoutes(apiGroup)
s.registerMemoRelationRoutes(apiGroup) s.registerMemoRelationRoutes(apiGroup)
apiV1Group := e.Group("/api/v1")
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret)
})
apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store) apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store)
apiV1Service.Register(e) apiV1Service.Register(apiV1Group)
return s, nil return s, nil
} }

View File

@ -27,6 +27,211 @@ func (s *Store) SeedDataForNewUser(ctx context.Context, user *api.User) error {
return err return err
} }
// Role is the type of a role.
type Role string
const (
// Host is the HOST role.
Host Role = "HOST"
// Admin is the ADMIN role.
Admin Role = "ADMIN"
// NormalUser is the USER role.
NormalUser Role = "USER"
)
func (e Role) String() string {
switch e {
case Host:
return "HOST"
case Admin:
return "ADMIN"
case NormalUser:
return "USER"
}
return "USER"
}
type UserMessage struct {
ID int
// Standard fields
RowStatus RowStatus
CreatedTs int64
UpdatedTs int64
// Domain specific fields
Username string
Role Role
Email string
Nickname string
PasswordHash string
OpenID string
AvatarURL string
}
type FindUserMessage struct {
ID *int
// Standard fields
RowStatus *RowStatus
// Domain specific fields
Username *string
Role *Role
Email *string
Nickname *string
OpenID *string
}
func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMessage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
query := `
INSERT INTO user (
username,
role,
email,
nickname,
password_hash,
open_id
)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id, avatar_url, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(ctx, query,
create.Username,
create.Role,
create.Email,
create.Nickname,
create.PasswordHash,
create.OpenID,
).Scan(
&create.ID,
&create.AvatarURL,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, FormatError(err)
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
userMessage := create
return userMessage, nil
}
func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMessage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := listUsers(ctx, tx, find)
if err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := listUsers(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
}
memoMessage := list[0]
return memoMessage, nil
}
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserMessage, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.Username; v != nil {
where, args = append(where, "username = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
if v := find.Email; v != nil {
where, args = append(where, "email = ?"), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = ?"), append(args, *v)
}
if v := find.OpenID; v != nil {
where, args = append(where, "open_id = ?"), append(args, *v)
}
query := `
SELECT
id,
username,
role,
email,
nickname,
password_hash,
open_id,
avatar_url,
created_ts,
updated_ts,
row_status
FROM user
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
}
defer rows.Close()
userMessageList := make([]*UserMessage, 0)
for rows.Next() {
var userMessage UserMessage
if err := rows.Scan(
&userMessage.ID,
&userMessage.Username,
&userMessage.Role,
&userMessage.Email,
&userMessage.Nickname,
&userMessage.PasswordHash,
&userMessage.OpenID,
&userMessage.AvatarURL,
&userMessage.CreatedTs,
&userMessage.UpdatedTs,
&userMessage.RowStatus,
); err != nil {
return nil, FormatError(err)
}
userMessageList = append(userMessageList, &userMessage)
}
if err := rows.Err(); err != nil {
return nil, FormatError(err)
}
return userMessageList, nil
}
// userRaw is the store model for an User. // userRaw is the store model for an User.
// Fields have exactly the same meanings as User. // Fields have exactly the same meanings as User.
type userRaw struct { type userRaw struct {

48
test/store/idp_test.go Normal file
View File

@ -0,0 +1,48 @@
package teststore
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestIdentityProviderStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
Name: "GitHub OAuth",
Type: store.IdentityProviderOAuth2,
IdentifierFilter: "",
Config: &store.IdentityProviderConfig{
OAuth2Config: &store.IdentityProviderOAuth2Config{
ClientID: "asd",
ClientSecret: "123",
AuthURL: "https://github.com/auth",
TokenURL: "https://github.com/token",
UserInfoURL: "https://github.com/user",
Scopes: []string{"login"},
FieldMapping: &store.FieldMapping{
Identifier: "login",
DisplayName: "name",
Email: "emai",
},
},
},
})
require.NoError(t, err)
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
ID: &createdIDP.ID,
})
require.NoError(t, err)
require.Equal(t, createdIDP, idp)
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{
ID: idp.ID,
})
require.NoError(t, err)
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
require.NoError(t, err)
require.Equal(t, 0, len(idpList))
}

View File

@ -17,9 +17,7 @@ const SSOSection = () => {
}, []); }, []);
const fetchIdentityProviderList = async () => { const fetchIdentityProviderList = async () => {
const { const { data: identityProviderList } = await api.getIdentityProviderList();
data: { data: identityProviderList },
} = await api.getIdentityProviderList();
setIdentityProviderList(identityProviderList); setIdentityProviderList(identityProviderList);
}; };

View File

@ -246,19 +246,19 @@ export function deleteStorage(storageId: StorageId) {
} }
export function getIdentityProviderList() { export function getIdentityProviderList() {
return axios.get<ResponseObject<IdentityProvider[]>>(`/api/idp`); return axios.get<IdentityProvider[]>(`/api/v1/idp`);
} }
export function createIdentityProvider(identityProviderCreate: IdentityProviderCreate) { export function createIdentityProvider(identityProviderCreate: IdentityProviderCreate) {
return axios.post<ResponseObject<IdentityProvider>>(`/api/idp`, identityProviderCreate); return axios.post<IdentityProvider>(`/api/v1/idp`, identityProviderCreate);
} }
export function patchIdentityProvider(identityProviderPatch: IdentityProviderPatch) { export function patchIdentityProvider(identityProviderPatch: IdentityProviderPatch) {
return axios.patch<ResponseObject<IdentityProvider>>(`/api/idp/${identityProviderPatch.id}`, identityProviderPatch); return axios.patch<IdentityProvider>(`/api/v1/idp/${identityProviderPatch.id}`, identityProviderPatch);
} }
export function deleteIdentityProvider(id: IdentityProviderId) { export function deleteIdentityProvider(id: IdentityProviderId) {
return axios.delete(`/api/idp/${id}`); return axios.delete(`/api/v1/idp/${id}`);
} }
export async function getRepoStarCount() { export async function getRepoStarCount() {

View File

@ -24,9 +24,7 @@ const Auth = () => {
useEffect(() => { useEffect(() => {
userStore.doSignOut().catch(); userStore.doSignOut().catch();
const fetchIdentityProviderList = async () => { const fetchIdentityProviderList = async () => {
const { const { data: identityProviderList } = await api.getIdentityProviderList();
data: { data: identityProviderList },
} = await api.getIdentityProviderList();
setIdentityProviderList(identityProviderList); setIdentityProviderList(identityProviderList);
}; };
fetchIdentityProviderList(); fetchIdentityProviderList();