mirror of
https://github.com/usememos/memos.git
synced 2025-02-15 19:00:46 +01:00
refactor: migration idp api (#1842)
* refactor: migration idp api * chore: update
This commit is contained in:
parent
4ed9a3a0ea
commit
b34aded376
58
api/idp.go
58
api/idp.go
@ -1,58 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
type IdentityProviderType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type IdentityProviderConfig struct {
|
|
||||||
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type IdentityProviderOAuth2Config struct {
|
|
||||||
ClientID string `json:"clientId"`
|
|
||||||
ClientSecret string `json:"clientSecret"`
|
|
||||||
AuthURL string `json:"authUrl"`
|
|
||||||
TokenURL string `json:"tokenUrl"`
|
|
||||||
UserInfoURL string `json:"userInfoUrl"`
|
|
||||||
Scopes []string `json:"scopes"`
|
|
||||||
FieldMapping *FieldMapping `json:"fieldMapping"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FieldMapping struct {
|
|
||||||
Identifier string `json:"identifier"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type IdentityProvider struct {
|
|
||||||
ID int `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type IdentityProviderType `json:"type"`
|
|
||||||
IdentifierFilter string `json:"identifierFilter"`
|
|
||||||
Config *IdentityProviderConfig `json:"config"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type IdentityProviderCreate struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type IdentityProviderType `json:"type"`
|
|
||||||
IdentifierFilter string `json:"identifierFilter"`
|
|
||||||
Config *IdentityProviderConfig `json:"config"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type IdentityProviderFind struct {
|
|
||||||
ID *int
|
|
||||||
}
|
|
||||||
|
|
||||||
type IdentityProviderPatch struct {
|
|
||||||
ID int
|
|
||||||
Type IdentityProviderType `json:"type"`
|
|
||||||
Name *string `json:"name"`
|
|
||||||
IdentifierFilter *string `json:"identifierFilter"`
|
|
||||||
Config *IdentityProviderConfig `json:"config"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type IdentityProviderDelete struct {
|
|
||||||
ID int
|
|
||||||
}
|
|
24
api/v1/common.go
Normal file
24
api/v1/common.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package v1
|
||||||
|
|
||||||
|
// UnknownID is the ID for unknowns.
|
||||||
|
const UnknownID = -1
|
||||||
|
|
||||||
|
// RowStatus is the status for a row.
|
||||||
|
type RowStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Normal is the status for a normal row.
|
||||||
|
Normal RowStatus = "NORMAL"
|
||||||
|
// Archived is the status for an archived row.
|
||||||
|
Archived RowStatus = "ARCHIVED"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e RowStatus) String() string {
|
||||||
|
switch e {
|
||||||
|
case Normal:
|
||||||
|
return "NORMAL"
|
||||||
|
case Archived:
|
||||||
|
return "ARCHIVED"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package server
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -7,12 +7,60 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/usememos/memos/api"
|
|
||||||
"github.com/usememos/memos/common"
|
"github.com/usememos/memos/common"
|
||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
type IdentityProviderType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IdentityProviderConfig struct {
|
||||||
|
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type IdentityProviderOAuth2Config struct {
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
AuthURL string `json:"authUrl"`
|
||||||
|
TokenURL string `json:"tokenUrl"`
|
||||||
|
UserInfoURL string `json:"userInfoUrl"`
|
||||||
|
Scopes []string `json:"scopes"`
|
||||||
|
FieldMapping *FieldMapping `json:"fieldMapping"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FieldMapping struct {
|
||||||
|
Identifier string `json:"identifier"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type IdentityProvider struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type IdentityProviderType `json:"type"`
|
||||||
|
IdentifierFilter string `json:"identifierFilter"`
|
||||||
|
Config *IdentityProviderConfig `json:"config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type IdentityProviderCreate struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type IdentityProviderType `json:"type"`
|
||||||
|
IdentifierFilter string `json:"identifierFilter"`
|
||||||
|
Config *IdentityProviderConfig `json:"config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type IdentityProviderPatch struct {
|
||||||
|
ID int
|
||||||
|
Type IdentityProviderType `json:"type"`
|
||||||
|
Name *string `json:"name"`
|
||||||
|
IdentifierFilter *string `json:"identifierFilter"`
|
||||||
|
Config *IdentityProviderConfig `json:"config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
||||||
g.POST("/idp", func(c echo.Context) error {
|
g.POST("/idp", func(c echo.Context) error {
|
||||||
ctx := c.Request().Context()
|
ctx := c.Request().Context()
|
||||||
userID, ok := c.Get(getUserIDContextKey()).(int)
|
userID, ok := c.Get(getUserIDContextKey()).(int)
|
||||||
@ -20,17 +68,17 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||||
}
|
}
|
||||||
if user == nil || user.Role != api.Host {
|
if user == nil || user.Role != store.Host {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
||||||
}
|
}
|
||||||
|
|
||||||
identityProviderCreate := &api.IdentityProviderCreate{}
|
identityProviderCreate := &IdentityProviderCreate{}
|
||||||
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil {
|
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
|
||||||
}
|
}
|
||||||
@ -44,7 +92,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
|
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
||||||
})
|
})
|
||||||
|
|
||||||
g.PATCH("/idp/:idpId", func(c echo.Context) error {
|
g.PATCH("/idp/:idpId", func(c echo.Context) error {
|
||||||
@ -54,13 +102,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||||
}
|
}
|
||||||
if user == nil || user.Role != api.Host {
|
if user == nil || user.Role != store.Host {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +117,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
identityProviderPatch := &api.IdentityProviderPatch{
|
identityProviderPatch := &IdentityProviderPatch{
|
||||||
ID: identityProviderID,
|
ID: identityProviderID,
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil {
|
if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil {
|
||||||
@ -86,7 +134,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
|
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
||||||
})
|
})
|
||||||
|
|
||||||
g.GET("/idp", func(c echo.Context) error {
|
g.GET("/idp", func(c echo.Context) error {
|
||||||
@ -99,18 +147,18 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
userID, ok := c.Get(getUserIDContextKey()).(int)
|
userID, ok := c.Get(getUserIDContextKey()).(int)
|
||||||
isHostUser := false
|
isHostUser := false
|
||||||
if ok {
|
if ok {
|
||||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil && common.ErrorCode(err) != common.NotFound {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||||
}
|
}
|
||||||
if user != nil && user.Role == api.Host {
|
if user == nil || user.Role == store.Host {
|
||||||
isHostUser = true
|
isHostUser = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identityProviderList := []*api.IdentityProvider{}
|
identityProviderList := []*IdentityProvider{}
|
||||||
for _, identityProviderMessage := range identityProviderMessageList {
|
for _, identityProviderMessage := range identityProviderMessageList {
|
||||||
identityProvider := convertIdentityProviderFromStore(identityProviderMessage)
|
identityProvider := convertIdentityProviderFromStore(identityProviderMessage)
|
||||||
// data desensitize
|
// data desensitize
|
||||||
@ -119,7 +167,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
}
|
}
|
||||||
identityProviderList = append(identityProviderList, identityProvider)
|
identityProviderList = append(identityProviderList, identityProvider)
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, composeResponse(identityProviderList))
|
return c.JSON(http.StatusOK, identityProviderList)
|
||||||
})
|
})
|
||||||
|
|
||||||
g.GET("/idp/:idpId", func(c echo.Context) error {
|
g.GET("/idp/:idpId", func(c echo.Context) error {
|
||||||
@ -129,14 +177,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||||
}
|
}
|
||||||
// We should only show identity provider list to host user.
|
if user == nil || user.Role != store.Host {
|
||||||
if user == nil || user.Role != api.Host {
|
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,7 +197,7 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, composeResponse(convertIdentityProviderFromStore(identityProviderMessage)))
|
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
||||||
})
|
})
|
||||||
|
|
||||||
g.DELETE("/idp/:idpId", func(c echo.Context) error {
|
g.DELETE("/idp/:idpId", func(c echo.Context) error {
|
||||||
@ -160,13 +207,13 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
user, err := s.Store.GetUser(ctx, &store.FindUserMessage{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
|
||||||
}
|
}
|
||||||
if user == nil || user.Role != api.Host {
|
if user == nil || user.Role != store.Host {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,26 +232,26 @@ func (s *Server) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *api.IdentityProvider {
|
func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *IdentityProvider {
|
||||||
return &api.IdentityProvider{
|
return &IdentityProvider{
|
||||||
ID: identityProviderMessage.ID,
|
ID: identityProviderMessage.ID,
|
||||||
Name: identityProviderMessage.Name,
|
Name: identityProviderMessage.Name,
|
||||||
Type: api.IdentityProviderType(identityProviderMessage.Type),
|
Type: IdentityProviderType(identityProviderMessage.Type),
|
||||||
IdentifierFilter: identityProviderMessage.IdentifierFilter,
|
IdentifierFilter: identityProviderMessage.IdentifierFilter,
|
||||||
Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config),
|
Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *api.IdentityProviderConfig {
|
func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig) *IdentityProviderConfig {
|
||||||
return &api.IdentityProviderConfig{
|
return &IdentityProviderConfig{
|
||||||
OAuth2Config: &api.IdentityProviderOAuth2Config{
|
OAuth2Config: &IdentityProviderOAuth2Config{
|
||||||
ClientID: config.OAuth2Config.ClientID,
|
ClientID: config.OAuth2Config.ClientID,
|
||||||
ClientSecret: config.OAuth2Config.ClientSecret,
|
ClientSecret: config.OAuth2Config.ClientSecret,
|
||||||
AuthURL: config.OAuth2Config.AuthURL,
|
AuthURL: config.OAuth2Config.AuthURL,
|
||||||
TokenURL: config.OAuth2Config.TokenURL,
|
TokenURL: config.OAuth2Config.TokenURL,
|
||||||
UserInfoURL: config.OAuth2Config.UserInfoURL,
|
UserInfoURL: config.OAuth2Config.UserInfoURL,
|
||||||
Scopes: config.OAuth2Config.Scopes,
|
Scopes: config.OAuth2Config.Scopes,
|
||||||
FieldMapping: &api.FieldMapping{
|
FieldMapping: &FieldMapping{
|
||||||
Identifier: config.OAuth2Config.FieldMapping.Identifier,
|
Identifier: config.OAuth2Config.FieldMapping.Identifier,
|
||||||
DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
|
DisplayName: config.OAuth2Config.FieldMapping.DisplayName,
|
||||||
Email: config.OAuth2Config.FieldMapping.Email,
|
Email: config.OAuth2Config.FieldMapping.Email,
|
||||||
@ -213,7 +260,7 @@ func convertIdentityProviderConfigFromStore(config *store.IdentityProviderConfig
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertIdentityProviderConfigToStore(config *api.IdentityProviderConfig) *store.IdentityProviderConfig {
|
func convertIdentityProviderConfigToStore(config *IdentityProviderConfig) *store.IdentityProviderConfig {
|
||||||
return &store.IdentityProviderConfig{
|
return &store.IdentityProviderConfig{
|
||||||
OAuth2Config: &store.IdentityProviderOAuth2Config{
|
OAuth2Config: &store.IdentityProviderOAuth2Config{
|
||||||
ClientID: config.OAuth2Config.ClientID,
|
ClientID: config.OAuth2Config.ClientID,
|
25
api/v1/user.go
Normal file
25
api/v1/user.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package v1
|
||||||
|
|
||||||
|
// Role is the type of a role.
|
||||||
|
type Role string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Host is the HOST role.
|
||||||
|
Host Role = "HOST"
|
||||||
|
// Admin is the ADMIN role.
|
||||||
|
Admin Role = "ADMIN"
|
||||||
|
// NormalUser is the USER role.
|
||||||
|
NormalUser Role = "USER"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e Role) String() string {
|
||||||
|
switch e {
|
||||||
|
case Host:
|
||||||
|
return "HOST"
|
||||||
|
case Admin:
|
||||||
|
return "ADMIN"
|
||||||
|
case NormalUser:
|
||||||
|
return "USER"
|
||||||
|
}
|
||||||
|
return "USER"
|
||||||
|
}
|
15
api/v1/v1.go
15
api/v1/v1.go
@ -6,6 +6,17 @@ import (
|
|||||||
"github.com/usememos/memos/store"
|
"github.com/usememos/memos/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Context section
|
||||||
|
// The key name used to store user id in the context
|
||||||
|
// user id is extracted from the jwt token subject field.
|
||||||
|
userIDContextKey = "user-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getUserIDContextKey() string {
|
||||||
|
return userIDContextKey
|
||||||
|
}
|
||||||
|
|
||||||
type APIV1Service struct {
|
type APIV1Service struct {
|
||||||
Secret string
|
Secret string
|
||||||
Profile *profile.Profile
|
Profile *profile.Profile
|
||||||
@ -20,8 +31,8 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIV1Service) Register(e *echo.Echo) {
|
func (s *APIV1Service) Register(apiV1Group *echo.Group) {
|
||||||
apiV1Group := e.Group("/api/v1")
|
|
||||||
s.registerTestRoutes(apiV1Group)
|
s.registerTestRoutes(apiV1Group)
|
||||||
s.registerAuthRoutes(apiV1Group, s.Secret)
|
s.registerAuthRoutes(apiV1Group, s.Secret)
|
||||||
|
s.registerIdentityProviderRoutes(apiV1Group)
|
||||||
}
|
}
|
||||||
|
@ -22,6 +22,10 @@ const (
|
|||||||
userIDContextKey = "user-id"
|
userIDContextKey = "user-id"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getUserIDContextKey() string {
|
||||||
|
return userIDContextKey
|
||||||
|
}
|
||||||
|
|
||||||
// Claims creates a struct that will be encoded to a JWT.
|
// Claims creates a struct that will be encoded to a JWT.
|
||||||
// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name.
|
// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name.
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
@ -29,10 +33,6 @@ type Claims struct {
|
|||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserIDContextKey() string {
|
|
||||||
return userIDContextKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractTokenFromHeader(c echo.Context) (string, error) {
|
func extractTokenFromHeader(c echo.Context) (string, error) {
|
||||||
authHeader := c.Request().Header.Get("Authorization")
|
authHeader := c.Request().Header.Get("Authorization")
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
@ -82,7 +82,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Skip validation for server status endpoints.
|
// Skip validation for server status endpoints.
|
||||||
if common.HasPrefixes(path, "/api/ping", "/api/idp", "/api/user/:id") && method == http.MethodGet {
|
if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,15 +111,9 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
|
|||||||
}
|
}
|
||||||
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
|
||||||
})
|
})
|
||||||
|
|
||||||
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
|
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized,
|
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName))
|
||||||
fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
|
|
||||||
claims.Audience,
|
|
||||||
auth.AccessTokenAudienceName,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration
|
generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var ve *jwt.ValidationError
|
var ve *jwt.ValidationError
|
||||||
@ -130,11 +124,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
|
|||||||
generateToken = true
|
generateToken = true
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return &echo.HTTPError{
|
return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token"))
|
||||||
Code: http.StatusUnauthorized,
|
|
||||||
Message: "Invalid or expired access token",
|
|
||||||
Internal: err,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,12 +105,15 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
|
|||||||
s.registerResourceRoutes(apiGroup)
|
s.registerResourceRoutes(apiGroup)
|
||||||
s.registerTagRoutes(apiGroup)
|
s.registerTagRoutes(apiGroup)
|
||||||
s.registerStorageRoutes(apiGroup)
|
s.registerStorageRoutes(apiGroup)
|
||||||
s.registerIdentityProviderRoutes(apiGroup)
|
|
||||||
s.registerOpenAIRoutes(apiGroup)
|
s.registerOpenAIRoutes(apiGroup)
|
||||||
s.registerMemoRelationRoutes(apiGroup)
|
s.registerMemoRelationRoutes(apiGroup)
|
||||||
|
|
||||||
|
apiV1Group := e.Group("/api/v1")
|
||||||
|
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return JWTMiddleware(s, next, s.Secret)
|
||||||
|
})
|
||||||
apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store)
|
apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store)
|
||||||
apiV1Service.Register(e)
|
apiV1Service.Register(apiV1Group)
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
205
store/user.go
205
store/user.go
@ -27,6 +27,211 @@ func (s *Store) SeedDataForNewUser(ctx context.Context, user *api.User) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Role is the type of a role.
|
||||||
|
type Role string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Host is the HOST role.
|
||||||
|
Host Role = "HOST"
|
||||||
|
// Admin is the ADMIN role.
|
||||||
|
Admin Role = "ADMIN"
|
||||||
|
// NormalUser is the USER role.
|
||||||
|
NormalUser Role = "USER"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e Role) String() string {
|
||||||
|
switch e {
|
||||||
|
case Host:
|
||||||
|
return "HOST"
|
||||||
|
case Admin:
|
||||||
|
return "ADMIN"
|
||||||
|
case NormalUser:
|
||||||
|
return "USER"
|
||||||
|
}
|
||||||
|
return "USER"
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserMessage struct {
|
||||||
|
ID int
|
||||||
|
|
||||||
|
// Standard fields
|
||||||
|
RowStatus RowStatus
|
||||||
|
CreatedTs int64
|
||||||
|
UpdatedTs int64
|
||||||
|
|
||||||
|
// Domain specific fields
|
||||||
|
Username string
|
||||||
|
Role Role
|
||||||
|
Email string
|
||||||
|
Nickname string
|
||||||
|
PasswordHash string
|
||||||
|
OpenID string
|
||||||
|
AvatarURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type FindUserMessage struct {
|
||||||
|
ID *int
|
||||||
|
|
||||||
|
// Standard fields
|
||||||
|
RowStatus *RowStatus
|
||||||
|
|
||||||
|
// Domain specific fields
|
||||||
|
Username *string
|
||||||
|
Role *Role
|
||||||
|
Email *string
|
||||||
|
Nickname *string
|
||||||
|
OpenID *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMessage, error) {
|
||||||
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
query := `
|
||||||
|
INSERT INTO user (
|
||||||
|
username,
|
||||||
|
role,
|
||||||
|
email,
|
||||||
|
nickname,
|
||||||
|
password_hash,
|
||||||
|
open_id
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
RETURNING id, avatar_url, created_ts, updated_ts, row_status
|
||||||
|
`
|
||||||
|
if err := tx.QueryRowContext(ctx, query,
|
||||||
|
create.Username,
|
||||||
|
create.Role,
|
||||||
|
create.Email,
|
||||||
|
create.Nickname,
|
||||||
|
create.PasswordHash,
|
||||||
|
create.OpenID,
|
||||||
|
).Scan(
|
||||||
|
&create.ID,
|
||||||
|
&create.AvatarURL,
|
||||||
|
&create.CreatedTs,
|
||||||
|
&create.UpdatedTs,
|
||||||
|
&create.RowStatus,
|
||||||
|
); err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
userMessage := create
|
||||||
|
return userMessage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMessage, error) {
|
||||||
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
list, err := listUsers(ctx, tx, find)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return list, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessage, error) {
|
||||||
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
list, err := listUsers(ctx, tx, find)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(list) == 0 {
|
||||||
|
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
|
||||||
|
}
|
||||||
|
|
||||||
|
memoMessage := list[0]
|
||||||
|
return memoMessage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserMessage, error) {
|
||||||
|
where, args := []string{"1 = 1"}, []any{}
|
||||||
|
|
||||||
|
if v := find.ID; v != nil {
|
||||||
|
where, args = append(where, "id = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := find.Username; v != nil {
|
||||||
|
where, args = append(where, "username = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := find.Role; v != nil {
|
||||||
|
where, args = append(where, "role = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := find.Email; v != nil {
|
||||||
|
where, args = append(where, "email = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := find.Nickname; v != nil {
|
||||||
|
where, args = append(where, "nickname = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
if v := find.OpenID; v != nil {
|
||||||
|
where, args = append(where, "open_id = ?"), append(args, *v)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
username,
|
||||||
|
role,
|
||||||
|
email,
|
||||||
|
nickname,
|
||||||
|
password_hash,
|
||||||
|
open_id,
|
||||||
|
avatar_url,
|
||||||
|
created_ts,
|
||||||
|
updated_ts,
|
||||||
|
row_status
|
||||||
|
FROM user
|
||||||
|
WHERE ` + strings.Join(where, " AND ") + `
|
||||||
|
ORDER BY created_ts DESC, row_status DESC
|
||||||
|
`
|
||||||
|
rows, err := tx.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
userMessageList := make([]*UserMessage, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var userMessage UserMessage
|
||||||
|
if err := rows.Scan(
|
||||||
|
&userMessage.ID,
|
||||||
|
&userMessage.Username,
|
||||||
|
&userMessage.Role,
|
||||||
|
&userMessage.Email,
|
||||||
|
&userMessage.Nickname,
|
||||||
|
&userMessage.PasswordHash,
|
||||||
|
&userMessage.OpenID,
|
||||||
|
&userMessage.AvatarURL,
|
||||||
|
&userMessage.CreatedTs,
|
||||||
|
&userMessage.UpdatedTs,
|
||||||
|
&userMessage.RowStatus,
|
||||||
|
); err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
userMessageList = append(userMessageList, &userMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return userMessageList, nil
|
||||||
|
}
|
||||||
|
|
||||||
// userRaw is the store model for an User.
|
// userRaw is the store model for an User.
|
||||||
// Fields have exactly the same meanings as User.
|
// Fields have exactly the same meanings as User.
|
||||||
type userRaw struct {
|
type userRaw struct {
|
||||||
|
48
test/store/idp_test.go
Normal file
48
test/store/idp_test.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package teststore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIdentityProviderStore(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ts := NewTestingStore(ctx, t)
|
||||||
|
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
|
||||||
|
Name: "GitHub OAuth",
|
||||||
|
Type: store.IdentityProviderOAuth2,
|
||||||
|
IdentifierFilter: "",
|
||||||
|
Config: &store.IdentityProviderConfig{
|
||||||
|
OAuth2Config: &store.IdentityProviderOAuth2Config{
|
||||||
|
ClientID: "asd",
|
||||||
|
ClientSecret: "123",
|
||||||
|
AuthURL: "https://github.com/auth",
|
||||||
|
TokenURL: "https://github.com/token",
|
||||||
|
UserInfoURL: "https://github.com/user",
|
||||||
|
Scopes: []string{"login"},
|
||||||
|
FieldMapping: &store.FieldMapping{
|
||||||
|
Identifier: "login",
|
||||||
|
DisplayName: "name",
|
||||||
|
Email: "emai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
|
||||||
|
ID: &createdIDP.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, createdIDP, idp)
|
||||||
|
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{
|
||||||
|
ID: idp.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 0, len(idpList))
|
||||||
|
}
|
@ -17,9 +17,7 @@ const SSOSection = () => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const fetchIdentityProviderList = async () => {
|
const fetchIdentityProviderList = async () => {
|
||||||
const {
|
const { data: identityProviderList } = await api.getIdentityProviderList();
|
||||||
data: { data: identityProviderList },
|
|
||||||
} = await api.getIdentityProviderList();
|
|
||||||
setIdentityProviderList(identityProviderList);
|
setIdentityProviderList(identityProviderList);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -246,19 +246,19 @@ export function deleteStorage(storageId: StorageId) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function getIdentityProviderList() {
|
export function getIdentityProviderList() {
|
||||||
return axios.get<ResponseObject<IdentityProvider[]>>(`/api/idp`);
|
return axios.get<IdentityProvider[]>(`/api/v1/idp`);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createIdentityProvider(identityProviderCreate: IdentityProviderCreate) {
|
export function createIdentityProvider(identityProviderCreate: IdentityProviderCreate) {
|
||||||
return axios.post<ResponseObject<IdentityProvider>>(`/api/idp`, identityProviderCreate);
|
return axios.post<IdentityProvider>(`/api/v1/idp`, identityProviderCreate);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function patchIdentityProvider(identityProviderPatch: IdentityProviderPatch) {
|
export function patchIdentityProvider(identityProviderPatch: IdentityProviderPatch) {
|
||||||
return axios.patch<ResponseObject<IdentityProvider>>(`/api/idp/${identityProviderPatch.id}`, identityProviderPatch);
|
return axios.patch<IdentityProvider>(`/api/v1/idp/${identityProviderPatch.id}`, identityProviderPatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function deleteIdentityProvider(id: IdentityProviderId) {
|
export function deleteIdentityProvider(id: IdentityProviderId) {
|
||||||
return axios.delete(`/api/idp/${id}`);
|
return axios.delete(`/api/v1/idp/${id}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getRepoStarCount() {
|
export async function getRepoStarCount() {
|
||||||
|
@ -24,9 +24,7 @@ const Auth = () => {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
userStore.doSignOut().catch();
|
userStore.doSignOut().catch();
|
||||||
const fetchIdentityProviderList = async () => {
|
const fetchIdentityProviderList = async () => {
|
||||||
const {
|
const { data: identityProviderList } = await api.getIdentityProviderList();
|
||||||
data: { data: identityProviderList },
|
|
||||||
} = await api.getIdentityProviderList();
|
|
||||||
setIdentityProviderList(identityProviderList);
|
setIdentityProviderList(identityProviderList);
|
||||||
};
|
};
|
||||||
fetchIdentityProviderList();
|
fetchIdentityProviderList();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user