diff --git a/config/funcs.go b/config/funcs.go index a9c82ce..9678df0 100644 --- a/config/funcs.go +++ b/config/funcs.go @@ -11,7 +11,9 @@ package config import ( + "net/http" "strings" + "time" ) // FriendlyHost returns the app's Host sans any schema @@ -25,3 +27,16 @@ func (ac AppCfg) CanCreateBlogs(currentlyUsed uint64) bool { } return int(currentlyUsed) < ac.MaxBlogs } + +// OrDefaultString returns input or a default value if input is empty. +func OrDefaultString(input, defaultValue string) string { + if len(input) == 0 { + return defaultValue + } + return input +} + +// DefaultHTTPClient returns a sane default HTTP client. +func DefaultHTTPClient() *http.Client { + return &http.Client{Timeout: 10 * time.Second} +} diff --git a/oauth.go b/oauth.go index 18f79eb..2eccbdc 100644 --- a/oauth.go +++ b/oauth.go @@ -34,6 +34,7 @@ type InspectResponse struct { ExpiresAt time.Time `json:"expires_at"` Username string `json:"username"` Email string `json:"email"` + Error string `json:"error"` } // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token @@ -104,7 +105,7 @@ func configureSlackOauth(r *mux.Router, app *App) { ClientSecret: app.Config().SlackOauth.ClientSecret, TeamID: app.Config().SlackOauth.TeamID, CallbackLocation: app.Config().App.Host + "/oauth/callback", - HttpClient: &http.Client{Timeout: 10 * time.Second}, + HttpClient: config.DefaultHTTPClient(), } configureOauthRoutes(r, app, oauthClient) } @@ -115,11 +116,14 @@ func configureWriteAsOauth(r *mux.Router, app *App) { oauthClient := writeAsOauthClient{ ClientID: app.Config().WriteAsOauth.ClientID, ClientSecret: app.Config().WriteAsOauth.ClientSecret, - ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, - InspectLocation: app.Config().WriteAsOauth.InspectLocation, - AuthLocation: app.Config().WriteAsOauth.AuthLocation, - HttpClient: &http.Client{Timeout: 10 * time.Second}, + ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation), + InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), + AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), + HttpClient: config.DefaultHTTPClient(), CallbackLocation: app.Config().App.Host + "/oauth/callback", + } + if oauthClient.ExchangeLocation == "" { + } configureOauthRoutes(r, app, oauthClient) } diff --git a/oauth_slack.go b/oauth_slack.go index 9c8508e..32ceea0 100644 --- a/oauth_slack.go +++ b/oauth_slack.go @@ -2,6 +2,7 @@ package writefreely import ( "context" + "errors" "github.com/writeas/slug" "net/http" "net/url" @@ -17,10 +18,12 @@ type slackOauthClient struct { } type slackExchangeResponse struct { + OK bool `json:"ok"` AccessToken string `json:"access_token"` Scope string `json:"scope"` TeamName string `json:"team_name"` TeamID string `json:"team_id"` + Error string `json:"error"` } type slackIdentity struct { @@ -103,11 +106,17 @@ func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (* if err != nil { return nil, err } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } var tokenResponse slackExchangeResponse if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { return nil, err } + if !tokenResponse.OK { + return nil, errors.New(tokenResponse.Error) + } return tokenResponse.TokenResponse(), nil } @@ -125,11 +134,17 @@ func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessTok if err != nil { return nil, err } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } var inspectResponse slackUserIdentityResponse if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { return nil, err } + if !inspectResponse.OK { + return nil, errors.New(inspectResponse.Error) + } return inspectResponse.InspectResponse(), nil } diff --git a/oauth_writeas.go b/oauth_writeas.go index 9550c35..eb12f64 100644 --- a/oauth_writeas.go +++ b/oauth_writeas.go @@ -2,6 +2,7 @@ package writefreely import ( "context" + "errors" "net/http" "net/url" "strings" @@ -19,6 +20,12 @@ type writeAsOauthClient struct { var _ oauthClient = writeAsOauthClient{} +const ( + writeAsAuthLocation = "https://write.as/oauth/login" + writeAsExchangeLocation = "https://write.as/oauth/token" + writeAsIdentityLocation = "https://write.as/oauth/inspect" +) + func (c writeAsOauthClient) GetProvider() string { return "write.as" } @@ -60,11 +67,17 @@ func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) if err != nil { return nil, err } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to exchange code for access token") + } var tokenResponse TokenResponse if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { return nil, err } + if tokenResponse.Error != "" { + return nil, errors.New(tokenResponse.Error) + } return &tokenResponse, nil } @@ -82,10 +95,16 @@ func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessT if err != nil { return nil, err } + if resp.StatusCode != http.StatusOK { + return nil, errors.New("unable to inspect access token") + } var inspectResponse InspectResponse if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { return nil, err } + if inspectResponse.Error != "" { + return nil, errors.New(inspectResponse.Error) + } return &inspectResponse, nil }