mirror of
https://github.com/usememos/memos.git
synced 2025-06-05 22:09:59 +02:00
chore: use tx
for user store
This commit is contained in:
@ -1,9 +1,5 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
var (
|
|
||||||
UNKNOWN_ID = 0
|
|
||||||
)
|
|
||||||
|
|
||||||
type Signin struct {
|
type Signin struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
|
@ -52,42 +52,44 @@ func removeUserSession(ctx echo.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
|
func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(ctx echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
|
path := c.Path()
|
||||||
// Skip auth.
|
// Skip auth.
|
||||||
if common.HasPrefixes(ctx.Path(), "/api/auth") {
|
if common.HasPrefixes(path, "/api/auth") {
|
||||||
return next(ctx)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.HasPrefixes(ctx.Path(), "/api/ping", "/api/status", "/api/user/:id") && ctx.Request().Method == http.MethodGet {
|
if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id") && c.Request().Method == http.MethodGet {
|
||||||
return next(ctx)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is openId in query string and related user is found, then skip auth.
|
// If there is openId in query string and related user is found, then skip auth.
|
||||||
openID := ctx.QueryParam("openId")
|
openID := c.QueryParam("openId")
|
||||||
if openID != "" {
|
if openID != "" {
|
||||||
userFind := &api.UserFind{
|
userFind := &api.UserFind{
|
||||||
OpenID: &openID,
|
OpenID: &openID,
|
||||||
}
|
}
|
||||||
user, err := s.Store.FindUser(userFind)
|
user, err := s.Store.FindUser(ctx, userFind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by open_id").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by open_id").SetInternal(err)
|
||||||
}
|
}
|
||||||
if user != nil {
|
if user != nil {
|
||||||
// Stores userID into context.
|
// Stores userID into context.
|
||||||
ctx.Set(getUserIDContextKey(), user.ID)
|
c.Set(getUserIDContextKey(), user.ID)
|
||||||
return next(ctx)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
sess, _ := session.Get("session", ctx)
|
sess, _ := session.Get("session", c)
|
||||||
userIDValue := sess.Values[userIDContextKey]
|
userIDValue := sess.Values[userIDContextKey]
|
||||||
if userIDValue != nil {
|
if userIDValue != nil {
|
||||||
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
|
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
|
||||||
userFind := &api.UserFind{
|
userFind := &api.UserFind{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
}
|
}
|
||||||
user, err := s.Store.FindUser(userFind)
|
user, err := s.Store.FindUser(ctx, userFind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
|
||||||
}
|
}
|
||||||
@ -95,22 +97,22 @@ func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
|
|||||||
if user.RowStatus == api.Archived {
|
if user.RowStatus == api.Archived {
|
||||||
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
|
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
|
||||||
}
|
}
|
||||||
ctx.Set(getUserIDContextKey(), userID)
|
c.Set(getUserIDContextKey(), userID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet {
|
if common.HasPrefixes(path, "/api/memo", "/api/tag", "/api/shortcut") && c.Request().Method == http.MethodGet {
|
||||||
if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil {
|
if _, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
|
||||||
return next(ctx)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := ctx.Get(getUserIDContextKey())
|
userID := c.Get(getUserIDContextKey())
|
||||||
if userID == nil {
|
if userID == nil {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
|
|
||||||
return next(ctx)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
func (s *Server) registerAuthRoutes(g *echo.Group) {
|
func (s *Server) registerAuthRoutes(g *echo.Group) {
|
||||||
g.POST("/auth/signin", func(c echo.Context) error {
|
g.POST("/auth/signin", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
signin := &api.Signin{}
|
signin := &api.Signin{}
|
||||||
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
|
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
|
||||||
@ -22,7 +23,7 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
|
|||||||
userFind := &api.UserFind{
|
userFind := &api.UserFind{
|
||||||
Email: &signin.Email,
|
Email: &signin.Email,
|
||||||
}
|
}
|
||||||
user, err := s.Store.FindUser(userFind)
|
user, err := s.Store.FindUser(ctx, userFind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by email %s", signin.Email)).SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by email %s", signin.Email)).SetInternal(err)
|
||||||
}
|
}
|
||||||
@ -60,12 +61,13 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
g.POST("/auth/signup", func(c echo.Context) error {
|
g.POST("/auth/signup", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
// Don't allow to signup by this api if site host existed.
|
// Don't allow to signup by this api if site host existed.
|
||||||
hostUserType := api.Host
|
hostUserType := api.Host
|
||||||
hostUserFind := api.UserFind{
|
hostUserFind := api.UserFind{
|
||||||
Role: &hostUserType,
|
Role: &hostUserType,
|
||||||
}
|
}
|
||||||
hostUser, err := s.Store.FindUser(&hostUserFind)
|
hostUser, err := s.Store.FindUser(ctx, &hostUserFind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err)
|
||||||
}
|
}
|
||||||
@ -99,7 +101,7 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
|
|||||||
PasswordHash: string(passwordHash),
|
PasswordHash: string(passwordHash),
|
||||||
OpenID: common.GenUUID(),
|
OpenID: common.GenUUID(),
|
||||||
}
|
}
|
||||||
user, err := s.Store.CreateUser(userCreate)
|
user, err := s.Store.CreateUser(ctx, userCreate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
@ -21,11 +21,12 @@ func (s *Server) registerSystemRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
g.GET("/status", func(c echo.Context) error {
|
g.GET("/status", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
hostUserType := api.Host
|
hostUserType := api.Host
|
||||||
hostUserFind := api.UserFind{
|
hostUserFind := api.UserFind{
|
||||||
Role: &hostUserType,
|
Role: &hostUserType,
|
||||||
}
|
}
|
||||||
hostUser, err := s.Store.FindUser(&hostUserFind)
|
hostUser, err := s.Store.FindUser(ctx, &hostUserFind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find host user").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
func (s *Server) registerUserRoutes(g *echo.Group) {
|
func (s *Server) registerUserRoutes(g *echo.Group) {
|
||||||
g.POST("/user", func(c echo.Context) error {
|
g.POST("/user", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
userCreate := &api.UserCreate{}
|
userCreate := &api.UserCreate{}
|
||||||
if err := json.NewDecoder(c.Request().Body).Decode(userCreate); err != nil {
|
if err := json.NewDecoder(c.Request().Body).Decode(userCreate); err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user request").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user request").SetInternal(err)
|
||||||
@ -26,7 +27,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userCreate.PasswordHash = string(passwordHash)
|
userCreate.PasswordHash = string(passwordHash)
|
||||||
user, err := s.Store.CreateUser(userCreate)
|
user, err := s.Store.CreateUser(ctx, userCreate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
|
||||||
}
|
}
|
||||||
@ -39,7 +40,8 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
g.GET("/user", func(c echo.Context) error {
|
g.GET("/user", func(c echo.Context) error {
|
||||||
userList, err := s.Store.FindUserList(&api.UserFind{})
|
ctx := c.Request().Context()
|
||||||
|
userList, err := s.Store.FindUserList(ctx, &api.UserFind{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user list").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user list").SetInternal(err)
|
||||||
}
|
}
|
||||||
@ -57,12 +59,13 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
g.GET("/user/:id", func(c echo.Context) error {
|
g.GET("/user/:id", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted user id").SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted user id").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.FindUser(&api.UserFind{
|
user, err := s.Store.FindUser(ctx, &api.UserFind{
|
||||||
ID: &id,
|
ID: &id,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -83,6 +86,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
|
|
||||||
// GET /api/user/me is used to check if the user is logged in.
|
// GET /api/user/me is used to check if the user is logged in.
|
||||||
g.GET("/user/me", func(c echo.Context) error {
|
g.GET("/user/me", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
userID, ok := c.Get(getUserIDContextKey()).(int)
|
userID, ok := c.Get(getUserIDContextKey()).(int)
|
||||||
if !ok {
|
if !ok {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
|
||||||
@ -91,7 +95,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
userFind := &api.UserFind{
|
userFind := &api.UserFind{
|
||||||
ID: &userID,
|
ID: &userID,
|
||||||
}
|
}
|
||||||
user, err := s.Store.FindUser(userFind)
|
user, err := s.Store.FindUser(ctx, userFind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch user").SetInternal(err)
|
||||||
}
|
}
|
||||||
@ -104,6 +108,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
g.PATCH("/user/:id", func(c echo.Context) error {
|
g.PATCH("/user/:id", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
userID, err := strconv.Atoi(c.Param("id"))
|
userID, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
|
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err)
|
||||||
@ -112,7 +117,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
currentUser, err := s.Store.FindUser(&api.UserFind{
|
currentUser, err := s.Store.FindUser(ctx, &api.UserFind{
|
||||||
ID: ¤tUserID,
|
ID: ¤tUserID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -146,7 +151,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
userPatch.OpenID = &openID
|
userPatch.OpenID = &openID
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.Store.PatchUser(userPatch)
|
user, err := s.Store.PatchUser(ctx, userPatch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch user").SetInternal(err)
|
||||||
}
|
}
|
||||||
@ -159,11 +164,12 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
g.DELETE("/user/:id", func(c echo.Context) error {
|
g.DELETE("/user/:id", func(c echo.Context) error {
|
||||||
|
ctx := c.Request().Context()
|
||||||
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
|
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
|
||||||
if !ok {
|
if !ok {
|
||||||
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
|
||||||
}
|
}
|
||||||
currentUser, err := s.Store.FindUser(&api.UserFind{
|
currentUser, err := s.Store.FindUser(ctx, &api.UserFind{
|
||||||
ID: ¤tUserID,
|
ID: ¤tUserID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -183,7 +189,7 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
|
|||||||
userDelete := &api.UserDelete{
|
userDelete := &api.UserDelete{
|
||||||
ID: userID,
|
ID: userID,
|
||||||
}
|
}
|
||||||
if err := s.Store.DeleteUser(userDelete); err != nil {
|
if err := s.Store.DeleteUser(ctx, userDelete); err != nil {
|
||||||
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete user").SetInternal(err)
|
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete user").SetInternal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
110
store/user.go
110
store/user.go
@ -1,6 +1,7 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
@ -43,12 +44,22 @@ func (raw *userRaw) toUser() *api.User {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) CreateUser(create *api.UserCreate) (*api.User, error) {
|
func (s *Store) CreateUser(ctx context.Context, create *api.UserCreate) (*api.User, error) {
|
||||||
userRaw, err := createUser(s.db, create)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
userRaw, err := createUser(ctx, tx, create)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
|
||||||
user := userRaw.toUser()
|
user := userRaw.toUser()
|
||||||
|
|
||||||
if err := s.cache.UpsertCache(api.UserCache, user.ID, user); err != nil {
|
if err := s.cache.UpsertCache(api.UserCache, user.ID, user); err != nil {
|
||||||
@ -58,12 +69,22 @@ func (s *Store) CreateUser(create *api.UserCreate) (*api.User, error) {
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) PatchUser(patch *api.UserPatch) (*api.User, error) {
|
func (s *Store) PatchUser(ctx context.Context, patch *api.UserPatch) (*api.User, error) {
|
||||||
userRaw, err := patchUser(s.db, patch)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
userRaw, err := patchUser(ctx, tx, patch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
|
||||||
user := userRaw.toUser()
|
user := userRaw.toUser()
|
||||||
|
|
||||||
if err := s.cache.UpsertCache(api.UserCache, user.ID, user); err != nil {
|
if err := s.cache.UpsertCache(api.UserCache, user.ID, user); err != nil {
|
||||||
@ -73,8 +94,14 @@ func (s *Store) PatchUser(patch *api.UserPatch) (*api.User, error) {
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) FindUserList(find *api.UserFind) ([]*api.User, error) {
|
func (s *Store) FindUserList(ctx context.Context, find *api.UserFind) ([]*api.User, error) {
|
||||||
userRawList, err := findUserList(s.db, find)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
userRawList, err := findUserList(ctx, tx, find)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -87,7 +114,7 @@ func (s *Store) FindUserList(find *api.UserFind) ([]*api.User, error) {
|
|||||||
return list, nil
|
return list, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) FindUser(find *api.UserFind) (*api.User, error) {
|
func (s *Store) FindUser(ctx context.Context, find *api.UserFind) (*api.User, error) {
|
||||||
if find.ID != nil {
|
if find.ID != nil {
|
||||||
user := &api.User{}
|
user := &api.User{}
|
||||||
has, err := s.cache.FindCache(api.UserCache, *find.ID, user)
|
has, err := s.cache.FindCache(api.UserCache, *find.ID, user)
|
||||||
@ -99,7 +126,13 @@ func (s *Store) FindUser(find *api.UserFind) (*api.User, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
list, err := findUserList(s.db, find)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, FormatError(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
list, err := findUserList(ctx, tx, find)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -119,19 +152,29 @@ func (s *Store) FindUser(find *api.UserFind) (*api.User, error) {
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteUser(delete *api.UserDelete) error {
|
func (s *Store) DeleteUser(ctx context.Context, delete *api.UserDelete) error {
|
||||||
err := deleteUser(s.db, delete)
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return FormatError(err)
|
return FormatError(err)
|
||||||
}
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
err = deleteUser(ctx, tx, delete)
|
||||||
|
if err != nil {
|
||||||
|
return FormatError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return FormatError(err)
|
||||||
|
}
|
||||||
|
|
||||||
s.cache.DeleteCache(api.UserCache, delete.ID)
|
s.cache.DeleteCache(api.UserCache, delete.ID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUser(db *sql.DB, create *api.UserCreate) (*userRaw, error) {
|
func createUser(ctx context.Context, tx *sql.Tx, create *api.UserCreate) (*userRaw, error) {
|
||||||
row, err := db.Query(`
|
query := `
|
||||||
INSERT INTO user (
|
INSERT INTO user (
|
||||||
email,
|
email,
|
||||||
role,
|
role,
|
||||||
@ -141,21 +184,15 @@ func createUser(db *sql.DB, create *api.UserCreate) (*userRaw, error) {
|
|||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?)
|
||||||
RETURNING id, email, role, name, password_hash, open_id, created_ts, updated_ts, row_status
|
RETURNING id, email, role, name, password_hash, open_id, created_ts, updated_ts, row_status
|
||||||
`,
|
`
|
||||||
|
var userRaw userRaw
|
||||||
|
if err := tx.QueryRowContext(ctx, query,
|
||||||
create.Email,
|
create.Email,
|
||||||
create.Role,
|
create.Role,
|
||||||
create.Name,
|
create.Name,
|
||||||
create.PasswordHash,
|
create.PasswordHash,
|
||||||
create.OpenID,
|
create.OpenID,
|
||||||
)
|
).Scan(
|
||||||
if err != nil {
|
|
||||||
return nil, FormatError(err)
|
|
||||||
}
|
|
||||||
defer row.Close()
|
|
||||||
|
|
||||||
row.Next()
|
|
||||||
var userRaw userRaw
|
|
||||||
if err := row.Scan(
|
|
||||||
&userRaw.ID,
|
&userRaw.ID,
|
||||||
&userRaw.Email,
|
&userRaw.Email,
|
||||||
&userRaw.Role,
|
&userRaw.Role,
|
||||||
@ -172,7 +209,7 @@ func createUser(db *sql.DB, create *api.UserCreate) (*userRaw, error) {
|
|||||||
return &userRaw, nil
|
return &userRaw, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func patchUser(db *sql.DB, patch *api.UserPatch) (*userRaw, error) {
|
func patchUser(ctx context.Context, tx *sql.Tx, patch *api.UserPatch) (*userRaw, error) {
|
||||||
set, args := []string{}, []interface{}{}
|
set, args := []string{}, []interface{}{}
|
||||||
|
|
||||||
if v := patch.RowStatus; v != nil {
|
if v := patch.RowStatus; v != nil {
|
||||||
@ -193,12 +230,13 @@ func patchUser(db *sql.DB, patch *api.UserPatch) (*userRaw, error) {
|
|||||||
|
|
||||||
args = append(args, patch.ID)
|
args = append(args, patch.ID)
|
||||||
|
|
||||||
row, err := db.Query(`
|
query := `
|
||||||
UPDATE user
|
UPDATE user
|
||||||
SET ` + strings.Join(set, ", ") + `
|
SET ` + strings.Join(set, ", ") + `
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
RETURNING id, email, role, name, password_hash, open_id, created_ts, updated_ts, row_status
|
RETURNING id, email, role, name, password_hash, open_id, created_ts, updated_ts, row_status
|
||||||
`, args...)
|
`
|
||||||
|
row, err := tx.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, FormatError(err)
|
||||||
}
|
}
|
||||||
@ -226,7 +264,7 @@ func patchUser(db *sql.DB, patch *api.UserPatch) (*userRaw, error) {
|
|||||||
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", patch.ID)}
|
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", patch.ID)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) {
|
func findUserList(ctx context.Context, tx *sql.Tx, find *api.UserFind) ([]*userRaw, error) {
|
||||||
where, args := []string{"1 = 1"}, []interface{}{}
|
where, args := []string{"1 = 1"}, []interface{}{}
|
||||||
|
|
||||||
if v := find.ID; v != nil {
|
if v := find.ID; v != nil {
|
||||||
@ -245,7 +283,7 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) {
|
|||||||
where, args = append(where, "open_id = ?"), append(args, *v)
|
where, args = append(where, "open_id = ?"), append(args, *v)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.Query(`
|
query := `
|
||||||
SELECT
|
SELECT
|
||||||
id,
|
id,
|
||||||
email,
|
email,
|
||||||
@ -258,9 +296,9 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) {
|
|||||||
row_status
|
row_status
|
||||||
FROM user
|
FROM user
|
||||||
WHERE ` + strings.Join(where, " AND ") + `
|
WHERE ` + strings.Join(where, " AND ") + `
|
||||||
ORDER BY created_ts DESC, row_status DESC`,
|
ORDER BY created_ts DESC, row_status DESC
|
||||||
args...,
|
`
|
||||||
)
|
rows, err := tx.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, FormatError(err)
|
return nil, FormatError(err)
|
||||||
}
|
}
|
||||||
@ -293,19 +331,13 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) {
|
|||||||
return userRawList, nil
|
return userRawList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func deleteUser(db *sql.DB, delete *api.UserDelete) error {
|
func deleteUser(ctx context.Context, tx *sql.Tx, delete *api.UserDelete) error {
|
||||||
result, err := db.Exec(`
|
if _, err := tx.ExecContext(ctx, `
|
||||||
PRAGMA foreign_keys = ON;
|
PRAGMA foreign_keys = ON;
|
||||||
DELETE FROM user WHERE id = ?
|
DELETE FROM user WHERE id = ?
|
||||||
`, delete.ID)
|
`, delete.ID); err != nil {
|
||||||
if err != nil {
|
|
||||||
return FormatError(err)
|
return FormatError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, _ := result.RowsAffected()
|
|
||||||
if rows == 0 {
|
|
||||||
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", delete.ID)}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user