chore: fix postgres stmts

This commit is contained in:
Steven 2024-01-05 21:27:16 +08:00
parent ee13927607
commit 501f8898f6
16 changed files with 589 additions and 672 deletions

View File

@ -2,8 +2,8 @@ package postgres
import (
"context"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
@ -21,50 +21,29 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
payloadString = string(bytes)
}
qb := squirrel.Insert("activity").
Columns("creator_id", "type", "level", "payload").
PlaceholderFormat(squirrel.Dollar)
values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
qb = qb.Values(values...).Suffix("RETURNING id")
stmt, args, err := qb.ToSql()
if err != nil {
return nil, errors.Wrap(err, "failed to construct query")
}
var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, errors.Wrap(err, "failed to execute statement and retrieve ID")
}
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id})
if err != nil || len(list) == 0 {
return nil, errors.Wrap(err, "failed to find activity")
}
return list[0], nil
}
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
qb := squirrel.Select("id", "created_ts", "creator_id", "type", "level", "payload").
From("activity").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
}
if find.Type != nil {
qb = qb.Where(squirrel.Eq{"type": find.Type.String()})
}
query, args, err := qb.ToSql()
if err != nil {
fields := []string{"creator_id", "type", "level", "payload"}
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
&create.CreatedTs,
); err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
if find.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
}
query := "SELECT id, creator_id, type, level, payload, created_ts FROM activity WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
@ -77,17 +56,17 @@ func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*s
var payloadBytes []byte
if err := rows.Scan(
&activity.ID,
&activity.CreatedTs,
&activity.CreatorID,
&activity.Type,
&activity.Level,
&payloadBytes,
&activity.CreatedTs,
); err != nil {
return nil, err
}
payload := &storepb.ActivityPayload{}
if err := protojson.Unmarshal(payloadBytes, payload); err != nil {
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
return nil, err
}
activity.Payload = payload

View File

@ -1,9 +1,26 @@
package postgres
import "google.golang.org/protobuf/encoding/protojson"
import (
"fmt"
"strings"
"google.golang.org/protobuf/encoding/protojson"
)
var (
protojsonUnmarshaler = protojson.UnmarshalOptions{
DiscardUnknown: true,
}
)
func placeholder(n int) string {
return "$" + fmt.Sprint(n)
}
func placeholders(n int) string {
list := []string{}
for i := 0; i < n; i++ {
list = append(list, placeholder(i+1))
}
return strings.Join(list, ", ")
}

View File

@ -3,8 +3,8 @@ package postgres
import (
"context"
"encoding/json"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
@ -22,42 +22,34 @@ func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityP
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
}
qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config")
values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}
qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar)
qb = qb.Suffix("RETURNING id")
stmt, args, err := qb.ToSql()
if err != nil {
fields := []string{"name", "type", "identifier_filter", "config"}
args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
return nil, err
}
var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, err
}
create.ID = id
return create, nil
identityProvider := create
return identityProvider, nil
}
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
qb := squirrel.Select("id", "name", "type", "identifier_filter", "config").
From("idp").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
qb = qb.Where(squirrel.Eq{"id": *v})
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
name,
type,
identifier_filter,
config
FROM idp
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
args...,
)
if err != nil {
return nil, err
}
@ -111,15 +103,12 @@ func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityPr
}
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
qb := squirrel.Update("idp").
PlaceholderFormat(squirrel.Dollar)
var err error
set, args := []string{}, []any{}
if v := update.Name; v != nil {
qb = qb.Set("name", *v)
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.IdentifierFilter; v != nil {
qb = qb.Set("identifier_filter", *v)
set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Config; v != nil {
var configBytes []byte
@ -132,42 +121,53 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
} else {
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
}
qb = qb.Set("config", string(configBytes))
set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, string(configBytes))
}
qb = qb.Where(squirrel.Eq{"id": update.ID})
stmt := `
UPDATE idp
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, name, type, identifier_filter, config
`
args = append(args, update.ID)
stmt, args, err := qb.ToSql()
if err != nil {
var identityProvider store.IdentityProvider
var identityProviderConfig string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID,
&identityProvider.Name,
&identityProvider.Type,
&identityProvider.IdentifierFilter,
&identityProviderConfig,
); err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, err
if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID})
return &identityProvider, nil
}
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
qb := squirrel.Delete("idp").
Where(squirrel.Eq{"id": delete.ID}).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
return err
}
where, args := []string{"id = $1"}, []any{delete.ID}
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}

