refactor: merge sign in requests

This commit is contained in:
johnnyjoy
2025-05-14 22:13:52 +08:00
parent a0f68895ab
commit ca79990679
9 changed files with 557 additions and 446 deletions

View File

@@ -44,30 +44,115 @@ func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusR
}
func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &request.Username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
var existingUser *store.User
if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil {
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &passwordCredentials.Username,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(passwordCredentials.Password)); err != nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
}
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
}
// Check if the password auth in is allowed.
if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
}
existingUser = user
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &ssoCredentials.IdpId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.InvalidArgument, "identity provider not found")
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err)
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err)
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier)
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
// Check if the user is allowed to sign up.
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
}
if workspaceGeneralSetting.DisallowUserRegistration {
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
}
// Create a new user with the user info from the identity provider.
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
AvatarURL: userInfo.AvatarURL,
}
password, err := util.RandomString(20)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
}
}
existingUser = user
}
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
if existingUser == nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid credentials")
}
// Check if the password auth in is allowed.
if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
}
if user.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", request.Username)
if existingUser.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", existingUser.Username)
}
expireTime := time.Now().Add(AccessTokenDuration)
@@ -75,97 +160,10 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest)
// Set the expire time to 100 years.
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
}
if err := s.doSignIn(ctx, user, expireTime); err != nil {
if err := s.doSignIn(ctx, existingUser, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWithSSORequest) (*v1pb.User, error) {
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &request.IdpId,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.InvalidArgument, "identity provider not found")
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err)
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err)
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier)
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
}
if user == nil {
// Check if the user is allowed to sign up.
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
}
if workspaceGeneralSetting.DisallowUserRegistration {
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
}
// Create a new user with the user info from the identity provider.
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
AvatarURL: userInfo.AvatarURL,
}
password, err := util.RandomString(20)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err)
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
}
}
if user.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", userInfo.Identifier)
}
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
}
return convertUserFromStore(user), nil
return convertUserFromStore(existingUser), nil
}
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {