diff --git a/internal/db/db.go b/internal/db/db.go index 4e315d26e..fbd13d729 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -138,8 +138,11 @@ type DB interface { // Returns an error if the username is already taken, or something went wrong in the db. IsUsernameAvailable(username string) error - // IsEmailAvailable checks whether a given email address for a user is available on our domain. - // Returns an error if the email is already associated with an account, or something went wrong in the db. + // IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain. + // Return an error if: + // A) the email is already associated with an account + // B) we block signups from this email domain + // C) something went wrong in the db IsEmailAvailable(email string) error /* diff --git a/internal/db/model/emaildomainblock.go b/internal/db/model/emaildomainblock.go index 5b910c0a2..6610a2075 100644 --- a/internal/db/model/emaildomainblock.go +++ b/internal/db/model/emaildomainblock.go @@ -20,13 +20,11 @@ package model import "time" -// SignUpDomainBlock represents a domain that the server should automatically reject sign-up requests from. -type SignUpDomainBlock struct { +// EmailDomainBlock represents a domain that the server should automatically reject sign-up requests from. +type EmailDomainBlock struct { // ID of this block in the database ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull,unique"` - // Domain to block. If ANY PART of the candidate domain contains this string, it will be blocked. - // For example: 'example.org' also blocks 'gts.example.org'. '.com' blocks *any* '.com' domains. - // TODO: implement wildcards here + // Email domain to block. Eg. 'gmail.com' or 'hotmail.com' Domain string `pg:",notnull"` // When was this block created CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"` diff --git a/internal/db/pg.go b/internal/db/pg.go index ab8af86ca..90c2d4687 100644 --- a/internal/db/pg.go +++ b/internal/db/pg.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "net/mail" "regexp" "strings" "time" @@ -378,19 +379,35 @@ func (ps *postgresService) IsUsernameAvailable(username string) error { if err := ps.conn.Model(&model.Account{}).Where("username = ?").Where("domain = ?", nil).Select(); err == nil { return fmt.Errorf("username %s already in use", username) } else if err != pg.ErrNoRows { - return err + return fmt.Errorf("db error: %s", err) } return nil } func (ps *postgresService) IsEmailAvailable(email string) error { - // if no error we fail because it means we found something - // if error but it's not db.ErrorNoEntries then we fail - // if err is ErrNoEntries we're good, we found nothing so continue + // parse the domain from the email + m, err := mail.ParseAddress(email) + if err != nil { + return fmt.Errorf("error parsing email address %s: %s", email, err) + } + domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @ + + // check if the email domain is blocked + if err := ps.conn.Model(&model.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil { + // fail because we found something + return fmt.Errorf("email domain %s is blocked", domain) + } else if err != pg.ErrNoRows { + // fail because we got an unexpected error + return fmt.Errorf("db error: %s", err) + } + + // check if this email is associated with an account already if err := ps.conn.Model(&model.Account{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil { + // fail because we found something return fmt.Errorf("email %s already in use", email) } else if err != pg.ErrNoRows { - return err + // fail because we got an unexpected error + return fmt.Errorf("db error: %s", err) } return nil }