mirror of
https://github.com/usememos/memos.git
synced 2025-02-14 10:20:49 +01:00
update auth api with session
This commit is contained in:
parent
a08ad0ebab
commit
050c2ccbd5
33
api/auth.go
33
api/auth.go
@ -30,13 +30,10 @@ func handleUserSignUp(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdCookie := &http.Cookie{
|
||||
Name: "user_id",
|
||||
Value: user.Id,
|
||||
Path: "/",
|
||||
MaxAge: 3600 * 24 * 30,
|
||||
}
|
||||
http.SetCookie(w, userIdCookie)
|
||||
session, _ := SessionStore.Get(r, "session")
|
||||
|
||||
session.Values["user_id"] = user.Id
|
||||
session.Save(r, w)
|
||||
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Succeed: true,
|
||||
@ -66,13 +63,10 @@ func handleUserSignIn(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdCookie := &http.Cookie{
|
||||
Name: "user_id",
|
||||
Value: user.Id,
|
||||
Path: "/",
|
||||
MaxAge: 3600 * 24 * 30,
|
||||
}
|
||||
http.SetCookie(w, userIdCookie)
|
||||
session, _ := SessionStore.Get(r, "session")
|
||||
|
||||
session.Values["user_id"] = user.Id
|
||||
session.Save(r, w)
|
||||
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Succeed: true,
|
||||
@ -82,13 +76,10 @@ func handleUserSignIn(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func handleUserSignOut(w http.ResponseWriter, r *http.Request) {
|
||||
userIdCookie := &http.Cookie{
|
||||
Name: "user_id",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: 0,
|
||||
}
|
||||
http.SetCookie(w, userIdCookie)
|
||||
session, _ := SessionStore.Get(r, "session")
|
||||
|
||||
session.Values["user_id"] = ""
|
||||
session.Save(r, w)
|
||||
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
Succeed: true,
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func handleGetMyMemos(w http.ResponseWriter, r *http.Request) {
|
||||
userId, _ := GetUserIdInCookie(r)
|
||||
userId, _ := GetUserIdInSession(r)
|
||||
urlParams := r.URL.Query()
|
||||
deleted := urlParams.Get("deleted")
|
||||
onlyDeletedFlag := deleted == "true"
|
||||
@ -34,7 +34,7 @@ type CreateMemo struct {
|
||||
}
|
||||
|
||||
func handleCreateMemo(w http.ResponseWriter, r *http.Request) {
|
||||
userId, _ := GetUserIdInCookie(r)
|
||||
userId, _ := GetUserIdInSession(r)
|
||||
|
||||
createMemo := CreateMemo{}
|
||||
err := json.NewDecoder(r.Body).Decode(&createMemo)
|
||||
@ -105,6 +105,8 @@ func handleDeleteMemo(w http.ResponseWriter, r *http.Request) {
|
||||
func RegisterMemoRoutes(r *mux.Router) {
|
||||
memoRouter := r.PathPrefix("/api/memo").Subrouter()
|
||||
|
||||
memoRouter.Use(AuthCheckerMiddleWare)
|
||||
|
||||
memoRouter.HandleFunc("/all", handleGetMyMemos).Methods("GET")
|
||||
memoRouter.HandleFunc("/", handleCreateMemo).Methods("PUT")
|
||||
memoRouter.HandleFunc("/{id}", handleUpdateMemo).Methods("PATCH")
|
||||
|
@ -7,9 +7,9 @@ import (
|
||||
|
||||
func AuthCheckerMiddleWare(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
userId, err := GetUserIdInCookie(r)
|
||||
session, _ := SessionStore.Get(r, "session")
|
||||
|
||||
if err != nil || userId == "" {
|
||||
if userId, ok := session.Values["user_id"].(string); !ok || userId == "" {
|
||||
e.ErrorHandler(w, "NOT_AUTH", "Need authorize")
|
||||
return
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func handleGetMyQueries(w http.ResponseWriter, r *http.Request) {
|
||||
userId, _ := GetUserIdInCookie(r)
|
||||
userId, _ := GetUserIdInSession(r)
|
||||
|
||||
queries, err := store.GetQueriesByUserId(userId)
|
||||
|
||||
@ -32,7 +32,7 @@ type QueryPut struct {
|
||||
}
|
||||
|
||||
func handleCreateQuery(w http.ResponseWriter, r *http.Request) {
|
||||
userId, _ := GetUserIdInCookie(r)
|
||||
userId, _ := GetUserIdInSession(r)
|
||||
|
||||
queryPut := QueryPut{}
|
||||
err := json.NewDecoder(r.Body).Decode(&queryPut)
|
||||
@ -103,6 +103,8 @@ func handleDeleteQuery(w http.ResponseWriter, r *http.Request) {
|
||||
func RegisterQueryRoutes(r *mux.Router) {
|
||||
queryRouter := r.PathPrefix("/api/query").Subrouter()
|
||||
|
||||
queryRouter.Use(AuthCheckerMiddleWare)
|
||||
|
||||
queryRouter.HandleFunc("/all", handleGetMyQueries).Methods("GET")
|
||||
queryRouter.HandleFunc("/", handleCreateQuery).Methods("PUT")
|
||||
queryRouter.HandleFunc("/{id}", handleUpdateQuery).Methods("PATCH")
|
||||
|
9
api/session.go
Normal file
9
api/session.go
Normal file
@ -0,0 +1,9 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"memos/common"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
var SessionStore = sessions.NewCookieStore([]byte(common.GenUUID()))
|
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
userId, _ := GetUserIdInCookie(r)
|
||||
userId, _ := GetUserIdInSession(r)
|
||||
|
||||
user, err := store.GetUserById(userId)
|
||||
|
||||
@ -27,7 +27,7 @@ func handleGetMyUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func handleUpdateMyUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
userId, _ := GetUserIdInCookie(r)
|
||||
userId, _ := GetUserIdInSession(r)
|
||||
|
||||
userPatch := store.UserPatch{}
|
||||
err := json.NewDecoder(r.Body).Decode(&userPatch)
|
||||
@ -83,7 +83,7 @@ type ValidPassword struct {
|
||||
}
|
||||
|
||||
func handleValidPassword(w http.ResponseWriter, r *http.Request) {
|
||||
userId, _ := GetUserIdInCookie(r)
|
||||
userId, _ := GetUserIdInSession(r)
|
||||
validPassword := ValidPassword{}
|
||||
err := json.NewDecoder(r.Body).Decode(&validPassword)
|
||||
|
||||
|
12
api/utils.go
12
api/utils.go
@ -10,12 +10,14 @@ type Response struct {
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
func GetUserIdInCookie(r *http.Request) (string, error) {
|
||||
userIdCookie, err := r.Cookie("user_id")
|
||||
func GetUserIdInSession(r *http.Request) (string, error) {
|
||||
session, _ := SessionStore.Get(r, "session")
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
userId, ok := session.Values["user_id"].(string)
|
||||
|
||||
if !ok {
|
||||
return "", http.ErrNoCookie
|
||||
}
|
||||
|
||||
return userIdCookie.Value, err
|
||||
return userId, nil
|
||||
}
|
||||
|
5
go.mod
5
go.mod
@ -7,3 +7,8 @@ require github.com/gorilla/mux v1.8.0
|
||||
require github.com/mattn/go-sqlite3 v1.14.9
|
||||
|
||||
require github.com/google/uuid v1.3.0
|
||||
|
||||
require (
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gorilla/sessions v1.2.1
|
||||
)
|
||||
|
4
go.sum
4
go.sum
@ -2,5 +2,9 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
|
||||
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
|
@ -1,26 +1,41 @@
|
||||
/*
|
||||
* Re-create tables and insert initial data(todo)
|
||||
* Re-create tables and insert initial data
|
||||
*/
|
||||
DROP TABLE IF EXISTS `users`;
|
||||
CREATE TABLE `users` (
|
||||
`id` TEXT NOT NULL PRIMARY KEY,
|
||||
`username` TEXT NOT NULL,
|
||||
`password` TEXT NOT NULL,
|
||||
`github_name` TEXT NULL DEFAULT '',
|
||||
`wx_open_id` TEXT NULL DEFAULT '',
|
||||
`github_name` TEXT DEFAULT '',
|
||||
`wx_open_id` TEXT DEFAULT '',
|
||||
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
INSERT INTO `users`
|
||||
(`id`, `username`, `password`)
|
||||
VALUES
|
||||
('0', 'admin', '123456'),
|
||||
('1', 'guest', '123456');
|
||||
|
||||
DROP TABLE IF EXISTS `memos`;
|
||||
CREATE TABLE `memos` (
|
||||
`id` TEXT NOT NULL PRIMARY KEY,
|
||||
`content` TEXT NOT NULL,
|
||||
`user_id` TEXT NOT NULL,
|
||||
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`deleted_at` TEXT,
|
||||
`deleted_at` TEXT DEFAULT '',
|
||||
FOREIGN KEY(`user_id`) REFERENCES `users`(`id`)
|
||||
);
|
||||
|
||||
INSERT INTO `memos`
|
||||
(`id`, `content`, `user_id`, )
|
||||
VALUES
|
||||
('0', '👋 Welcome to memos', '0'),
|
||||
('1', '👋 Welcome to memos', '1');
|
||||
|
||||
DROP TABLE IF EXISTS `queries`;
|
||||
CREATE TABLE `queries` (
|
||||
`id` TEXT NOT NULL PRIMARY KEY,
|
||||
`user_id` TEXT NOT NULL,
|
||||
@ -28,6 +43,6 @@ CREATE TABLE `queries` (
|
||||
`querystring` TEXT NOT NULL,
|
||||
`created_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`updated_at` TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
`pinned_at` TEXT NULL,
|
||||
`pinned_at` TEXT DEFAULT '',
|
||||
FOREIGN KEY(`user_id`) REFERENCES `users`(`id`)
|
||||
);
|
||||
|
Binary file not shown.
50
store/db.go
Normal file
50
store/db.go
Normal file
@ -0,0 +1,50 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
/*
|
||||
* Use a global variable to save the db connection: Quick and easy to setup.
|
||||
* Reference: https://techinscribed.com/different-approaches-to-pass-database-connection-into-controllers-in-golang/
|
||||
*/
|
||||
var DB *sql.DB
|
||||
|
||||
func InitDBConn() {
|
||||
dbFilePath := "/data/memos.db"
|
||||
|
||||
if _, err := os.Stat(dbFilePath); err != nil {
|
||||
dbFilePath = "./resources/memos.db"
|
||||
resetDataInDefaultDatabase()
|
||||
println("use the default database")
|
||||
} else {
|
||||
println("use the custom database")
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", dbFilePath)
|
||||
|
||||
if err != nil {
|
||||
println("connect failed")
|
||||
} else {
|
||||
DB = db
|
||||
println("connect to sqlite succeed")
|
||||
}
|
||||
}
|
||||
|
||||
func FormatDBError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err.Error() {
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func resetDataInDefaultDatabase() {
|
||||
// do nth
|
||||
}
|
@ -65,34 +65,36 @@ func DeleteMemo(memoId string) (error, error) {
|
||||
}
|
||||
|
||||
func GetMemoById(id string) (Memo, error) {
|
||||
query := `SELECT id, content, user_id, deleted_at, created_at, updated_at FROM memos WHERE id=?`
|
||||
query := `SELECT id, content, deleted_at, created_at, updated_at FROM memos WHERE id=?`
|
||||
memo := Memo{}
|
||||
err := DB.QueryRow(query, id).Scan(&memo.Id, &memo.Content, &memo.UserId, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
|
||||
err := DB.QueryRow(query, id).Scan(&memo.Id, &memo.Content, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
|
||||
return memo, err
|
||||
}
|
||||
|
||||
func GetMemosByUserId(userId string, deleted bool) ([]Memo, error) {
|
||||
query := `SELECT id, content, user_id, deleted_at, created_at, updated_at FROM memos WHERE user_id=?`
|
||||
func GetMemosByUserId(userId string, onlyDeleted bool) ([]Memo, error) {
|
||||
sqlQuery := `SELECT id, content, deleted_at, created_at, updated_at FROM memos WHERE user_id=?`
|
||||
|
||||
if deleted {
|
||||
query = query + ` AND deleted_at!=""`
|
||||
if onlyDeleted {
|
||||
sqlQuery = sqlQuery + ` AND deleted_at!=""`
|
||||
} else {
|
||||
query = query + ` AND deleted_at=""`
|
||||
sqlQuery = sqlQuery + ` AND deleted_at=""`
|
||||
}
|
||||
|
||||
rows, _ := DB.Query(query, userId)
|
||||
rows, _ := DB.Query(sqlQuery, userId)
|
||||
defer rows.Close()
|
||||
|
||||
memos := []Memo{}
|
||||
|
||||
for rows.Next() {
|
||||
memo := Memo{}
|
||||
rows.Scan(&memo.Id, &memo.Content, &memo.UserId, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
|
||||
rows.Scan(&memo.Id, &memo.Content, &memo.DeletedAt, &memo.CreatedAt, &memo.UpdatedAt)
|
||||
|
||||
memos = append(memos, memo)
|
||||
}
|
||||
|
||||
err := rows.Err()
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return memos, err
|
||||
return memos, nil
|
||||
}
|
||||
|
@ -27,8 +27,8 @@ func CreateNewQuery(title string, querystring string, userId string) (Query, err
|
||||
UpdatedAt: nowDateTimeStr,
|
||||
}
|
||||
|
||||
query := `INSERT INTO queries (id, title, querystring, user_id, pinned_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
_, err := DB.Exec(query, newQuery.Id, newQuery.Title, newQuery.Querystring, newQuery.UserId, newQuery.PinnedAt, newQuery.CreatedAt, newQuery.UpdatedAt)
|
||||
sqlQuery := `INSERT INTO queries (id, title, querystring, user_id, pinned_at, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
_, err := DB.Exec(sqlQuery, newQuery.Id, newQuery.Title, newQuery.Querystring, newQuery.UserId, newQuery.PinnedAt, newQuery.CreatedAt, newQuery.UpdatedAt)
|
||||
|
||||
return newQuery, err
|
||||
}
|
||||
@ -72,14 +72,14 @@ func DeleteQuery(queryId string) (error, error) {
|
||||
}
|
||||
|
||||
func GetQueryById(queryId string) (Query, error) {
|
||||
sqlQuery := `SELECT id, title, querystring, user_id, pinned_at, created_at, updated_at FROM queries WHERE id=?`
|
||||
sqlQuery := `SELECT id, title, querystring, pinned_at, created_at, updated_at FROM queries WHERE id=?`
|
||||
query := Query{}
|
||||
err := DB.QueryRow(sqlQuery, queryId).Scan(&query.Id, &query.Title, &query.Querystring, &query.UserId, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
|
||||
err := DB.QueryRow(sqlQuery, queryId).Scan(&query.Id, &query.Title, &query.Querystring, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
|
||||
return query, err
|
||||
}
|
||||
|
||||
func GetQueriesByUserId(userId string) ([]Query, error) {
|
||||
query := `SELECT id, title, querystring, user_id, pinned_at, created_at, updated_at FROM queries WHERE user_id=?`
|
||||
query := `SELECT id, title, querystring, pinned_at, created_at, updated_at FROM queries WHERE user_id=?`
|
||||
|
||||
rows, _ := DB.Query(query, userId)
|
||||
defer rows.Close()
|
||||
@ -88,12 +88,14 @@ func GetQueriesByUserId(userId string) ([]Query, error) {
|
||||
|
||||
for rows.Next() {
|
||||
query := Query{}
|
||||
rows.Scan(&query.Id, &query.Title, &query.Querystring, &query.UserId, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
|
||||
rows.Scan(&query.Id, &query.Title, &query.Querystring, &query.PinnedAt, &query.CreatedAt, &query.UpdatedAt)
|
||||
|
||||
queries = append(queries, query)
|
||||
}
|
||||
|
||||
err := rows.Err()
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return queries, err
|
||||
return queries, nil
|
||||
}
|
||||
|
@ -1,31 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
var DB *sql.DB
|
||||
|
||||
func InitDBConn() {
|
||||
db, err := sql.Open("sqlite3", "./resources/memos.db")
|
||||
if err != nil {
|
||||
fmt.Println("connect failed")
|
||||
} else {
|
||||
DB = db
|
||||
fmt.Println("connect to sqlite succeed")
|
||||
}
|
||||
}
|
||||
|
||||
func FormatDBError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch err.Error() {
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user