View File

@ -2,8 +2,8 @@ package postgres
import (
"context"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
@ -21,61 +21,54 @@ func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox
messageString = string(bytes)
}
qb := squirrel.Insert("inbox").
Columns("sender_id", "receiver_id", "status", "message").
Values(create.SenderID, create.ReceiverID, create.Status, messageString).
Suffix("RETURNING id").
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
fields := []string{"sender_id", "receiver_id", "status", "message"}
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
stmt := "INSERT INTO inbox (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
&create.CreatedTs,
); err != nil {
return nil, err
}
var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, err
}
return d.GetInbox(ctx, &store.FindInbox{ID: &id})
return create, nil
}
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
qb := squirrel.Select("id", "created_ts", "sender_id", "receiver_id", "status", "message").
From("inbox").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
if find.SenderID != nil {
qb = qb.Where(squirrel.Eq{"sender_id": *find.SenderID})
where, args = append(where, "sender_id = "+placeholder(len(args)+1)), append(args, *find.SenderID)
}
if find.ReceiverID != nil {
qb = qb.Where(squirrel.Eq{"receiver_id": *find.ReceiverID})
where, args = append(where, "receiver_id = "+placeholder(len(args)+1)), append(args, *find.ReceiverID)
}
if find.Status != nil {
qb = qb.Where(squirrel.Eq{"status": *find.Status})
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
where, args = append(where, "status = "+placeholder(len(args)+1)), append(args, *find.Status)
}
query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*store.Inbox
list := []*store.Inbox{}
for rows.Next() {
inbox := &store.Inbox{}
var messageBytes []byte
if err := rows.Scan(&inbox.ID, &inbox.CreatedTs, &inbox.SenderID, &inbox.ReceiverID, &inbox.Status, &messageBytes); err != nil {
if err := rows.Scan(
&inbox.ID,
&inbox.CreatedTs,
&inbox.SenderID,
&inbox.ReceiverID,
&inbox.Status,
&messageBytes,
); err != nil {
return nil, err
}
@ -87,7 +80,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
list = append(list, inbox)
}
return list, rows.Err()
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
@ -102,39 +99,36 @@ func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox,
}
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
qb := squirrel.Update("inbox").
Set("status", update.Status.String()).
Where(squirrel.Eq{"id": update.ID}).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
set, args := []string{"status = $1"}, []any{update.Status.String()}
args = append(args, update.ID)
query := "UPDATE inbox SET " + strings.Join(set, ", ") + " WHERE id = $2 RETURNING id, created_ts, sender_id, receiver_id, status, message"
inbox := &store.Inbox{}
var messageBytes []byte
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
&inbox.ID,
&inbox.CreatedTs,
&inbox.SenderID,
&inbox.ReceiverID,
&inbox.Status,
&messageBytes,
); err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, stmt, args...)
if err != nil {
message := &storepb.InboxMessage{}
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
return nil, err
}
return d.GetInbox(ctx, &store.FindInbox{ID: &update.ID})
inbox.Message = message
return inbox, nil
}
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
qb := squirrel.Delete("inbox").
Where(squirrel.Eq{"id": delete.ID}).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
result, err := d.db.ExecContext(ctx, "DELETE FROM inbox WHERE id = $1", delete.ID)
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
if _, err := result.RowsAffected(); err != nil {
return err
}
_, err = result.RowsAffected()
return err
return nil
}

View File

