219 lines
6.5 KiB
Go
219 lines
6.5 KiB
Go
/*
|
|
* Copyright © 2020 A Bunch Tell LLC.
|
|
*
|
|
* This file is part of WriteFreely.
|
|
*
|
|
* WriteFreely is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU Affero General Public License, included
|
|
* in the LICENSE file in this source code package.
|
|
*/
|
|
|
|
package writefreely
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"github.com/writeas/impart"
|
|
"github.com/writeas/web-core/auth"
|
|
"github.com/writeas/web-core/log"
|
|
"github.com/writeas/writefreely/page"
|
|
"html/template"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type viewOauthSignupVars struct {
|
|
page.StaticPage
|
|
To string
|
|
Message template.HTML
|
|
Flashes []template.HTML
|
|
|
|
AccessToken string
|
|
TokenUsername string
|
|
TokenAlias string // TODO: rename this to match the data it represents: the collection title
|
|
TokenEmail string
|
|
TokenRemoteUser string
|
|
Provider string
|
|
ClientID string
|
|
TokenHash string
|
|
|
|
LoginUsername string
|
|
Alias string // TODO: rename this to match the data it represents: the collection title
|
|
Email string
|
|
}
|
|
|
|
const (
|
|
oauthParamAccessToken = "access_token"
|
|
oauthParamTokenUsername = "token_username"
|
|
oauthParamTokenAlias = "token_alias"
|
|
oauthParamTokenEmail = "token_email"
|
|
oauthParamTokenRemoteUserID = "token_remote_user"
|
|
oauthParamClientID = "client_id"
|
|
oauthParamProvider = "provider"
|
|
oauthParamHash = "signature"
|
|
oauthParamUsername = "username"
|
|
oauthParamAlias = "alias"
|
|
oauthParamEmail = "email"
|
|
oauthParamPassword = "password"
|
|
)
|
|
|
|
type oauthSignupPageParams struct {
|
|
AccessToken string
|
|
TokenUsername string
|
|
TokenAlias string // TODO: rename this to match the data it represents: the collection title
|
|
TokenEmail string
|
|
TokenRemoteUser string
|
|
ClientID string
|
|
Provider string
|
|
TokenHash string
|
|
}
|
|
|
|
func (p oauthSignupPageParams) HashTokenParams(key string) string {
|
|
hasher := sha256.New()
|
|
hasher.Write([]byte(key))
|
|
hasher.Write([]byte(p.AccessToken))
|
|
hasher.Write([]byte(p.TokenUsername))
|
|
hasher.Write([]byte(p.TokenAlias))
|
|
hasher.Write([]byte(p.TokenEmail))
|
|
hasher.Write([]byte(p.TokenRemoteUser))
|
|
hasher.Write([]byte(p.ClientID))
|
|
hasher.Write([]byte(p.Provider))
|
|
return hex.EncodeToString(hasher.Sum(nil))
|
|
}
|
|
|
|
func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.Request) error {
|
|
tp := &oauthSignupPageParams{
|
|
AccessToken: r.FormValue(oauthParamAccessToken),
|
|
TokenUsername: r.FormValue(oauthParamTokenUsername),
|
|
TokenAlias: r.FormValue(oauthParamTokenAlias),
|
|
TokenEmail: r.FormValue(oauthParamTokenEmail),
|
|
TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID),
|
|
ClientID: r.FormValue(oauthParamClientID),
|
|
Provider: r.FormValue(oauthParamProvider),
|
|
}
|
|
if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) {
|
|
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."}
|
|
}
|
|
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
|
|
if err := h.validateOauthSignup(r); err != nil {
|
|
return h.showOauthSignupPage(app, w, r, tp, err)
|
|
}
|
|
|
|
var err error
|
|
hashedPass := []byte{}
|
|
clearPass := r.FormValue(oauthParamPassword)
|
|
hasPass := clearPass != ""
|
|
if hasPass {
|
|
hashedPass, err = auth.HashPass([]byte(clearPass))
|
|
if err != nil {
|
|
return h.showOauthSignupPage(app, w, r, tp, fmt.Errorf("unable to hash password"))
|
|
}
|
|
}
|
|
newUser := &User{
|
|
Username: r.FormValue(oauthParamUsername),
|
|
HashedPass: hashedPass,
|
|
HasPass: hasPass,
|
|
Email: prepareUserEmail(r.FormValue(oauthParamEmail), h.EmailKey),
|
|
Created: time.Now().Truncate(time.Second).UTC(),
|
|
}
|
|
displayName := r.FormValue(oauthParamAlias)
|
|
if len(displayName) == 0 {
|
|
displayName = r.FormValue(oauthParamUsername)
|
|
}
|
|
|
|
err = h.DB.CreateUser(h.Config, newUser, displayName)
|
|
if err != nil {
|
|
return h.showOauthSignupPage(app, w, r, tp, err)
|
|
}
|
|
|
|
err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken))
|
|
if err != nil {
|
|
return h.showOauthSignupPage(app, w, r, tp, err)
|
|
}
|
|
|
|
if err := loginOrFail(h.Store, w, r, newUser); err != nil {
|
|
return h.showOauthSignupPage(app, w, r, tp, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h oauthHandler) validateOauthSignup(r *http.Request) error {
|
|
username := r.FormValue(oauthParamUsername)
|
|
if len(username) < h.Config.App.MinUsernameLen {
|
|
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too short."}
|
|
}
|
|
if len(username) > 100 {
|
|
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too long."}
|
|
}
|
|
collTitle := r.FormValue(oauthParamAlias)
|
|
if len(collTitle) == 0 {
|
|
collTitle = username
|
|
}
|
|
email := r.FormValue(oauthParamEmail)
|
|
if len(email) > 0 {
|
|
parts := strings.Split(email, "@")
|
|
if len(parts) != 2 || (len(parts[0]) < 1 || len(parts[1]) < 1) {
|
|
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Invalid email address"}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *http.Request, tp *oauthSignupPageParams, errMsg error) error {
|
|
username := tp.TokenUsername
|
|
collTitle := tp.TokenAlias
|
|
email := tp.TokenEmail
|
|
|
|
session, err := app.sessionStore.Get(r, cookieName)
|
|
if err != nil {
|
|
// Ignore this
|
|
log.Error("Unable to get session; ignoring: %v", err)
|
|
}
|
|
|
|
if tmpValue := r.FormValue(oauthParamUsername); len(tmpValue) > 0 {
|
|
username = tmpValue
|
|
}
|
|
if tmpValue := r.FormValue(oauthParamAlias); len(tmpValue) > 0 {
|
|
collTitle = tmpValue
|
|
}
|
|
if tmpValue := r.FormValue(oauthParamEmail); len(tmpValue) > 0 {
|
|
email = tmpValue
|
|
}
|
|
|
|
p := &viewOauthSignupVars{
|
|
StaticPage: pageForReq(app, r),
|
|
To: r.FormValue("to"),
|
|
Flashes: []template.HTML{},
|
|
|
|
AccessToken: tp.AccessToken,
|
|
TokenUsername: tp.TokenUsername,
|
|
TokenAlias: tp.TokenAlias,
|
|
TokenEmail: tp.TokenEmail,
|
|
TokenRemoteUser: tp.TokenRemoteUser,
|
|
Provider: tp.Provider,
|
|
ClientID: tp.ClientID,
|
|
TokenHash: tp.TokenHash,
|
|
|
|
LoginUsername: username,
|
|
Alias: collTitle,
|
|
Email: email,
|
|
}
|
|
|
|
// Display any error messages
|
|
flashes, _ := getSessionFlashes(app, w, r, session)
|
|
for _, flash := range flashes {
|
|
p.Flashes = append(p.Flashes, template.HTML(flash))
|
|
}
|
|
if errMsg != nil {
|
|
p.Flashes = append(p.Flashes, template.HTML(errMsg.Error()))
|
|
}
|
|
err = pages["signup-oauth.tmpl"].ExecuteTemplate(w, "base", p)
|
|
if err != nil {
|
|
log.Error("Unable to render signup-oauth: %v", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|