mirror of
https://github.com/usememos/memos.git
synced 2025-02-15 19:00:46 +01:00
refactor: migration idp api (#1842)
* refactor: migration idp api * chore: update
This commit is contained in:
parent
4ed9a3a0ea
commit
b34aded376
58
api/idp.go
58
api/idp.go
@ -1,58 +0,0 @@
|
||||
package api
|
||||
|
||||
type IdentityProviderType string
|
||||
|
||||
const (
|
||||
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
|
||||
)
|
||||
|
||||
type IdentityProviderConfig struct {
|
||||
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
|
||||
}
|
||||
|
||||
type IdentityProviderOAuth2Config struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
AuthURL string `json:"authUrl"`
|
||||
TokenURL string `json:"tokenUrl"`
|
||||
UserInfoURL string `json:"userInfoUrl"`
|
||||
Scopes []string `json:"scopes"`
|
||||
FieldMapping *FieldMapping `json:"fieldMapping"`
|
||||
}
|
||||
|
||||
type FieldMapping struct {
|
||||
Identifier string `json:"identifier"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type IdentityProvider struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type IdentityProviderType `json:"type"`
|
||||
IdentifierFilter string `json:"identifierFilter"`
|
||||
Config *IdentityProviderConfig `json:"config"`
|
||||
}
|
||||
|
||||
type IdentityProviderCreate struct {
|
||||
Name string `json:"name"`
|
||||
Type IdentityProviderType `json:"type"`
|
||||
IdentifierFilter string `json:"identifierFilter"`
|
||||
Config *IdentityProviderConfig `json:"config"`
|
||||
}
|
||||
|
||||
type IdentityProviderFind struct {
|
||||
ID *int
|
||||
}
|
||||
|
||||
type IdentityProviderPatch struct {
|
||||
ID int
|
||||
Type IdentityProviderType `json:"type"`
|
||||
Name *string `json:"name"`
|
||||
IdentifierFilter *string `json:"identifierFilter"`
|
||||
Config *IdentityProviderConfig `json:"config"`
|
||||
}
|
||||
|
||||
type IdentityProviderDelete struct {
|
||||
ID int
|
||||
}
|
24
api/v1/common.go
Normal file
24
api/v1/common.go
Normal file
@ -0,0 +1,24 @@
|
||||
package v1
|
||||
|
||||
// UnknownID is the ID for unknowns.
|
||||
const UnknownID = -1
|
||||
|
||||
// RowStatus is the status for a row.
|
||||
type RowStatus string
|
||||
|
||||
const (
|
||||
// Normal is the status for a normal row.
|
||||
Normal RowStatus = "NORMAL"
|
||||
// Archived is the status for an archived row.
|
||||
Archived RowStatus = "ARCHIVED"
|
||||
)
|
||||
|
||||
func (e RowStatus) String() string {
|
||||
switch e {
|
||||
case Normal:
|
||||
return "NORMAL"
|
||||
case Archived:
|
||||
return "ARCHIVED"
|
||||
}
|
||||
return ""
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package server
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@ -7,12 +7,60 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/usememos/memos/api"
|
||||
"github.com/usememos/memos/common"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
type IdentityProviderType string
|
||||
|
||||
const (
|
||||
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
|
||||
)
|
||||
|
||||
type IdentityProviderConfig struct {
|
||||
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
|
||||
}
|
||||
|
||||
type IdentityProviderOAuth2Config struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
AuthURL string `json:"authUrl"`
|
||||
TokenURL string `json:"tokenUrl"`
|
||||
UserInfoURL string `json:"userInfoUrl"`
|
||||
Scopes []string `json:"scopes"`
|
||||
FieldMapping *FieldMapping `json:"fieldMapping"`
|
||||
}
|
||||
|
||||
type FieldMapping struct {
|
||||
Identifier string `json:"identifier"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type IdentityProvider struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type IdentityProviderType `json:"type"`
|
||||
IdentifierFilter string `json:"identifierFilter"`
|
||||
Config *IdentityProviderConfig `json:"config"`
|
||||
}
|
||||
|
||||
type IdentityProviderCreate struct {
|
||||
Name string `json:"name"`
|
||||
Type IdentityProviderType `json:"type"`
|
||||
IdentifierFilter string `json:"identifierFilter"`
|
||||
Config *IdentityProviderConfig `json:"config"`
|
||||
}
|
||||
|
||||
type IdentityProviderPatch struct {
|
||||
ID int
|
||||
Type IdentityProviderType `json:"type"`
|
||||
Name *string `json:"name"`
|
||||
IdentifierFilter *string `json:"identifierFilter"`
|
||||
Config *IdentityProviderConfig `json:"config"`
|
||||
}
|
||||
|
||||
func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
g.POST("/idp", func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
userID, ok := c.Get(getUserIDContextKey()).(int)
|
||||
@ -20,17 +68,17 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||
}
|
||||
if user == nil || user.Role != api.Host {
|
||||
if user == nil || user.Role != store.Host {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
||||
}
|
||||
|
||||
identityProviderCreate := &api.IdentityProviderCreate{}
|
||||
identityProviderCreate := &IdentityProviderCreate{}
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
|
||||
}
|
||||
@ -44,7 +92,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
|
||||
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
||||
})
|
||||
|
||||
g.PATCH("/idp/:idpId", func(c echo.Context) error {
|
||||
@ -54,13 +102,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||
}
|
||||
if user == nil || user.Role != api.Host {
|
||||
if user == nil || user.Role != store.Host {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
||||
}
|
||||
|
||||
@ -69,7 +117,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
|
||||
}
|
||||
|
||||
identityProviderPatch := &api.IdentityProviderPatch{
|
||||
identityProviderPatch := &IdentityProviderPatch{
|
||||
ID: identityProviderID,
|
||||
}
|
||||
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil {
|
||||
@ -86,7 +134,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
|
||||
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
||||
})
|
||||
|
||||
g.GET("/idp", func(c echo.Context) error {
|
||||
@ -99,18 +147,18 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
userID, ok := c.Get(getUserIDContextKey()).(int)
|
||||
isHostUser := false
|
||||
if ok {
|
||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||
}
|
||||
if user != nil && user.Role == api.Host {
|
||||
if user == nil || user.Role == store.Host {
|
||||
isHostUser = true
|
||||
}
|
||||
}
|
||||
|
||||
identityProviderList := []*api.IdentityProvider{}
|
||||
identityProviderList := []*IdentityProvider{}
|
||||
for _, identityProviderMessage := range identityProviderMessageList {
|
||||
identityProvider := convertIdentityProviderFromStore(identityProviderMessage)
|
||||
// data desensitize
|
||||
@ -119,7 +167,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
}
|
||||
identityProviderList = append(identityProviderList, identityProvider)
|
||||
}
|
||||
return c.JSON(http.StatusOK, composeResponse(identityProviderList))
|
||||
return c.JSON(http.StatusOK, identityProviderList)
|
||||
})
|
||||
|
||||
g.GET("/idp/:idpId", func(c echo.Context) error {
|
||||
@ -129,14 +177,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||
}
|
||||
// We should only show identity provider list to host user.
|
||||
if user == nil || user.Role != api.Host {
|
||||
if user == nil || user.Role != store.Host {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
||||
}
|
||||
|
||||
@ -150,7 +197,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
|
||||
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
||||
})
|
||||
|
||||
g.DELETE("/idp/:idpId", func(c echo.Context) error {
|
||||
@ -160,13 +207,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||
}
|
||||
|
||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
||||
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||
ID: &userID,
|
||||
})
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||
}
|
||||
if user == nil || user.Role != api.Host {
|
||||
if user == nil || user.Role != store.Host {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
||||
}
|
||||
|
||||
@ -185,26 +232,26 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
||||
})
|
||||
}
|
||||
|
||||
func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *api.IdentityProvider {
|
||||
return &api.IdentityProvider{
|
||||
func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *IdentityProvider {
|
||||
return &IdentityProvider{
|
||||
ID: identityProviderMessage.ID,
|
||||
Name: identityProviderMessage.Name,
|
||||
Type: api.IdentityProviderType(identityProviderMessage.Type),
|
||||
Type: IdentityProviderType(identityProviderMessage.Type),
|
||||
IdentifierFilter: identityProviderMessage.IdentifierFilter,
|
||||
Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config),
|
||||
}
|
||||
}
|
||||
|
||||
func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *api.IdentityProviderConfig {
|
||||
return &api.IdentityProviderConfig{
|
||||
OAuth2Config: &api.IdentityProviderOAuth2Config{
|
||||
func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *IdentityProviderConfig {
|
||||
return &IdentityProviderConfig{
|
||||
OAuth2Config: &IdentityProviderOAuth2Config{
|
||||
ClientID: config.OAuth2Config.ClientID,
|
||||
ClientSecret: config.OAuth2Config.ClientSecret,
|
||||
AuthURL: config.OAuth2Config.AuthURL,
|
||||
TokenURL: config.OAuth2Config.TokenURL,
|
||||
UserInfoURL: config.OAuth2Config.UserInfoURL,
|
||||
Scopes: config.OAuth2Config.Scopes,
|
||||
FieldMapping: &api.FieldMapping{
|
||||
FieldMapping: &FieldMapping{
|
||||
Identifier: config.OAuth2Config.FieldMapping.Identifier,
|
||||
DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
|
||||
Email: config.OAuth2Config.FieldMapping.Email,
|
||||
@ -213,7 +260,7 @@ func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig
|
||||
}
|
||||
}
|
||||
|
||||
func convertIdentityProviderConfigToStore(config *api.IdentityProviderConfig) *store.IdentityProviderConfig {
|
||||
func convertIdentityProviderConfigToStore(config *IdentityProviderConfig) *store.IdentityProviderConfig {
|
||||
return &store.IdentityProviderConfig{
|
||||
OAuth2Config: &store.IdentityProviderOAuth2Config{
|
||||
ClientID: config.OAuth2Config.ClientID,
|
25
api/v1/user.go
Normal file
25
api/v1/user.go
Normal file
@ -0,0 +1,25 @@
|
||||
package v1
|
||||
|
||||
// Role is the type of a role.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
// Host is the HOST role.
|
||||
Host Role = "HOST"
|
||||
// Admin is the ADMIN role.
|
||||
Admin Role = "ADMIN"
|
||||
// NormalUser is the USER role.
|
||||
NormalUser Role = "USER"
|
||||
)
|
||||
|
||||
func (e Role) String() string {
|
||||
switch e {
|
||||
case Host:
|
||||
return "HOST"
|
||||
case Admin:
|
||||
return "ADMIN"
|
||||
case NormalUser:
|
||||
return "USER"
|
||||
}
|
||||
return "USER"
|
||||
}
|
15
api/v1/v1.go
15
api/v1/v1.go
@ -6,6 +6,17 @@ import (
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// Context section
|
||||
// The key name used to store user id in the context
|
||||
// user id is extracted from the jwt token subject field.
|
||||
userIDContextKey = "user-id"
|
||||
)
|
||||
|
||||
func getUserIDContextKey() string {
|
||||
return userIDContextKey
|
||||
}
|
||||
|
||||
type APIV1Service struct {
|
||||
Secret string
|
||||
Profile *profile.Profile
|
||||
@ -20,8 +31,8 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIV1Service) Register(e *echo.Echo) {
|
||||
apiV1Group := e.Group("/api/v1")
|
||||
func (s *APIV1Service) Register(apiV1Group *echo.Group) {
|
||||
s.registerTestRoutes(apiV1Group)
|
||||
s.registerAuthRoutes(apiV1Group, s.Secret)
|
||||
s.registerIdentityProviderRoutes(apiV1Group)
|
||||
}
|
||||
|
@ -22,6 +22,10 @@ const (
|
||||
userIDContextKey = "user-id"
|
||||
)
|
||||
|
||||
func getUserIDContextKey() string {
|
||||
return userIDContextKey
|
||||
}
|
||||
|
||||
// Claims creates a struct that will be encoded to a JWT.
|
||||
// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name.
|
||||
type Claims struct {
|
||||
@ -29,10 +33,6 @@ type Claims struct {
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func getUserIDContextKey() string {
|
||||
return userIDContextKey
|
||||
}
|
||||
|
||||
func extractTokenFromHeader(c echo.Context) (string, error) {
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
@ -82,7 +82,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
|
||||
}
|
||||
|
||||
// Skip validation for server status endpoints.
|
||||
if common.HasPrefixes(path, "/api/ping", "/api/idp", "/api/user/:id") && method == http.MethodGet {
|
||||
if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
@ -111,15 +111,9 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
|
||||
}
|
||||
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
||||
})
|
||||
|
||||
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized,
|
||||
fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
|
||||
claims.Audience,
|
||||
auth.AccessTokenAudienceName,
|
||||
))
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName))
|
||||
}
|
||||
|
||||
generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration
|
||||
if err != nil {
|
||||
var ve *jwt.ValidationError
|
||||
@ -130,11 +124,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
|
||||
generateToken = true
|
||||
}
|
||||
} else {
|
||||
return &echo.HTTPError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "Invalid or expired access token",
|
||||
Internal: err,
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,12 +105,15 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
|
||||
s.registerResourceRoutes(apiGroup)
|
||||
s.registerTagRoutes(apiGroup)
|
||||
s.registerStorageRoutes(apiGroup)
|
||||
s.registerIdentityProviderRoutes(apiGroup)
|
||||
s.registerOpenAIRoutes(apiGroup)
|
||||
s.registerMemoRelationRoutes(apiGroup)
|
||||
|
||||
apiV1Group := e.Group("/api/v1")
|
||||
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return JWTMiddleware(s, next, s.Secret)
|
||||
})
|
||||
apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store)
|
||||
apiV1Service.Register(e)
|
||||
apiV1Service.Register(apiV1Group)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
205
store/user.go
205
store/user.go
@ -27,6 +27,211 @@ func (s *Store) SeedDataForNewUser(ctx context.Context, user *api.User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Role is the type of a role.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
// Host is the HOST role.
|
||||
Host Role = "HOST"
|
||||
// Admin is the ADMIN role.
|
||||
Admin Role = "ADMIN"
|
||||
// NormalUser is the USER role.
|
||||
NormalUser Role = "USER"
|
||||
)
|
||||
|
||||
func (e Role) String() string {
|
||||
switch e {
|
||||
case Host:
|
||||
return "HOST"
|
||||
case Admin:
|
||||
return "ADMIN"
|
||||
case NormalUser:
|
||||
return "USER"
|
||||
}
|
||||
return "USER"
|
||||
}
|
||||
|
||||
type UserMessage struct {
|
||||
ID int
|
||||
|
||||
// Standard fields
|
||||
RowStatus RowStatus
|
||||
CreatedTs int64
|
||||
UpdatedTs int64
|
||||
|
||||
// Domain specific fields
|
||||
Username string
|
||||
Role Role
|
||||
Email string
|
||||
Nickname string
|
||||
PasswordHash string
|
||||
OpenID string
|
||||
AvatarURL string
|
||||
}
|
||||
|
||||
type FindUserMessage struct {
|
||||
ID *int
|
||||
|
||||
// Standard fields
|
||||
RowStatus *RowStatus
|
||||
|
||||
// Domain specific fields
|
||||
Username *string
|
||||
Role *Role
|
||||
Email *string
|
||||
Nickname *string
|
||||
OpenID *string
|
||||
}
|
||||
|
||||
func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMessage, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
query := `
|
||||
INSERT INTO user (
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
open_id
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
RETURNING id, avatar_url, created_ts, updated_ts, row_status
|
||||
`
|
||||
if err := tx.QueryRowContext(ctx, query,
|
||||
create.Username,
|
||||
create.Role,
|
||||
create.Email,
|
||||
create.Nickname,
|
||||
create.PasswordHash,
|
||||
create.OpenID,
|
||||
).Scan(
|
||||
&create.ID,
|
||||
&create.AvatarURL,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
userMessage := create
|
||||
return userMessage, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMessage, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
list, err := listUsers(ctx, tx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessage, error) {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
list, err := listUsers(ctx, tx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
|
||||
}
|
||||
|
||||
memoMessage := list[0]
|
||||
return memoMessage, nil
|
||||
}
|
||||
|
||||
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserMessage, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
where, args = append(where, "username = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.OpenID; v != nil {
|
||||
where, args = append(where, "open_id = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
open_id,
|
||||
avatar_url,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status
|
||||
FROM user
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY created_ts DESC, row_status DESC
|
||||
`
|
||||
rows, err := tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userMessageList := make([]*UserMessage, 0)
|
||||
for rows.Next() {
|
||||
var userMessage UserMessage
|
||||
if err := rows.Scan(
|
||||
&userMessage.ID,
|
||||
&userMessage.Username,
|
||||
&userMessage.Role,
|
||||
&userMessage.Email,
|
||||
&userMessage.Nickname,
|
||||
&userMessage.PasswordHash,
|
||||
&userMessage.OpenID,
|
||||
&userMessage.AvatarURL,
|
||||
&userMessage.CreatedTs,
|
||||
&userMessage.UpdatedTs,
|
||||
&userMessage.RowStatus,
|
||||
); err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
userMessageList = append(userMessageList, &userMessage)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, FormatError(err)
|
||||
}
|
||||
|
||||
return userMessageList, nil
|
||||
}
|
||||
|
||||
// userRaw is the store model for an User.
|
||||
// Fields have exactly the same meanings as User.
|
||||
type userRaw struct {
|
||||
|
48
test/store/idp_test.go
Normal file
48
test/store/idp_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package teststore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestIdentityProviderStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
|
||||
Name: "GitHub OAuth",
|
||||
Type: store.IdentityProviderOAuth2,
|
||||
IdentifierFilter: "",
|
||||
Config: &store.IdentityProviderConfig{
|
||||
OAuth2Config: &store.IdentityProviderOAuth2Config{
|
||||
ClientID: "asd",
|
||||
ClientSecret: "123",
|
||||
AuthURL: "https://github.com/auth",
|
||||
TokenURL: "https://github.com/token",
|
||||
UserInfoURL: "https://github.com/user",
|
||||
Scopes: []string{"login"},
|
||||
FieldMapping: &store.FieldMapping{
|
||||
Identifier: "login",
|
||||
DisplayName: "name",
|
||||
Email: "emai",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
|
||||
ID: &createdIDP.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, createdIDP, idp)
|
||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{
|
||||
ID: idp.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(idpList))
|
||||
}
|
@ -17,9 +17,7 @@ const SSOSection = () => {
|
||||
}, []);
|
||||
|
||||
const fetchIdentityProviderList = async () => {
|
||||
const {
|
||||
data: { data: identityProviderList },
|
||||
} = await api.getIdentityProviderList();
|
||||
const { data: identityProviderList } = await api.getIdentityProviderList();
|
||||
setIdentityProviderList(identityProviderList);
|
||||
};
|
||||
|
||||
|
@ -246,19 +246,19 @@ export function deleteStorage(storageId: StorageId) {
|
||||
}
|
||||
|
||||
export function getIdentityProviderList() {
|
||||
return axios.get<ResponseObject<IdentityProvider[]>>(`/api/idp`);
|
||||
return axios.get<IdentityProvider[]>(`/api/v1/idp`);
|
||||
}
|
||||
|
||||
export function createIdentityProvider(identityProviderCreate: IdentityProviderCreate) {
|
||||
return axios.post<ResponseObject<IdentityProvider>>(`/api/idp`, identityProviderCreate);
|
||||
return axios.post<IdentityProvider>(`/api/v1/idp`, identityProviderCreate);
|
||||
}
|
||||
|
||||
export function patchIdentityProvider(identityProviderPatch: IdentityProviderPatch) {
|
||||
return axios.patch<ResponseObject<IdentityProvider>>(`/api/idp/${identityProviderPatch.id}`, identityProviderPatch);
|
||||
return axios.patch<IdentityProvider>(`/api/v1/idp/${identityProviderPatch.id}`, identityProviderPatch);
|
||||
}
|
||||
|
||||
export function deleteIdentityProvider(id: IdentityProviderId) {
|
||||
return axios.delete(`/api/idp/${id}`);
|
||||
return axios.delete(`/api/v1/idp/${id}`);
|
||||
}
|
||||
|
||||
export async function getRepoStarCount() {
|
||||
|
@ -24,9 +24,7 @@ const Auth = () => {
|
||||
useEffect(() => {
|
||||
userStore.doSignOut().catch();
|
||||
const fetchIdentityProviderList = async () => {
|
||||
const {
|
||||
data: { data: identityProviderList },
|
||||
} = await api.getIdentityProviderList();
|
||||
const { data: identityProviderList } = await api.getIdentityProviderList();
|
||||
setIdentityProviderList(identityProviderList);
|
||||
};
|
||||
fetchIdentityProviderList();
|
||||
|
Loading…
x
Reference in New Issue
Block a user