mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
chore: update idp store (#1856)
This commit is contained in:
@@ -74,16 +74,19 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
|
|||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
|
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||||
ID: &signin.IdentityProviderID,
|
ID: &signin.IdentityProviderID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
if identityProvider == nil {
|
||||||
|
return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
|
||||||
|
}
|
||||||
|
|
||||||
var userInfo *idp.IdentityProviderUserInfo
|
var userInfo *idp.IdentityProviderUserInfo
|
||||||
if identityProviderMessage.Type == store.IdentityProviderOAuth2 {
|
if identityProvider.Type == store.IdentityProviderOAuth2 {
|
||||||
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProviderMessage.Config.OAuth2Config)
|
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
|
||||||
}
|
}
|
||||||
@@ -97,7 +100,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
identifierFilter := identityProviderMessage.IdentifierFilter
|
identifierFilter := identityProvider.IdentifierFilter
|
||||||
if identifierFilter != "" {
|
if identifierFilter != "" {
|
||||||
identifierFilterRegex, err := regexp.Compile(identifierFilter)
|
identifierFilterRegex, err := regexp.Compile(identifierFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -83,7 +83,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
identityProviderMessage, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
|
identityProvider, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProvider{
|
||||||
Name: identityProviderCreate.Name,
|
Name: identityProviderCreate.Name,
|
||||||
Type: store.IdentityProviderType(identityProviderCreate.Type),
|
Type: store.IdentityProviderType(identityProviderCreate.Type),
|
||||||
IdentifierFilter: identityProviderCreate.IdentifierFilter,
|
IdentifierFilter: identityProviderCreate.IdentifierFilter,
|
||||||
@@ -92,7 +92,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err)
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
|
||||||
})
|
})
|
||||||
|
|
||||||
g.PATCH("/idp/:idpId", func(c echo.Context) error {
|
g.PATCH("/idp/:idpId", func(c echo.Context) error {
|
||||||
@@ -124,7 +124,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
identityProviderMessage, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderMessage{
|
identityProvider, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{
|
||||||
ID: identityProviderPatch.ID,
|
ID: identityProviderPatch.ID,
|
||||||
Type: store.IdentityProviderType(identityProviderPatch.Type),
|
Type: store.IdentityProviderType(identityProviderPatch.Type),
|
||||||
Name: identityProviderPatch.Name,
|
Name: identityProviderPatch.Name,
|
||||||
@@ -134,12 +134,12 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err)
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
|
||||||
})
|
})
|
||||||
|
|
||||||
g.GET("/idp", func(c echo.Context) error {
|
g.GET("/idp", func(c echo.Context) error {
|
||||||
ctx := c.Request().Context()
|
ctx := c.Request().Context()
|
||||||
identityProviderMessageList, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
|
list, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err)
|
||||||
}
|
}
|
||||||
@@ -159,8 +159,8 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
identityProviderList := []*IdentityProvider{}
|
identityProviderList := []*IdentityProvider{}
|
||||||
for _, identityProviderMessage := range identityProviderMessageList {
|
for _, item := range list {
|
||||||
identityProvider := convertIdentityProviderFromStore(identityProviderMessage)
|
identityProvider := convertIdentityProviderFromStore(item)
|
||||||
// data desensitize
|
// data desensitize
|
||||||
if !isHostUser {
|
if !isHostUser {
|
||||||
identityProvider.Config.OAuth2Config.ClientSecret = ""
|
identityProvider.Config.OAuth2Config.ClientSecret = ""
|
||||||
@@ -191,13 +191,17 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
|
||||||
}
|
}
|
||||||
identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
|
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||||
ID: &identityProviderID,
|
ID: &identityProviderID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err)
|
||||||
}
|
}
|
||||||
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage))
|
if identityProvider == nil {
|
||||||
|
return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider))
|
||||||
})
|
})
|
||||||
|
|
||||||
g.DELETE("/idp/:idpId", func(c echo.Context) error {
|
g.DELETE("/idp/:idpId", func(c echo.Context) error {
|
||||||
@@ -222,7 +226,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{ID: identityProviderID}); err != nil {
|
if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProviderID}); err != nil {
|
||||||
if common.ErrorCode(err) == common.NotFound {
|
if common.ErrorCode(err) == common.NotFound {
|
||||||
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Identity provider ID not found: %d", identityProviderID))
|
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Identity provider ID not found: %d", identityProviderID))
|
||||||
}
|
}
|
||||||
@@ -232,13 +236,13 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *IdentityProvider {
|
func convertIdentityProviderFromStore(identityProvider *store.IdentityProvider) *IdentityProvider {
|
||||||
return &IdentityProvider{
|
return &IdentityProvider{
|
||||||
ID: identityProviderMessage.ID,
|
ID: identityProvider.ID,
|
||||||
Name: identityProviderMessage.Name,
|
Name: identityProvider.Name,
|
||||||
Type: IdentityProviderType(identityProviderMessage.Type),
|
Type: IdentityProviderType(identityProvider.Type),
|
||||||
IdentifierFilter: identityProviderMessage.IdentifierFilter,
|
IdentifierFilter: identityProvider.IdentifierFilter,
|
||||||
Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config),
|
Config: convertIdentityProviderConfigFromStore(identityProvider.Config),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
121
store/idp.go
121
store/idp.go
@@ -6,8 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/usememos/memos/common"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type IdentityProviderType string
|
type IdentityProviderType string
|
||||||
@@ -36,7 +34,7 @@ type FieldMapping struct {
|
|||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type IdentityProviderMessage struct {
|
type IdentityProvider struct {
|
||||||
ID int
|
ID int
|
||||||
Name string
|
Name string
|
||||||
Type IdentityProviderType
|
Type IdentityProviderType
|
||||||
@@ -44,11 +42,11 @@ type IdentityProviderMessage struct {
|
|||||||
Config *IdentityProviderConfig
|
Config *IdentityProviderConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type FindIdentityProviderMessage struct {
|
type FindIdentityProvider struct {
|
||||||
ID *int
|
ID *int
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateIdentityProviderMessage struct {
|
type UpdateIdentityProvider struct {
|
||||||
ID int
|
ID int
|
||||||
Type IdentityProviderType
|
Type IdentityProviderType
|
||||||
Name *string
|
Name *string
|
||||||
@@ -56,14 +54,14 @@ type UpdateIdentityProviderMessage struct {
|
|||||||
Config *IdentityProviderConfig
|
Config *IdentityProviderConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeleteIdentityProviderMessage struct {
|
type DeleteIdentityProvider struct {
|
||||||
ID int
|
ID int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProviderMessage) (*IdentityProviderMessage, error) {
|
func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) {
|
||||||
tx, err := s.db.BeginTx(ctx, nil)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
@@ -76,6 +74,7 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv
|
|||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
|
return nil, fmt.Errorf("unsupported idp type %s", string(create.Type))
|
||||||
}
|
}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO idp (
|
INSERT INTO idp (
|
||||||
name,
|
name,
|
||||||
@@ -96,20 +95,22 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv
|
|||||||
).Scan(
|
).Scan(
|
||||||
&create.ID,
|
&create.ID,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
identityProviderMessage := create
|
|
||||||
s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
|
identityProvider := create
|
||||||
return identityProviderMessage, nil
|
s.idpCache.Store(identityProvider.ID, identityProvider)
|
||||||
|
return identityProvider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) {
|
func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) {
|
||||||
tx, err := s.db.BeginTx(ctx, nil)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
@@ -124,16 +125,16 @@ func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityPro
|
|||||||
return list, nil
|
return list, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProviderMessage) (*IdentityProviderMessage, error) {
|
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) {
|
||||||
if find.ID != nil {
|
if find.ID != nil {
|
||||||
if cache, ok := s.idpCache.Load(*find.ID); ok {
|
if cache, ok := s.idpCache.Load(*find.ID); ok {
|
||||||
return cache.(*IdentityProviderMessage), nil
|
return cache.(*IdentityProvider), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTx(ctx, nil)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
@@ -142,18 +143,18 @@ func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvi
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(list) == 0 {
|
if len(list) == 0 {
|
||||||
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
identityProviderMessage := list[0]
|
identityProvider := list[0]
|
||||||
s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
|
s.idpCache.Store(identityProvider.ID, identityProvider)
|
||||||
return identityProviderMessage, nil
|
return identityProvider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderMessage) (*IdentityProviderMessage, error) {
|
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) {
|
||||||
tx, err := s.db.BeginTx(ctx, nil)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
@@ -184,64 +185,65 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti
|
|||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
RETURNING id, name, type, identifier_filter, config
|
RETURNING id, name, type, identifier_filter, config
|
||||||
`
|
`
|
||||||
var identityProviderMessage IdentityProviderMessage
|
var identityProvider IdentityProvider
|
||||||
var identityProviderConfig string
|
var identityProviderConfig string
|
||||||
if err := tx.QueryRowContext(ctx, query, args...).Scan(
|
if err := tx.QueryRowContext(ctx, query, args...).Scan(
|
||||||
&identityProviderMessage.ID,
|
&identityProvider.ID,
|
||||||
&identityProviderMessage.Name,
|
&identityProvider.Name,
|
||||||
&identityProviderMessage.Type,
|
&identityProvider.Type,
|
||||||
&identityProviderMessage.IdentifierFilter,
|
&identityProvider.IdentifierFilter,
|
||||||
&identityProviderConfig,
|
&identityProviderConfig,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
if identityProviderMessage.Type == IdentityProviderOAuth2 {
|
|
||||||
|
if identityProvider.Type == IdentityProviderOAuth2 {
|
||||||
oauth2Config := &IdentityProviderOAuth2Config{}
|
oauth2Config := &IdentityProviderOAuth2Config{}
|
||||||
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
|
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
identityProviderMessage.Config = &IdentityProviderConfig{
|
identityProvider.Config = &IdentityProviderConfig{
|
||||||
OAuth2Config: oauth2Config,
|
OAuth2Config: oauth2Config,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type))
|
return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage)
|
|
||||||
return &identityProviderMessage, nil
|
s.idpCache.Store(identityProvider.ID, identityProvider)
|
||||||
|
return &identityProvider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProviderMessage) error {
|
func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error {
|
||||||
tx, err := s.db.BeginTx(ctx, nil)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return FormatError(err)
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
where, args := []string{"id = ?"}, []any{delete.ID}
|
where, args := []string{"id = ?"}, []any{delete.ID}
|
||||||
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
||||||
result, err := tx.ExecContext(ctx, stmt, args...)
|
result, err := tx.ExecContext(ctx, stmt, args...)
|
||||||
if err != nil {
|
|
||||||
return FormatError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if rows == 0 {
|
|
||||||
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("idp not found")}
|
if _, err = result.RowsAffected(); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.idpCache.Delete(delete.ID)
|
s.idpCache.Delete(delete.ID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) {
|
func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) {
|
||||||
where, args := []string{"TRUE"}, []any{}
|
where, args := []string{"TRUE"}, []any{}
|
||||||
if v := find.ID; v != nil {
|
if v := find.ID; v != nil {
|
||||||
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
|
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
|
||||||
@@ -259,40 +261,41 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr
|
|||||||
args...,
|
args...,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var identityProviderMessages []*IdentityProviderMessage
|
var identityProviders []*IdentityProvider
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var identityProviderMessage IdentityProviderMessage
|
var identityProvider IdentityProvider
|
||||||
var identityProviderConfig string
|
var identityProviderConfig string
|
||||||
if err := rows.Scan(
|
if err := rows.Scan(
|
||||||
&identityProviderMessage.ID,
|
&identityProvider.ID,
|
||||||
&identityProviderMessage.Name,
|
&identityProvider.Name,
|
||||||
&identityProviderMessage.Type,
|
&identityProvider.Type,
|
||||||
&identityProviderMessage.IdentifierFilter,
|
&identityProvider.IdentifierFilter,
|
||||||
&identityProviderConfig,
|
&identityProviderConfig,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
if identityProviderMessage.Type == IdentityProviderOAuth2 {
|
|
||||||
|
if identityProvider.Type == IdentityProviderOAuth2 {
|
||||||
oauth2Config := &IdentityProviderOAuth2Config{}
|
oauth2Config := &IdentityProviderOAuth2Config{}
|
||||||
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
|
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
identityProviderMessage.Config = &IdentityProviderConfig{
|
identityProvider.Config = &IdentityProviderConfig{
|
||||||
OAuth2Config: oauth2Config,
|
OAuth2Config: oauth2Config,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type))
|
return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type))
|
||||||
}
|
}
|
||||||
identityProviderMessages = append(identityProviderMessages, &identityProviderMessage)
|
identityProviders = append(identityProviders, &identityProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return identityProviderMessages, nil
|
return identityProviders, nil
|
||||||
}
|
}
|
||||||
|
@@ -16,7 +16,7 @@ type Store struct {
|
|||||||
userCache sync.Map // map[int]*userRaw
|
userCache sync.Map // map[int]*userRaw
|
||||||
userSettingCache sync.Map // map[string]*UserSettingMessage
|
userSettingCache sync.Map // map[string]*UserSettingMessage
|
||||||
shortcutCache sync.Map // map[int]*shortcutRaw
|
shortcutCache sync.Map // map[int]*shortcutRaw
|
||||||
idpCache sync.Map // map[int]*IdentityProviderMessage
|
idpCache sync.Map // map[int]*IdentityProvider
|
||||||
resourceCache sync.Map // map[int]*resourceRaw
|
resourceCache sync.Map // map[int]*resourceRaw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -12,14 +12,14 @@ import (
|
|||||||
func TestIdentityProviderStore(t *testing.T) {
|
func TestIdentityProviderStore(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ts := NewTestingStore(ctx, t)
|
ts := NewTestingStore(ctx, t)
|
||||||
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
|
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProvider{
|
||||||
Name: "GitHub OAuth",
|
Name: "GitHub OAuth",
|
||||||
Type: store.IdentityProviderOAuth2,
|
Type: store.IdentityProviderOAuth2,
|
||||||
IdentifierFilter: "",
|
IdentifierFilter: "",
|
||||||
Config: &store.IdentityProviderConfig{
|
Config: &store.IdentityProviderConfig{
|
||||||
OAuth2Config: &store.IdentityProviderOAuth2Config{
|
OAuth2Config: &store.IdentityProviderOAuth2Config{
|
||||||
ClientID: "asd",
|
ClientID: "client_id",
|
||||||
ClientSecret: "123",
|
ClientSecret: "client_secret",
|
||||||
AuthURL: "https://github.com/auth",
|
AuthURL: "https://github.com/auth",
|
||||||
TokenURL: "https://github.com/token",
|
TokenURL: "https://github.com/token",
|
||||||
UserInfoURL: "https://github.com/user",
|
UserInfoURL: "https://github.com/user",
|
||||||
@@ -33,16 +33,23 @@ func TestIdentityProviderStore(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
|
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{
|
||||||
ID: &createdIDP.ID,
|
ID: &createdIDP.ID,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, createdIDP, idp)
|
require.Equal(t, createdIDP, idp)
|
||||||
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{
|
newName := "My GitHub OAuth"
|
||||||
|
updatedIdp, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{
|
||||||
|
ID: idp.ID,
|
||||||
|
Name: &newName,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, newName, updatedIdp.Name)
|
||||||
|
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{
|
||||||
ID: idp.ID,
|
ID: idp.ID,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
|
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 0, len(idpList))
|
require.Equal(t, 0, len(idpList))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user