update auth api with session

This commit is contained in:
steven 2021-12-10 13:41:17 +08:00
parent a08ad0ebab
commit 050c2ccbd5
15 changed files with 143 additions and 90 deletions

View File

@ -30,13 +30,10 @@ func handleUserSignUp(w http.ResponseWriter, r *http.Request) {
return
}
userIdCookie := &http.Cookie{
Name: "user_id",
Value: user.Id,
Path: "/",
MaxAge: 3600 * 24 * 30,
}
http.SetCookie(w, userIdCookie)
session, _ := SessionStore.Get(r, "session")
session.Values["user_id"] = user.Id
session.Save(r, w)
json.NewEncoder(w).Encode(Response{
Succeed: true,
@ -66,13 +63,10 @@ func handleUserSignIn(w http.ResponseWriter, r *http.Request) {
return
}
userIdCookie := &http.Cookie{
Name: "user_id",
Value: user.Id,
Path: "/",
MaxAge: 3600 * 24 * 30,
}
http.SetCookie(w, userIdCookie)
session, _ := SessionStore.Get(r, "session")
session.Values["user_id"] = user.Id
session.Save(r, w)
json.NewEncoder(w).Encode(Response{
Succeed: true,
@ -82,13 +76,10 @@ func handleUserSignIn(w http.ResponseWriter, r *http.Request) {
}
func handleUserSignOut(w http.ResponseWriter, r *http.Request) {
userIdCookie := &http.Cookie{
Name: "user_id",
Value: "",
Path: "/",
MaxAge: 0,
}
http.SetCookie(w, userIdCookie)
session, _ := SessionStore.Get(r, "session")
session.Values["user_id"] = ""
session.Save(r, w)
json.NewEncoder(w).Encode(Response{
Succeed: true,

View File

@ -10,7 +10,7 @@ import (
)
func handleGetMyMemos(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r)
userId, _ := GetUserIdInSession(r)
urlParams := r.URL.Query()
deleted := urlParams.Get("deleted")
onlyDeletedFlag := deleted == "true"
@ -34,7 +34,7 @@ type CreateMemo struct {
}
func handleCreateMemo(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r)
userId, _ := GetUserIdInSession(r)
createMemo := CreateMemo{}
err := json.NewDecoder(r.Body).Decode(&createMemo)
@ -105,6 +105,8 @@ func handleDeleteMemo(w http.ResponseWriter, r *http.Request) {
func RegisterMemoRoutes(r *mux.Router) {
memoRouter := r.PathPrefix("/api/memo").Subrouter()
memoRouter.Use(AuthCheckerMiddleWare)
memoRouter.HandleFunc("/all", handleGetMyMemos).Methods("GET")
memoRouter.HandleFunc("/", handleCreateMemo).Methods("PUT")
memoRouter.HandleFunc("/{id}", handleUpdateMemo).Methods("PATCH")

View File

@ -7,9 +7,9 @@ import (
func AuthCheckerMiddleWare(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userId, err := GetUserIdInCookie(r)
session, _ := SessionStore.Get(r, "session")
if err != nil || userId == "" {
if userId, ok := session.Values["user_id"].(string); !ok || userId == "" {
e.ErrorHandler(w, "NOT_AUTH", "Need authorize")
return
}

View File

@ -10,7 +10,7 @@ import (
)
func handleGetMyQueries(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r)
userId, _ := GetUserIdInSession(r)
queries, err := store.GetQueriesByUserId(userId)
@ -32,7 +32,7 @@ type QueryPut struct {
}
func handleCreateQuery(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r)
userId, _ := GetUserIdInSession(r)
queryPut := QueryPut{}
err := json.NewDecoder(r.Body).Decode(&queryPut)
@ -103,6 +103,8 @@ func handleDeleteQuery(w http.ResponseWriter, r *http.Request) {
func RegisterQueryRoutes(r *mux.Router) {
queryRouter := r.PathPrefix("/api/query").Subrouter()
queryRouter.Use(AuthCheckerMiddleWare)
queryRouter.HandleFunc("/all", handleGetMyQueries).Methods("GET")
queryRouter.HandleFunc("/", handleCreateQuery).Methods("PUT")
queryRouter.HandleFunc("/{id}", handleUpdateQuery).Methods("PATCH")

9
api/session.go Normal file
View File

@ -0,0 +1,9 @@
package api
import (
"memos/common"
"github.com/gorilla/sessions"
)
var SessionStore = sessions.NewCookieStore([]byte(common.GenUUID()))

View File

@ -10,7 +10,7 @@ import (
)
func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r)
userId, _ := GetUserIdInSession(r)
user, err := store.GetUserById(userId)
@ -27,7 +27,7 @@ func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) {
}
func handleUpdateMyUserInfo(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r)
userId, _ := GetUserIdInSession(r)
userPatch := store.UserPatch{}
err := json.NewDecoder(r.Body).Decode(&userPatch)
@ -83,7 +83,7 @@ type ValidPassword struct {
}
func handleValidPassword(w http.ResponseWriter, r *http.Request) {
userId, _ := GetUserIdInCookie(r)
userId, _ := GetUserIdInSession(r)
validPassword := ValidPassword{}
err := json.NewDecoder(r.Body).Decode(&validPassword)

View File

@ -10,12 +10,14 @@ type Response struct {
Data interface{} `json:"data"`
}
func GetUserIdInCookie(r *http.Request) (string, error) {
userIdCookie, err := r.Cookie("user_id")
func GetUserIdInSession(r *http.Request) (string, error) {
session, _ := SessionStore.Get(r, "session")
if err != nil {
return "", err
userId, ok := session.Values["user_id"].(string)
if !ok {
return "", http.ErrNoCookie
}
return userIdCookie.Value, err
return userId, nil
}

5
go.mod
View File

@ -7,3 +7,8 @@ require github.com/gorilla/mux v1.8.0
require github.com/mattn/go-sqlite3 v1.14.9
require github.com/google/uuid v1.3.0
require (
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1
)

4
go.sum
View File

@ -2,5 +2,9 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=

View File

@ -1,26 +1,41 @@
/*
* Re-create tables and insert initial data(todo)
* Re-create tables and insert initial data
*/
DROP TABLE IF EXISTS `users`;
CREATE TABLE `users` (
`id` TEXT NOT NULL PRIMARY KEY,
`username` TEXT NOT NULL,
`password` TEXT NOT NULL,
`github_name` TEXT NULL DEFAULT '',
`wx_open_id` TEXT NULL DEFAULT '',
`github_name` TEXT DEFAULT '',
`wx_open_id` TEXT DEFAULT '',
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
);
INSERT INTO `users`
(`id`, `username`, `password`)
VALUES
('0', 'admin', '123456'),
('1', 'guest', '123456');
DROP TABLE IF EXISTS `memos`;
CREATE TABLE `memos` (
`id` TEXT NOT NULL PRIMARY KEY,
`content` TEXT NOT NULL,
`user_id` TEXT NOT NULL,
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`deleted_at` TEXT,
`deleted_at` TEXT DEFAULT '',
FOREIGN KEY(`user_id`) REFERENCES `users`(`id`)
);
INSERT INTO `memos`
(`id`, `content`, `user_id`, )
VALUES
('0', '👋 Welcome to memos', '0'),
('1', '👋 Welcome to memos', '1');
DROP TABLE IF EXISTS `queries`;
CREATE TABLE `queries` (
`id` TEXT NOT NULL PRIMARY KEY,
`user_id` TEXT NOT NULL,
@ -28,6 +43,6 @@ CREATE TABLE `queries` (
`querystring` TEXT NOT NULL,
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
`pinned_at` TEXT NULL,
`pinned_at` TEXT DEFAULT '',
FOREIGN KEY(`user_id`) REFERENCES `users`(`id`)
);

Binary file not shown.

50
store/db.go Normal file
View File

@ -0,0 +1,50 @@
package store
import (
"database/sql"
"os"
_ "github.com/mattn/go-sqlite3"
)
/*
* Use a global variable to save the db connection: Quick and easy to setup.
* Reference: https://techinscribed.com/different-approaches-to-pass-database-connection-into-controllers-in-golang/
*/
var DB *sql.DB
func InitDBConn() {
dbFilePath := "/data/memos.db"
if _, err := os.Stat(dbFilePath); err != nil {
dbFilePath = "./resources/memos.db"
resetDataInDefaultDatabase()
println("use the default database")
} else {
println("use the custom database")
}
db, err := sql.Open("sqlite3", dbFilePath)
if err != nil {
println("connect failed")
} else {
DB = db
println("connect to sqlite succeed")
}
}
func FormatDBError(err error) error {
if err == nil {
return nil
}
switch err.Error() {
default:
return err
}
}
func resetDataInDefaultDatabase() {
// do nth
}

View File

@ -65,34 +65,36 @@ func DeleteMemo(memoId string) (error, error) {
}
func GetMemoById(id string) (Memo, error) {
query := `SELECT id, content, user_id, deleted_at, created_at, updated_at FROM memos WHERE id=?`
query := `SELECT id, content, deleted_at, created_at, updated_at FROM memos WHERE id=?`
memo := Memo{}
err := DB.QueryRow(query, id).Scan(&memo.Id, &memo.Content, &memo.UserId, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
err := DB.QueryRow(query, id).Scan(&memo.Id, &memo.Content, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
return memo, err
}
func GetMemosByUserId(userId string, deleted bool) ([]Memo, error) {
query := `SELECT id, content, user_id, deleted_at, created_at, updated_at FROM memos WHERE user_id=?`
func GetMemosByUserId(userId string, onlyDeleted bool) ([]Memo, error) {
sqlQuery := `SELECT id, content, deleted_at, created_at, updated_at FROM memos WHERE user_id=?`
if deleted {
query = query + ` AND deleted_at!=""`
if onlyDeleted {
sqlQuery = sqlQuery + ` AND deleted_at!=""`
} else {
query = query + ` AND deleted_at=""`
sqlQuery = sqlQuery + ` AND deleted_at=""`
}
rows, _ := DB.Query(query, userId)
rows, _ := DB.Query(sqlQuery, userId)
defer rows.Close()
memos := []Memo{}
for rows.Next() {
memo := Memo{}
rows.Scan(&memo.Id, &memo.Content, &memo.UserId, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
rows.Scan(&memo.Id, &memo.Content, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
memos = append(memos, memo)
}
err := rows.Err()
if err := rows.Err(); err != nil {
return nil, err
}
return memos, err
return memos, nil
}

View File

@ -27,8 +27,8 @@ func CreateNewQuery(title string, querystring string, userId string) (Query, err
UpdatedAt: nowDateTimeStr,
}
query := `INSERT INTO queries (id, title, querystring, user_id, pinned_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`
_, err := DB.Exec(query, newQuery.Id, newQuery.Title, newQuery.Querystring, newQuery.UserId, newQuery.PinnedAt, newQuery.CreatedAt, newQuery.UpdatedAt)
sqlQuery := `INSERT INTO queries (id, title, querystring, user_id, pinned_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`
_, err := DB.Exec(sqlQuery, newQuery.Id, newQuery.Title, newQuery.Querystring, newQuery.UserId, newQuery.PinnedAt, newQuery.CreatedAt, newQuery.UpdatedAt)
return newQuery, err
}
@ -72,14 +72,14 @@ func DeleteQuery(queryId string) (error, error) {
}
func GetQueryById(queryId string) (Query, error) {
sqlQuery := `SELECT id, title, querystring, user_id, pinned_at, created_at, updated_at FROM queries WHERE id=?`
sqlQuery := `SELECT id, title, querystring, pinned_at, created_at, updated_at FROM queries WHERE id=?`
query := Query{}
err := DB.QueryRow(sqlQuery, queryId).Scan(&query.Id, &query.Title, &query.Querystring, &query.UserId, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
err := DB.QueryRow(sqlQuery, queryId).Scan(&query.Id, &query.Title, &query.Querystring, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
return query, err
}
func GetQueriesByUserId(userId string) ([]Query, error) {
query := `SELECT id, title, querystring, user_id, pinned_at, created_at, updated_at FROM queries WHERE user_id=?`
query := `SELECT id, title, querystring, pinned_at, created_at, updated_at FROM queries WHERE user_id=?`
rows, _ := DB.Query(query, userId)
defer rows.Close()
@ -88,12 +88,14 @@ func GetQueriesByUserId(userId string) ([]Query, error) {
for rows.Next() {
query := Query{}
rows.Scan(&query.Id, &query.Title, &query.Querystring, &query.UserId, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
rows.Scan(&query.Id, &query.Title, &query.Querystring, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
queries = append(queries, query)
}
err := rows.Err()
if err := rows.Err(); err != nil {
return nil, err
}
return queries, err
return queries, nil
}

View File

@ -1,31 +0,0 @@
package store
import (
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3"
)
var DB *sql.DB
func InitDBConn() {
db, err := sql.Open("sqlite3", "./resources/memos.db")
if err != nil {
fmt.Println("connect failed")
} else {
DB = db
fmt.Println("connect to sqlite succeed")
}
}
func FormatDBError(err error) error {
if err == nil {
return nil
}
switch err.Error() {
default:
return err
}
}