mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2025-06-05 21:59:39 +02:00
Pg to bun (#148)
* start moving to bun * changing more stuff * more * and yet more * tests passing * seems stable now * more big changes * small fix * little fixes
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -80,13 +81,13 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
|
||||
app := >smodel.Application{
|
||||
ClientID: clientID,
|
||||
}
|
||||
if err := m.db.GetWhere([]db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
|
||||
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: app.ClientID}}, app); err != nil {
|
||||
m.clearSession(s)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := m.parseUserFromClaims(claims, net.IP(c.ClientIP()), app.ID)
|
||||
user, err := m.parseUserFromClaims(c.Request.Context(), claims, net.IP(c.ClientIP()), app.ID)
|
||||
if err != nil {
|
||||
m.clearSession(s)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
@@ -103,14 +104,14 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, OauthAuthorizePath)
|
||||
}
|
||||
|
||||
func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
|
||||
func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, ip net.IP, appID string) (*gtsmodel.User, error) {
|
||||
if claims.Email == "" {
|
||||
return nil, errors.New("no email returned in claims")
|
||||
}
|
||||
|
||||
// see if we already have a user for this email address
|
||||
user := >smodel.User{}
|
||||
err := m.db.GetWhere([]db.Where{{Key: "email", Value: claims.Email}}, user)
|
||||
err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
|
||||
if err == nil {
|
||||
// we do! so we can just return it
|
||||
return user, nil
|
||||
@@ -122,7 +123,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
||||
}
|
||||
|
||||
// maybe we have an unconfirmed user
|
||||
err = m.db.GetWhere([]db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
|
||||
err = m.db.GetWhere(ctx, []db.Where{{Key: "unconfirmed_email", Value: claims.Email}}, user)
|
||||
if err == nil {
|
||||
// user is unconfirmed so return an error
|
||||
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
|
||||
@@ -137,9 +138,13 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
||||
// however, because we trust the OIDC provider, we should now create a user + account with the provided claims
|
||||
|
||||
// check if the email address is available for use; if it's not there's nothing we can so
|
||||
if err := m.db.IsEmailAvailable(claims.Email); err != nil {
|
||||
emailAvailable, err := m.db.IsEmailAvailable(ctx, claims.Email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("email %s not available: %s", claims.Email, err)
|
||||
}
|
||||
if !emailAvailable {
|
||||
return nil, fmt.Errorf("email %s in use", claims.Email)
|
||||
}
|
||||
|
||||
// now we need a username
|
||||
var username string
|
||||
@@ -180,12 +185,11 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
||||
// note that for the first iteration, iString is still "" when the check is made, so our first choice
|
||||
// is still the raw username with no integer stuck on the end
|
||||
for i := 1; !found; i = i + 1 {
|
||||
if err := m.db.IsUsernameAvailable(username + iString); err != nil {
|
||||
if strings.Contains(err.Error(), "db error") {
|
||||
// if there's an actual db error we should return
|
||||
return nil, fmt.Errorf("error checking username availability: %s", err)
|
||||
}
|
||||
} else {
|
||||
usernameAvailable, err := m.db.IsUsernameAvailable(ctx, username+iString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if usernameAvailable {
|
||||
// no error so we've found a username that works
|
||||
found = true
|
||||
username = username + iString
|
||||
@@ -209,7 +213,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
|
||||
password := uuid.NewString() + uuid.NewString()
|
||||
|
||||
// create the user! this will also create an account and store it in the database so we don't need to do that here
|
||||
user, err = m.db.NewSignup(username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
|
||||
user, err = m.db.NewSignup(ctx, username, "", m.config.AccountsConfig.RequireApproval, claims.Email, password, ip, "", appID, claims.EmailVerified, admin)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating user: %s", err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user