update auth api with session

This commit is contained in:
steven 2021-12-10 13:41:17 +08:00
parent a08ad0ebab
commit 050c2ccbd5
15 changed files with 143 additions and 90 deletions

View File

@ -30,13 +30,10 @@ func handleUserSignUp(w http.ResponseWriter, r *http.Request) {
return return
} }
userIdCookie := &http.Cookie{ session, _ := SessionStore.Get(r, "session")
Name: "user_id",
Value: user.Id, session.Values["user_id"] = user.Id
Path: "/", session.Save(r, w)
MaxAge: 3600 * 24 * 30,
}
http.SetCookie(w, userIdCookie)
json.NewEncoder(w).Encode(Response{ json.NewEncoder(w).Encode(Response{
Succeed: true, Succeed: true,
@ -66,13 +63,10 @@ func handleUserSignIn(w http.ResponseWriter, r *http.Request) {
return return
} }
userIdCookie := &http.Cookie{ session, _ := SessionStore.Get(r, "session")
Name: "user_id",
Value: user.Id, session.Values["user_id"] = user.Id
Path: "/", session.Save(r, w)
MaxAge: 3600 * 24 * 30,
}
http.SetCookie(w, userIdCookie)
json.NewEncoder(w).Encode(Response{ json.NewEncoder(w).Encode(Response{
Succeed: true, Succeed: true,
@ -82,13 +76,10 @@ func handleUserSignIn(w http.ResponseWriter, r *http.Request) {
} }
func handleUserSignOut(w http.ResponseWriter, r *http.Request) { func handleUserSignOut(w http.ResponseWriter, r *http.Request) {
userIdCookie := &http.Cookie{ session, _ := SessionStore.Get(r, "session")
Name: "user_id",
Value: "", session.Values["user_id"] = ""
Path: "/", session.Save(r, w)
MaxAge: 0,
}
http.SetCookie(w, userIdCookie)
json.NewEncoder(w).Encode(Response{ json.NewEncoder(w).Encode(Response{
Succeed: true, Succeed: true,

View File

@ -10,7 +10,7 @@ import (
) )
func handleGetMyMemos(w http.ResponseWriter, r *http.Request) { func handleGetMyMemos(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r) userId, _ := GetUserIdInSession(r)
urlParams := r.URL.Query() urlParams := r.URL.Query()
deleted := urlParams.Get("deleted") deleted := urlParams.Get("deleted")
onlyDeletedFlag := deleted == "true" onlyDeletedFlag := deleted == "true"
@ -34,7 +34,7 @@ type CreateMemo struct {
} }
func handleCreateMemo(w http.ResponseWriter, r *http.Request) { func handleCreateMemo(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r) userId, _ := GetUserIdInSession(r)
createMemo := CreateMemo{} createMemo := CreateMemo{}
err := json.NewDecoder(r.Body).Decode(&createMemo) err := json.NewDecoder(r.Body).Decode(&createMemo)
@ -105,6 +105,8 @@ func handleDeleteMemo(w http.ResponseWriter, r *http.Request) {
func RegisterMemoRoutes(r *mux.Router) { func RegisterMemoRoutes(r *mux.Router) {
memoRouter := r.PathPrefix("/api/memo").Subrouter() memoRouter := r.PathPrefix("/api/memo").Subrouter()
memoRouter.Use(AuthCheckerMiddleWare)
memoRouter.HandleFunc("/all", handleGetMyMemos).Methods("GET") memoRouter.HandleFunc("/all", handleGetMyMemos).Methods("GET")
memoRouter.HandleFunc("/", handleCreateMemo).Methods("PUT") memoRouter.HandleFunc("/", handleCreateMemo).Methods("PUT")
memoRouter.HandleFunc("/{id}", handleUpdateMemo).Methods("PATCH") memoRouter.HandleFunc("/{id}", handleUpdateMemo).Methods("PATCH")

View File

@ -7,9 +7,9 @@ import (
func AuthCheckerMiddleWare(next http.Handler) http.Handler { func AuthCheckerMiddleWare(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userId, err := GetUserIdInCookie(r) session, _ := SessionStore.Get(r, "session")
if err != nil || userId == "" { if userId, ok := session.Values["user_id"].(string); !ok || userId == "" {
e.ErrorHandler(w, "NOT_AUTH", "Need authorize") e.ErrorHandler(w, "NOT_AUTH", "Need authorize")
return return
} }

View File

@ -10,7 +10,7 @@ import (
) )
func handleGetMyQueries(w http.ResponseWriter, r *http.Request) { func handleGetMyQueries(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r) userId, _ := GetUserIdInSession(r)
queries, err := store.GetQueriesByUserId(userId) queries, err := store.GetQueriesByUserId(userId)
@ -32,7 +32,7 @@ type QueryPut struct {
} }
func handleCreateQuery(w http.ResponseWriter, r *http.Request) { func handleCreateQuery(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r) userId, _ := GetUserIdInSession(r)
queryPut := QueryPut{} queryPut := QueryPut{}
err := json.NewDecoder(r.Body).Decode(&queryPut) err := json.NewDecoder(r.Body).Decode(&queryPut)
@ -103,6 +103,8 @@ func handleDeleteQuery(w http.ResponseWriter, r *http.Request) {
func RegisterQueryRoutes(r *mux.Router) { func RegisterQueryRoutes(r *mux.Router) {
queryRouter := r.PathPrefix("/api/query").Subrouter() queryRouter := r.PathPrefix("/api/query").Subrouter()
queryRouter.Use(AuthCheckerMiddleWare)
queryRouter.HandleFunc("/all", handleGetMyQueries).Methods("GET") queryRouter.HandleFunc("/all", handleGetMyQueries).Methods("GET")
queryRouter.HandleFunc("/", handleCreateQuery).Methods("PUT") queryRouter.HandleFunc("/", handleCreateQuery).Methods("PUT")
queryRouter.HandleFunc("/{id}", handleUpdateQuery).Methods("PATCH") queryRouter.HandleFunc("/{id}", handleUpdateQuery).Methods("PATCH")

9
api/session.go Normal file
View File

@ -0,0 +1,9 @@
package api
import (
"memos/common"
"github.com/gorilla/sessions"
)
var SessionStore = sessions.NewCookieStore([]byte(common.GenUUID()))

View File

@ -10,7 +10,7 @@ import (
) )
func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) { func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r) userId, _ := GetUserIdInSession(r)
user, err := store.GetUserById(userId) user, err := store.GetUserById(userId)
@ -27,7 +27,7 @@ func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) {
} }
func handleUpdateMyUserInfo(w http.ResponseWriter, r *http.Request) { func handleUpdateMyUserInfo(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r) userId, _ := GetUserIdInSession(r)
userPatch := store.UserPatch{} userPatch := store.UserPatch{}
err := json.NewDecoder(r.Body).Decode(&userPatch) err := json.NewDecoder(r.Body).Decode(&userPatch)
@ -83,7 +83,7 @@ type ValidPassword struct {
} }
func handleValidPassword(w http.ResponseWriter, r *http.Request) { func handleValidPassword(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r) userId, _ := GetUserIdInSession(r)
validPassword := ValidPassword{} validPassword := ValidPassword{}
err := json.NewDecoder(r.Body).Decode(&validPassword) err := json.NewDecoder(r.Body).Decode(&validPassword)

View File

@ -10,12 +10,14 @@ type Response struct {
Data interface{} `json:"data"` Data interface{} `json:"data"`
} }
func GetUserIdInCookie(r *http.Request) (string, error) { func GetUserIdInSession(r *http.Request) (string, error) {
userIdCookie, err := r.Cookie("user_id") session, _ := SessionStore.Get(r, "session")
if err != nil { userId, ok := session.Values["user_id"].(string)
return "", err
if !ok {
return "", http.ErrNoCookie
} }
return userIdCookie.Value, err return userId, nil
} }

5
go.mod
View File

@ -7,3 +7,8 @@ require github.com/gorilla/mux v1.8.0
require github.com/mattn/go-sqlite3 v1.14.9 require github.com/mattn/go-sqlite3 v1.14.9
require github.com/google/uuid v1.3.0 require github.com/google/uuid v1.3.0
require (
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1
)

4
go.sum
View File

@ -2,5 +2,9 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=

View File

@ -1,26 +1,41 @@
/* /*
* Re-create tables and insert initial data(todo) * Re-create tables and insert initial data
*/ */
DROP TABLE IF EXISTS `users`;
CREATE TABLE `users` ( CREATE TABLE `users` (
`id` TEXT NOT NULL PRIMARY KEY, `id` TEXT NOT NULL PRIMARY KEY,
`username` TEXT NOT NULL, `username` TEXT NOT NULL,
`password` TEXT NOT NULL, `password` TEXT NOT NULL,
`github_name` TEXT NULL DEFAULT '', `github_name` TEXT DEFAULT '',
`wx_open_id` TEXT NULL DEFAULT '', `wx_open_id` TEXT DEFAULT '',
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, `created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP `updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
); );
INSERT INTO `users`
(`id`, `username`, `password`)
VALUES
('0', 'admin', '123456'),
('1', 'guest', '123456');
DROP TABLE IF EXISTS `memos`;
CREATE TABLE `memos` ( CREATE TABLE `memos` (
`id` TEXT NOT NULL PRIMARY KEY, `id` TEXT NOT NULL PRIMARY KEY,
`content` TEXT NOT NULL, `content` TEXT NOT NULL,
`user_id` TEXT NOT NULL, `user_id` TEXT NOT NULL,
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, `created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`deleted_at` TEXT, `deleted_at` TEXT DEFAULT '',
FOREIGN KEY(`user_id`) REFERENCES `users`(`id`) FOREIGN KEY(`user_id`) REFERENCES `users`(`id`)
); );
INSERT INTO `memos`
(`id`, `content`, `user_id`, )
VALUES
('0', '👋 Welcome to memos', '0'),
('1', '👋 Welcome to memos', '1');
DROP TABLE IF EXISTS `queries`;
CREATE TABLE `queries` ( CREATE TABLE `queries` (
`id` TEXT NOT NULL PRIMARY KEY, `id` TEXT NOT NULL PRIMARY KEY,
`user_id` TEXT NOT NULL, `user_id` TEXT NOT NULL,
@ -28,6 +43,6 @@ CREATE TABLE `queries` (
`querystring` TEXT NOT NULL, `querystring` TEXT NOT NULL,
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, `created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`pinned_at` TEXT NULL, `pinned_at` TEXT DEFAULT '',
FOREIGN KEY(`user_id`) REFERENCES `users`(`id`) FOREIGN KEY(`user_id`) REFERENCES `users`(`id`)
); );

Binary file not shown.

50
store/db.go Normal file
View File

@ -0,0 +1,50 @@
package store
import (
"database/sql"
"os"
_ "github.com/mattn/go-sqlite3"
)
/*
* Use a global variable to save the db connection: Quick and easy to setup.
* Reference: https://techinscribed.com/different-approaches-to-pass-database-connection-into-controllers-in-golang/
*/
var DB *sql.DB
func InitDBConn() {
dbFilePath := "/data/memos.db"
if _, err := os.Stat(dbFilePath); err != nil {
dbFilePath = "./resources/memos.db"
resetDataInDefaultDatabase()
println("use the default database")
} else {
println("use the custom database")
}
db, err := sql.Open("sqlite3", dbFilePath)
if err != nil {
println("connect failed")
} else {
DB = db
println("connect to sqlite succeed")
}
}
func FormatDBError(err error) error {
if err == nil {
return nil
}
switch err.Error() {
default:
return err
}
}
func resetDataInDefaultDatabase() {
// do nth
}

View File

@ -65,34 +65,36 @@ func DeleteMemo(memoId string) (error, error) {
} }
func GetMemoById(id string) (Memo, error) { func GetMemoById(id string) (Memo, error) {
query := `SELECT id, content, user_id, deleted_at, created_at, updated_at FROM memos WHERE id=?` query := `SELECT id, content, deleted_at, created_at, updated_at FROM memos WHERE id=?`
memo := Memo{} memo := Memo{}
err := DB.QueryRow(query, id).Scan(&memo.Id, &memo.Content, &memo.UserId, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt) err := DB.QueryRow(query, id).Scan(&memo.Id, &memo.Content, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
return memo, err return memo, err
} }
func GetMemosByUserId(userId string, deleted bool) ([]Memo, error) { func GetMemosByUserId(userId string, onlyDeleted bool) ([]Memo, error) {
query := `SELECT id, content, user_id, deleted_at, created_at, updated_at FROM memos WHERE user_id=?` sqlQuery := `SELECT id, content, deleted_at, created_at, updated_at FROM memos WHERE user_id=?`
if deleted { if onlyDeleted {
query = query + ` AND deleted_at!=""` sqlQuery = sqlQuery + ` AND deleted_at!=""`
} else { } else {
query = query + ` AND deleted_at=""` sqlQuery = sqlQuery + ` AND deleted_at=""`
} }
rows, _ := DB.Query(query, userId) rows, _ := DB.Query(sqlQuery, userId)
defer rows.Close() defer rows.Close()
memos := []Memo{} memos := []Memo{}
for rows.Next() { for rows.Next() {
memo := Memo{} memo := Memo{}
rows.Scan(&memo.Id, &memo.Content, &memo.UserId, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt) rows.Scan(&memo.Id, &memo.Content, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
memos = append(memos, memo) memos = append(memos, memo)
} }
err := rows.Err() if err := rows.Err(); err != nil {
return nil, err
}
return memos, err return memos, nil
} }

View File

@ -27,8 +27,8 @@ func CreateNewQuery(title string, querystring string, userId string) (Query, err
UpdatedAt: nowDateTimeStr, UpdatedAt: nowDateTimeStr,
} }
query := `INSERT INTO queries (id, title, querystring, user_id, pinned_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)` sqlQuery := `INSERT INTO queries (id, title, querystring, user_id, pinned_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`
_, err := DB.Exec(query, newQuery.Id, newQuery.Title, newQuery.Querystring, newQuery.UserId, newQuery.PinnedAt, newQuery.CreatedAt, newQuery.UpdatedAt) _, err := DB.Exec(sqlQuery, newQuery.Id, newQuery.Title, newQuery.Querystring, newQuery.UserId, newQuery.PinnedAt, newQuery.CreatedAt, newQuery.UpdatedAt)
return newQuery, err return newQuery, err
} }
@ -72,14 +72,14 @@ func DeleteQuery(queryId string) (error, error) {
} }
func GetQueryById(queryId string) (Query, error) { func GetQueryById(queryId string) (Query, error) {
sqlQuery := `SELECT id, title, querystring, user_id, pinned_at, created_at, updated_at FROM queries WHERE id=?` sqlQuery := `SELECT id, title, querystring, pinned_at, created_at, updated_at FROM queries WHERE id=?`
query := Query{} query := Query{}
err := DB.QueryRow(sqlQuery, queryId).Scan(&query.Id, &query.Title, &query.Querystring, &query.UserId, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt) err := DB.QueryRow(sqlQuery, queryId).Scan(&query.Id, &query.Title, &query.Querystring, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
return query, err return query, err
} }
func GetQueriesByUserId(userId string) ([]Query, error) { func GetQueriesByUserId(userId string) ([]Query, error) {
query := `SELECT id, title, querystring, user_id, pinned_at, created_at, updated_at FROM queries WHERE user_id=?` query := `SELECT id, title, querystring, pinned_at, created_at, updated_at FROM queries WHERE user_id=?`
rows, _ := DB.Query(query, userId) rows, _ := DB.Query(query, userId)
defer rows.Close() defer rows.Close()
@ -88,12 +88,14 @@ func GetQueriesByUserId(userId string) ([]Query, error) {
for rows.Next() { for rows.Next() {
query := Query{} query := Query{}
rows.Scan(&query.Id, &query.Title, &query.Querystring, &query.UserId, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt) rows.Scan(&query.Id, &query.Title, &query.Querystring, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
queries = append(queries, query) queries = append(queries, query)
} }
err := rows.Err() if err := rows.Err(); err != nil {
return nil, err
}
return queries, err return queries, nil
} }

View File

@ -1,31 +0,0 @@
package store
import (
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3"
)
var DB *sql.DB
func InitDBConn() {
db, err := sql.Open("sqlite3", "./resources/memos.db")
if err != nil {
fmt.Println("connect failed")
} else {
DB = db
fmt.Println("connect to sqlite succeed")
}
}
func FormatDBError(err error) error {
if err == nil {
return nil
}
switch err.Error() {
default:
return err
}
}