@ -3,153 +3,127 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
// Initialize a Squirrel statement builder for PostgreSQL
builder := squirrel.Insert("memo").
PlaceholderFormat(squirrel.Dollar).
Columns("creator_id", "content", "visibility")
fields := []string{"creator_id", "content", "visibility"}
args := []any{create.CreatorID, create.Content, create.Visibility}
// Add initial values for the columns
values := []any{create.CreatorID, create.Content, create.Visibility}
// Add all the values at once
builder = builder.Values(values...)
// Add the RETURNING clause to get the ID of the inserted row
builder = builder.Suffix("RETURNING id")
// Prepare and execute the query
query, args, err := builder.ToSql()
if err != nil {
stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
var id int32
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
// Retrieve the newly created memo
memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id})
if err != nil {
return nil, err
}
if memo == nil {
return nil, errors.Errorf("failed to create memo")
}
return memo, nil
return create, nil
}
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
// Start building the SELECT statement
builder := squirrel.Select(
"memo.id AS id",
"memo.creator_id AS creator_id",
"memo.created_ts AS created_ts",
"memo.updated_ts AS updated_ts",
"memo.row_status AS row_status",
"memo.content AS content",
"memo.visibility AS visibility",
"MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned").
From("memo").
LeftJoin("memo_organizer ON memo.id = memo_organizer.memo_id").
LeftJoin("resource ON memo.id = resource.memo_id").
GroupBy("memo.id").
PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
// Add conditional where clauses
if v := find.ID; v != nil {
builder = builder.Where("memo.id = ?", *v)
where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatorID; v != nil {
builder = builder.Where("memo.creator_id = ?", *v)
where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.RowStatus; v != nil {
builder = builder.Where("memo.row_status = ?", *v)
where, args = append(where, "memo.row_status = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatedTsBefore; v != nil {
builder = builder.Where("memo.created_ts < ?", *v)
where, args = append(where, "memo.created_ts < "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatedTsAfter; v != nil {
builder = builder.Where("memo.created_ts > ?", *v)
}
if v := find.Pinned; v != nil {
builder = builder.Where("memo_organizer.pinned = 1")
where, args = append(where, "memo.created_ts > "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.ContentSearch; len(v) != 0 {
for _, s := range v {
builder = builder.Where("memo.content LIKE ?", "%"+s+"%")
where, args = append(where, "memo.content LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", s))
}
}
if v := find.VisibilityList; len(v) != 0 {
holders := []string{}
for _, visibility := range v {
holders = append(holders, placeholder(len(args)+1))
args = append(args, visibility.String())
}
where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(holders, ", ")))
}
if v := find.Pinned; v != nil {
where = append(where, "memo_organizer.pinned = 1")
}
if v := find.VisibilityList; len(v) != 0 {
placeholders := make([]string, len(v))
args := make([]any, len(v))
for i, visibility := range v {
placeholders[i] = "?"
args[i] = visibility // Assuming visibility can be directly used as an argument
}
inClause := strings.Join(placeholders, ",")
builder = builder.Where("memo.visibility IN ("+inClause+")", args...)
}
// Add order by clauses
orders := []string{}
if find.OrderByPinned {
builder = builder.OrderBy("pinned DESC")
orders = append(orders, "pinned DESC")
}
if find.OrderByUpdatedTs {
builder = builder.OrderBy("updated_ts DESC")
orders = append(orders, "updated_ts DESC")
} else {
builder = builder.OrderBy("created_ts DESC")
orders = append(orders, "created_ts DESC")
}
builder = builder.OrderBy("id DESC")
orders = append(orders, "id DESC")
// Handle pagination
fields := []string{
`memo.id AS id`,
`memo.creator_id AS creator_id`,
`memo.created_ts AS created_ts`,
`memo.updated_ts AS updated_ts`,
`memo.row_status AS row_status`,
`memo.visibility AS visibility`,
`MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned`,
}
if !find.ExcludeContent {
fields = append(fields, `memo.content AS content`)
}
query := `SELECT ` + strings.Join(fields, ", ") + `
FROM memo
LEFT JOIN memo_organizer ON memo.id = memo_organizer.memo_id
WHERE ` + strings.Join(where, " AND ") + `
GROUP BY memo.id
ORDER BY ` + strings.Join(orders, ", ")
if find.Limit != nil {
builder = builder.Limit(uint64(*find.Limit))
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
builder = builder.Offset(uint64(*find.Offset))
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
// Prepare and execute the query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
// Process the result set
list := make([]*store.Memo, 0)
for rows.Next() {
var memo store.Memo
if err := rows.Scan(
dests := []any{
&memo.ID,
&memo.CreatorID,
&memo.CreatedTs,
&memo.UpdatedTs,
&memo.RowStatus,
&memo.Content,
&memo.Visibility,
&memo.Pinned,
); err != nil {
}
if !find.ExcludeContent {
dests = append(dests, &memo.Content)
}
if err := rows.Scan(dests...); err != nil {
return nil, err
}
list = append(list, &memo)
}
@ -174,51 +148,42 @@ func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, er
}
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
// Start building the update statement
builder := squirrel.Update("memo").
PlaceholderFormat(squirrel.Dollar).
Where("id = ?", update.ID)
// Conditionally add set clauses
set, args := []string{}, []any{}
if v := update.CreatedTs; v != nil {
builder = builder.Set("created_ts", *v)
set, args = append(set, "created_ts = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.UpdatedTs; v != nil {
builder = builder.Set("updated_ts", *v)
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.RowStatus; v != nil {
builder = builder.Set("row_status", *v)
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Content; v != nil {
builder = builder.Set("content", *v)
set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Visibility; v != nil {
builder = builder.Set("visibility", *v)
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, *v)
}
// Prepare and execute the query
query, args, err := builder.ToSql()
if err != nil {
stmt := `UPDATE memo SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
args = append(args, update.ID)
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
return err
}
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return err
}
return nil
}
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
stmt := `DELETE FROM memo WHERE id = $1`
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
where, args := []string{"id = " + placeholder(1)}, []any{delete.ID}
stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
println("stmt", stmt, delete.ID)
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
return errors.Wrap(err, "failed to delete memo")
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return d.Vacuum(ctx)
return nil
}
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {

View File

@ -4,8 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"github.com/Masterminds/squirrel"
"strings"
"github.com/usememos/memos/store"
)
@ -15,99 +14,94 @@ func (d *DB) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganize
if upsert.Pinned {
pinned = 1
}
stmt := "INSERT INTO memo_organizer (memo_id, user_id, pinned) VALUES ($1, $2, $3) ON CONFLICT (memo_id, user_id) DO UPDATE SET pinned = $4"
if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, pinned, pinned); err != nil {
stmt := `
INSERT INTO memo_organizer (
memo_id,
user_id,
pinned
)
VALUES (` + placeholders(3) + `)
ON CONFLICT(memo_id, user_id) DO UPDATE
SET pinned = EXCLUDED.pinned`
if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, pinned); err != nil {
return nil, err
}
return upsert, nil
}
func (d *DB) ListMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) ([]*store.MemoOrganizer, error) {
qb := squirrel.Select("memo_id", "user_id", "pinned").
From("memo_organizer").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
if find.MemoID != 0 {
qb = qb.Where(squirrel.Eq{"memo_id": find.MemoID})
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
}
if find.UserID != 0 {
qb = qb.Where(squirrel.Eq{"user_id": find.UserID})
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, find.UserID)
}
query := fmt.Sprintf(`
SELECT
memo_id,
user_id,
pinned
FROM memo_organizer
WHERE %s
`, strings.Join(where, " AND "))
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*store.MemoOrganizer
list := []*store.MemoOrganizer{}
for rows.Next() {
memoOrganizer := &store.MemoOrganizer{}
if err := rows.Scan(&memoOrganizer.MemoID, &memoOrganizer.UserID, &memoOrganizer.Pinned); err != nil {
pinned := 0
if err := rows.Scan(
&memoOrganizer.MemoID,
&memoOrganizer.UserID,
&pinned,
); err != nil {
return nil, err
}
memoOrganizer.Pinned = pinned == 1
list = append(list, memoOrganizer)
}
return list, rows.Err()
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error {
qb := squirrel.Delete("memo_organizer").
PlaceholderFormat(squirrel.Dollar)
where, args := []string{}, []any{}
if v := delete.MemoID; v != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *v})
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := delete.UserID; v != nil {
qb = qb.Where(squirrel.Eq{"user_id": *v})
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *v)
}
stmt, args, err := qb.ToSql()
if err != nil {
stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ")
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
return err
}
return nil
}
if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil {
func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
memo_organizer
WHERE
memo_id NOT IN (SELECT id FROM memo)
OR user_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for memo_id
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Build the subquery for user_id
subQueryUser, subArgsUser, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subqueries
query, args, err := squirrel.Delete("memo_organizer").
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
Where(fmt.Sprintf("user_id NOT IN (%s)", subQueryUser), subArgsUser...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Combine the arguments from both subqueries
args = append(args, subArgsUser...)
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

View File

@ -3,127 +3,111 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/Masterminds/squirrel"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
qb := squirrel.Insert("memo_relation").
Columns("memo_id", "related_memo_id", "type").
Values(create.MemoID, create.RelatedMemoID, create.Type).
Suffix("ON CONFLICT (version) DO NOTHING").
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
stmt := `
INSERT INTO memo_relation (
memo_id,
related_memo_id,
type
)
VALUES (` + placeholders(3) + `)
RETURNING memo_id, related_memo_id, type
`
memoRelation := &store.MemoRelation{}
if err := d.db.QueryRowContext(
ctx,
stmt,
create.MemoID,
create.RelatedMemoID,
create.Type,
).Scan(
&memoRelation.MemoID,
&memoRelation.RelatedMemoID,
&memoRelation.Type,
); err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return nil, err
}
return &store.MemoRelation{
MemoID: create.MemoID,
RelatedMemoID: create.RelatedMemoID,
Type: create.Type,
}, nil
return memoRelation, nil
}
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
qb := squirrel.Select("memo_id", "related_memo_id", "type").
From("memo_relation").
Where("TRUE").
PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
if find.MemoID != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *find.MemoID})
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
}
if find.RelatedMemoID != nil {
qb = qb.Where(squirrel.Eq{"related_memo_id": *find.RelatedMemoID})
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID)
}
if find.Type != nil {
qb = qb.Where(squirrel.Eq{"type": *find.Type})
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
rows, err := d.db.QueryContext(ctx, `
SELECT
memo_id,
related_memo_id,
type
FROM memo_relation
WHERE `+strings.Join(where, " AND "), args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*store.MemoRelation
list := []*store.MemoRelation{}
for rows.Next() {
memoRelation := &store.MemoRelation{}
if err := rows.Scan(&memoRelation.MemoID, &memoRelation.RelatedMemoID, &memoRelation.Type); err != nil {
if err := rows.Scan(
&memoRelation.MemoID,
&memoRelation.RelatedMemoID,
&memoRelation.Type,
); err != nil {
return nil, err
}
list = append(list, memoRelation)
}
return list, rows.Err()
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
qb := squirrel.Delete("memo_relation").
PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
if delete.MemoID != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *delete.MemoID})
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID)
}
if delete.RelatedMemoID != nil {
qb = qb.Where(squirrel.Eq{"related_memo_id": *delete.RelatedMemoID})
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID)
}
if delete.Type != nil {
qb = qb.Where(squirrel.Eq{"type": *delete.Type})
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type)
}
stmt, args, err := qb.ToSql()
if err != nil {
return err
}
stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
_, err = result.RowsAffected()
return err
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}
func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for memo_id
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
if _, err := tx.ExecContext(ctx, `
DELETE FROM memo_relation
WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo)
`); err != nil {
return err
}
// Note: The same subquery is used for related_memo_id as it's also checking against the "memo" table
// Now, build the main delete query using the subqueries
query, args, err := squirrel.Delete("memo_relation").
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
Where(fmt.Sprintf("related_memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Combine the arguments for both instances of the same subquery
args = append(args, subArgsMemo...)
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
return nil
}

View File

@ -3,22 +3,12 @@ package postgres
import (
"context"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store"
)
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
qb := squirrel.Select("version", "created_ts").
From("migration_history").
OrderBy("created_ts DESC")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
query := "SELECT version, created_ts FROM migration_history ORDER BY created_ts DESC"
rows, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
@ -27,9 +17,13 @@ func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigratio
list := make([]*store.MigrationHistory, 0)
for rows.Next() {
var migrationHistory store.MigrationHistory
if err := rows.Scan(&migrationHistory.Version, &migrationHistory.CreatedTs); err != nil {
if err := rows.Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}
list = append(list, &migrationHistory)
}
@ -41,33 +35,21 @@ func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigratio
}
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
qb := squirrel.Insert("migration_history").
Columns("version").
Values(upsert.Version).
Suffix("ON CONFLICT (version) DO NOTHING").
PlaceholderFormat(squirrel.Dollar)
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
stmt := `
INSERT INTO migration_history (
version
)
VALUES ($1)
ON CONFLICT(version) DO UPDATE
SET
version=EXCLUDED.version
RETURNING version, created_ts
`
var migrationHistory store.MigrationHistory
query, args, err = squirrel.Select("version", "created_ts").
From("migration_history").
Where(squirrel.Eq{"version": upsert.Version}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, err
}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(&migrationHistory.Version, &migrationHistory.CreatedTs); err != nil {
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}

View File

@ -4,77 +4,61 @@ import (
"context"
"database/sql"
"fmt"
"time"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) {
qb := squirrel.Insert("resource").Columns("filename", "blob", "external_link", "type", "size", "creator_id", "internal_path", "memo_id")
values := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, create.MemoID}
fields := []string{"filename", "blob", "external_link", "type", "size", "creator_id", "internal_path", "memo_id"}
args := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, create.MemoID}
qb = qb.Values(values...).Suffix("RETURNING id")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
stmt := "INSERT INTO resource (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
return nil, err
}
var id int32
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
list, err := d.ListResources(ctx, &store.FindResource{ID: &id})
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
}
return list[0], nil
return create, nil
}
func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) {
qb := squirrel.Select("id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id").From("resource")
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
qb = qb.Where(squirrel.Eq{"id": *v})
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.CreatorID; v != nil {
qb = qb.Where(squirrel.Eq{"creator_id": *v})
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Filename; v != nil {
qb = qb.Where(squirrel.Eq{"filename": *v})
where, args = append(where, "filename = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.MemoID; v != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *v})
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
}
if find.HasRelatedMemo {
qb = qb.Where("memo_id IS NOT NULL")
where = append(where, "memo_id IS NOT NULL")
}
fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id"}
if find.GetBlob {
qb = qb.Columns("blob")
fields = append(fields, "blob")
}
qb = qb.GroupBy("id").OrderBy("created_ts DESC")
query := fmt.Sprintf(`
SELECT
%s
FROM resource
WHERE %s
GROUP BY id
ORDER BY created_ts DESC
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
if find.Limit != nil {
qb = qb.Limit(uint64(*find.Limit))
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
qb = qb.Offset(uint64(*find.Offset))
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
@ -103,7 +87,6 @@ func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*st
if err := rows.Scan(dests...); err != nil {
return nil, err
}
if memoID.Valid {
resource.MemoID = &memoID.Int32
}
@ -118,88 +101,72 @@ func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*st
}
func (d *DB) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) {
qb := squirrel.Update("resource")
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
qb = qb.Set("updated_ts", time.Unix(0, *v))
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Filename; v != nil {
qb = qb.Set("filename", *v)
set, args = append(set, "filename = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.InternalPath; v != nil {
qb = qb.Set("internal_path", *v)
set, args = append(set, "internal_path = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.MemoID; v != nil {
qb = qb.Set("memo_id", *v)
set, args = append(set, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Blob; v != nil {
qb = qb.Set("blob", v)
set, args = append(set, "blob = "+placeholder(len(args)+1)), append(args, v)
}
qb = qb.Where(squirrel.Eq{"id": update.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"}
stmt := `
UPDATE resource
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING ` + strings.Join(fields, ", ")
args = append(args, update.ID)
resource := store.Resource{}
dests := []any{
&resource.ID,
&resource.Filename,
&resource.ExternalLink,
&resource.Type,
&resource.Size,
&resource.CreatorID,
&resource.CreatedTs,
&resource.UpdatedTs,
&resource.InternalPath,
}
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(dests...); err != nil {
return nil, err
}
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID})
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
}
return list[0], nil
return &resource, nil
}
func (d *DB) DeleteResource(ctx context.Context, delete *store.DeleteResource) error {
qb := squirrel.Delete("resource").Where(squirrel.Eq{"id": delete.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
stmt := `DELETE FROM resource WHERE id = $1`
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}
if err := d.Vacuum(ctx); err != nil {
// Prevent linter warning.
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
resource
WHERE
creator_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subquery
query, args, err := squirrel.Delete("resource").
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

View File

@ -3,7 +3,6 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/Masterminds/squirrel"
@ -82,22 +81,15 @@ func (d *DB) DeleteTag(ctx context.Context, delete *store.DeleteTag) error {
}
func vacuumTag(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for creator_id
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
stmt := `
DELETE FROM
tag
WHERE
creator_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
// Now, build the main delete query using the subquery
query, args, err := squirrel.Delete("tag").
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
return nil
}

View File

@ -2,132 +2,113 @@ package postgres
import (
"context"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
// Start building the insert statement
builder := squirrel.Insert(`"user"`).PlaceholderFormat(squirrel.Dollar)
columns := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
builder = builder.Columns(columns...)
values := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
builder = builder.Values(values...)
builder = builder.Suffix("RETURNING id")
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
fields := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
stmt := "INSERT INTO \"user\" (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, avatar_url, created_ts, updated_ts, row_status"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
&create.AvatarURL,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
// Execute the query and get the returned ID
var id int32
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
// Use the returned ID to retrieve the full user object
user, err := d.GetUser(ctx, &store.FindUser{ID: &id})
if err != nil {
return nil, err
}
return user, nil
return create, nil
}
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
// Start building the update statement
builder := squirrel.Update(`"user"`).PlaceholderFormat(squirrel.Dollar)
// Conditionally add set clauses
set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
builder = builder.Set("updated_ts", *v)
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.RowStatus; v != nil {
builder = builder.Set("row_status", *v)
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Username; v != nil {
builder = builder.Set("username", *v)
set, args = append(set, "username = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Email; v != nil {
builder = builder.Set("email", *v)
set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Nickname; v != nil {
builder = builder.Set("nickname", *v)
set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.AvatarURL; v != nil {
builder = builder.Set("avatar_url", *v)
set, args = append(set, "avatar_url = "+placeholder(len(args)+1)), append(args, *v)
}
if v := update.PasswordHash; v != nil {
builder = builder.Set("password_hash", *v)
set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
}
// Add the WHERE clause
builder = builder.Where(squirrel.Eq{"id": update.ID})
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
query := `
UPDATE "user"
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, username, role, email, nickname, password_hash, avatar_url, created_ts, updated_ts, row_status
`
args = append(args, update.ID)
user := &store.User{}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err
}
// Execute the query with the context
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
// Retrieve the updated user
user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
if err != nil {
return nil, err
}
return user, nil
}
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
// Start building the SELECT statement
builder := squirrel.Select("id", "username", "role", "email", "nickname", "password_hash", "avatar_url", "created_ts", "updated_ts", "row_status").
From(`"user"`).
PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
// 1 = 1 is often used as a no-op in SQL, ensuring there's always a WHERE clause
builder = builder.Where("1 = 1")
// Conditionally add where clauses
if v := find.ID; v != nil {
builder = builder.Where(squirrel.Eq{"id": *v})
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Username; v != nil {
builder = builder.Where(squirrel.Eq{"username": *v})
where, args = append(where, "username = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Role; v != nil {
builder = builder.Where(squirrel.Eq{"role": *v})
where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Email; v != nil {
builder = builder.Where(squirrel.Eq{"email": *v})
where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Nickname; v != nil {
builder = builder.Where(squirrel.Eq{"nickname": *v})
where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
}
// Add ordering
builder = builder.OrderBy("created_ts DESC", "row_status DESC")
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
// Execute the query with the context
query := `
SELECT
id,
username,
role,
email,
nickname,
password_hash,
avatar_url,
created_ts,
updated_ts,
row_status
FROM "user"
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
@ -161,35 +142,13 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
return list, nil
}
func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
list, err := d.ListUsers(ctx, find)
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
}
return list[0], nil
}
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
// Start building the DELETE statement
builder := squirrel.Delete(`"user"`).
PlaceholderFormat(squirrel.Dollar).
Where(squirrel.Eq{"id": delete.ID})
// Prepare the final query
query, args, err := builder.ToSql()
result, err := d.db.ExecContext(ctx, `
DELETE FROM "user" WHERE id = $1
`, delete.ID)
if err != nil {
return err
}
// Execute the query with the context
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}

View File

@ -3,7 +3,6 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
@ -130,22 +129,15 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
}
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
stmt := `
DELETE FROM
user_setting
WHERE
user_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
// Now, build the main delete query using the subquery
query, args, err := squirrel.Delete("user_setting").
Where(fmt.Sprintf("user_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
return nil
}

View File

@ -38,7 +38,6 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "`id` = ?"), append(args, *find.ID)
}

View File

@ -0,0 +1,62 @@
package teststore
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestMemoOrganizerStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoCreate := &store.Memo{
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
}
memo, err := ts.CreateMemo(ctx, memoCreate)
require.NoError(t, err)
require.Equal(t, memoCreate.Content, memo.Content)
memoOrganizer, err := ts.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
MemoID: memo.ID,
UserID: user.ID,
Pinned: true,
})
require.NoError(t, err)
require.NotNil(t, memoOrganizer)
require.Equal(t, memo.ID, memoOrganizer.MemoID)
require.Equal(t, user.ID, memoOrganizer.UserID)
require.Equal(t, true, memoOrganizer.Pinned)
memoOrganizerTemp, err := ts.GetMemoOrganizer(ctx, &store.FindMemoOrganizer{
MemoID: memo.ID,
})
require.NoError(t, err)
require.Equal(t, memoOrganizer, memoOrganizerTemp)
memoOrganizerTemp, err = ts.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
MemoID: memo.ID,
UserID: user.ID,
Pinned: false,
})
require.NoError(t, err)
require.NotNil(t, memoOrganizerTemp)
require.Equal(t, memo.ID, memoOrganizerTemp.MemoID)
require.Equal(t, user.ID, memoOrganizerTemp.UserID)
require.Equal(t, false, memoOrganizerTemp.Pinned)
err = ts.DeleteMemoOrganizer(ctx, &store.DeleteMemoOrganizer{
MemoID: &memo.ID,
UserID: &user.ID,
})
require.NoError(t, err)
memoOrganizers, err := ts.ListMemoOrganizer(ctx, &store.FindMemoOrganizer{
UserID: user.ID,
})
require.NoError(t, err)
require.Equal(t, 0, len(memoOrganizers))
}

View File

@ -59,3 +59,22 @@ func TestMemoStore(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 0, len(memoList))
}
func TestDeleteMemoStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoCreate := &store.Memo{
CreatorID: user.ID,
Content: "test_content",
Visibility: store.Public,
}
memo, err := ts.CreateMemo(ctx, memoCreate)
require.NoError(t, err)
require.Equal(t, memoCreate.Content, memo.Content)
err = ts.DeleteMemo(ctx, &store.DeleteMemo{
ID: memo.ID,
})
require.NoError(t, err)
}

View File

@ -8,6 +8,7 @@ import (
// sqlite driver.
_ "modernc.org/sqlite"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
"github.com/usememos/memos/store/db"
"github.com/usememos/memos/test"
@ -19,6 +20,7 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
if err != nil {
fmt.Printf("failed to create db driver, error: %+v\n", err)
}
resetTestingDB(ctx, profile, dbDriver)
if err := dbDriver.Migrate(ctx); err != nil {
fmt.Printf("failed to migrate db, error: %+v\n", err)
}
@ -26,3 +28,13 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
store := store.New(dbDriver, profile)
return store
}
func resetTestingDB(ctx context.Context, profile *profile.Profile, dbDriver store.Driver) {
if profile.Driver == "postgres" {
_, err := dbDriver.GetDB().ExecContext(ctx, `DROP SCHEMA public CASCADE; CREATE SCHEMA public;`)
if err != nil {
fmt.Printf("failed to reset testing db, error: %+v\n", err)
panic(err)
}
}
}