diff --git a/handle.go b/handle.go index c70859e..d296d93 100644 --- a/handle.go +++ b/handle.go @@ -237,22 +237,49 @@ func (h *Handler) AdminApper(f userApperHandlerFunc) http.HandlerFunc { } } +func apiAuth(app *App, r *http.Request) (*User, error) { + // Authorize user from Authorization header + t := r.Header.Get("Authorization") + if t == "" { + return nil, ErrNoAccessToken + } + u := &User{ID: app.db.GetUserID(t)} + if u.ID == -1 { + return nil, ErrBadAccessToken + } + + return u, nil +} + +// optionaAPIAuth is used for endpoints that accept authenticated requests via +// Authorization header or cookie, unlike apiAuth. It returns a different err +// in the case where no Authorization header is present. +func optionalAPIAuth(app *App, r *http.Request) (*User, error) { + // Authorize user from Authorization header + t := r.Header.Get("Authorization") + if t == "" { + return nil, ErrNotLoggedIn + } + u := &User{ID: app.db.GetUserID(t)} + if u.ID == -1 { + return nil, ErrBadAccessToken + } + + return u, nil +} + +func webAuth(app *App, r *http.Request) (*User, error) { + u := getUserSession(app, r) + if u == nil { + return nil, ErrNotLoggedIn + } + return u, nil +} + // UserAPI handles requests made in the API by the authenticated user. // This provides user-friendly HTML pages and actions that work in the browser. func (h *Handler) UserAPI(f userHandlerFunc) http.HandlerFunc { - return h.UserAll(false, f, func(app *App, r *http.Request) (*User, error) { - // Authorize user from Authorization header - t := r.Header.Get("Authorization") - if t == "" { - return nil, ErrNoAccessToken - } - u := &User{ID: app.db.GetUserID(t)} - if u.ID == -1 { - return nil, ErrBadAccessToken - } - - return u, nil - }) + return h.UserAll(false, f, apiAuth) } func (h *Handler) UserAll(web bool, f userHandlerFunc, a authFunc) http.HandlerFunc { @@ -515,6 +542,64 @@ func (h *Handler) All(f handlerFunc) http.HandlerFunc { } } +func (h *Handler) AllReader(f handlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + h.handleError(w, r, func() error { + status := 200 + start := time.Now() + + defer func() { + if e := recover(); e != nil { + log.Error("%s:\n%s", e, debug.Stack()) + impart.WriteError(w, impart.HTTPError{http.StatusInternalServerError, "Something didn't work quite right."}) + status = 500 + } + + log.Info("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, time.Since(start), r.UserAgent()) + }() + + if h.app.App().cfg.App.Private { + // This instance is private, so ensure it's being accessed by a valid user + // Check if authenticated with an access token + _, apiErr := optionalAPIAuth(h.app.App(), r) + if apiErr != nil { + if err, ok := apiErr.(impart.HTTPError); ok { + status = err.Status + } else { + status = 500 + } + + if apiErr == ErrNotLoggedIn { + // Fall back to web auth since there was no access token given + _, err := webAuth(h.app.App(), r) + if err != nil { + if err, ok := apiErr.(impart.HTTPError); ok { + status = err.Status + } else { + status = 500 + } + return err + } + } else { + return apiErr + } + } + } + + err := f(h.app.App(), w, r) + if err != nil { + if err, ok := err.(impart.HTTPError); ok { + status = err.Status + } else { + status = 500 + } + } + + return err + }()) + } +} + func (h *Handler) Download(f dataHandlerFunc, ul UserLevelFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { h.handleHTTPError(w, r, func() error { diff --git a/routes.go b/routes.go index 58bd455..a1a04b3 100644 --- a/routes.go +++ b/routes.go @@ -113,28 +113,28 @@ func InitRoutes(apper Apper, r *mux.Router) *mux.Router { // Handle collections write.HandleFunc("/api/collections", handler.All(newCollection)).Methods("POST") apiColls := write.PathPrefix("/api/collections/").Subrouter() - apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.All(fetchCollection)).Methods("GET") + apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.AllReader(fetchCollection)).Methods("GET") apiColls.HandleFunc("/{alias:[0-9a-zA-Z\\-]+}", handler.All(existingCollection)).Methods("POST", "DELETE") - apiColls.HandleFunc("/{alias}/posts", handler.All(fetchCollectionPosts)).Methods("GET") + apiColls.HandleFunc("/{alias}/posts", handler.AllReader(fetchCollectionPosts)).Methods("GET") apiColls.HandleFunc("/{alias}/posts", handler.All(newPost)).Methods("POST") - apiColls.HandleFunc("/{alias}/posts/{post}", handler.All(fetchPost)).Methods("GET") + apiColls.HandleFunc("/{alias}/posts/{post}", handler.AllReader(fetchPost)).Methods("GET") apiColls.HandleFunc("/{alias}/posts/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST") - apiColls.HandleFunc("/{alias}/posts/{post}/{property}", handler.All(fetchPostProperty)).Methods("GET") + apiColls.HandleFunc("/{alias}/posts/{post}/{property}", handler.AllReader(fetchPostProperty)).Methods("GET") apiColls.HandleFunc("/{alias}/collect", handler.All(addPost)).Methods("POST") apiColls.HandleFunc("/{alias}/pin", handler.All(pinPost)).Methods("POST") apiColls.HandleFunc("/{alias}/unpin", handler.All(pinPost)).Methods("POST") apiColls.HandleFunc("/{alias}/inbox", handler.All(handleFetchCollectionInbox)).Methods("POST") - apiColls.HandleFunc("/{alias}/outbox", handler.All(handleFetchCollectionOutbox)).Methods("GET") - apiColls.HandleFunc("/{alias}/following", handler.All(handleFetchCollectionFollowing)).Methods("GET") - apiColls.HandleFunc("/{alias}/followers", handler.All(handleFetchCollectionFollowers)).Methods("GET") + apiColls.HandleFunc("/{alias}/outbox", handler.AllReader(handleFetchCollectionOutbox)).Methods("GET") + apiColls.HandleFunc("/{alias}/following", handler.AllReader(handleFetchCollectionFollowing)).Methods("GET") + apiColls.HandleFunc("/{alias}/followers", handler.AllReader(handleFetchCollectionFollowers)).Methods("GET") // Handle posts write.HandleFunc("/api/posts", handler.All(newPost)).Methods("POST") posts := write.PathPrefix("/api/posts/").Subrouter() - posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(fetchPost)).Methods("GET") + posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.AllReader(fetchPost)).Methods("GET") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(existingPost)).Methods("POST", "PUT") posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}", handler.All(deletePost)).Methods("DELETE") - posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}/{property}", handler.All(fetchPostProperty)).Methods("GET") + posts.HandleFunc("/{post:[a-zA-Z0-9]{10}}/{property}", handler.AllReader(fetchPostProperty)).Methods("GET") posts.HandleFunc("/claim", handler.All(addPost)).Methods("POST") posts.HandleFunc("/disperse", handler.All(dispersePost)).Methods("POST")