chore: fix postgres stmts

This commit is contained in:
Steven 2024-01-05 21:27:16 +08:00
parent ee13927607
commit 501f8898f6
16 changed files with 589 additions and 672 deletions

View File

@ -2,8 +2,8 @@ package postgres
import ( import (
"context" "context"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
@ -21,50 +21,29 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
payloadString = string(bytes) payloadString = string(bytes)
} }
qb := squirrel.Insert("activity"). fields := []string{"creator_id", "type", "level", "payload"}
Columns("creator_id", "type", "level", "payload"). args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
PlaceholderFormat(squirrel.Dollar) stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString} &create.ID,
qb = qb.Values(values...).Suffix("RETURNING id") &create.CreatedTs,
); err != nil {
stmt, args, err := qb.ToSql()
if err != nil {
return nil, errors.Wrap(err, "failed to construct query")
}
var id int32
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, errors.Wrap(err, "failed to execute statement and retrieve ID")
}
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id})
if err != nil || len(list) == 0 {
return nil, errors.Wrap(err, "failed to find activity")
}
return list[0], nil
}
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
qb := squirrel.Select("id", "created_ts", "creator_id", "type", "level", "payload").
From("activity").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
}
if find.Type != nil {
qb = qb.Where(squirrel.Eq{"type": find.Type.String()})
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err return nil, err
} }
return create, nil
}
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
if find.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
}
query := "SELECT id, creator_id, type, level, payload, created_ts FROM activity WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
rows, err := d.db.QueryContext(ctx, query, args...) rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -77,17 +56,17 @@ func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*s
var payloadBytes []byte var payloadBytes []byte
if err := rows.Scan( if err := rows.Scan(
&activity.ID, &activity.ID,
&activity.CreatedTs,
&activity.CreatorID, &activity.CreatorID,
&activity.Type, &activity.Type,
&activity.Level, &activity.Level,
&payloadBytes, &payloadBytes,
&activity.CreatedTs,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
payload := &storepb.ActivityPayload{} payload := &storepb.ActivityPayload{}
if err := protojson.Unmarshal(payloadBytes, payload); err != nil { if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
return nil, err return nil, err
} }
activity.Payload = payload activity.Payload = payload

View File

@ -1,9 +1,26 @@
package postgres package postgres
import "google.golang.org/protobuf/encoding/protojson" import (
"fmt"
"strings"
"google.golang.org/protobuf/encoding/protojson"
)
var ( var (
protojsonUnmarshaler = protojson.UnmarshalOptions{ protojsonUnmarshaler = protojson.UnmarshalOptions{
DiscardUnknown: true, DiscardUnknown: true,
} }
) )
func placeholder(n int) string {
return "$" + fmt.Sprint(n)
}
func placeholders(n int) string {
list := []string{}
for i := 0; i < n; i++ {
list = append(list, placeholder(i+1))
}
return strings.Join(list, ", ")
}

View File

@ -3,8 +3,8 @@ package postgres
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
@ -22,42 +22,34 @@ func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityP
return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
} }
qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config") fields := []string{"name", "type", "identifier_filter", "config"}
values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar) if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
qb = qb.Suffix("RETURNING id")
stmt, args, err := qb.ToSql()
if err != nil {
return nil, err return nil, err
} }
var id int32 identityProvider := create
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) return identityProvider, nil
if err != nil {
return nil, err
}
create.ID = id
return create, nil
} }
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
qb := squirrel.Select("id", "name", "type", "identifier_filter", "config"). where, args := []string{"1 = 1"}, []any{}
From("idp").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if v := find.ID; v != nil { if v := find.ID; v != nil {
qb = qb.Where(squirrel.Eq{"id": *v}) where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
} }
query, args, err := qb.ToSql() rows, err := d.db.QueryContext(ctx, `
if err != nil { SELECT
return nil, err id,
} name,
type,
rows, err := d.db.QueryContext(ctx, query, args...) identifier_filter,
config
FROM idp
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
args...,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -111,15 +103,12 @@ func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityPr
} }
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) { func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
qb := squirrel.Update("idp"). set, args := []string{}, []any{}
PlaceholderFormat(squirrel.Dollar)
var err error
if v := update.Name; v != nil { if v := update.Name; v != nil {
qb = qb.Set("name", *v) set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.IdentifierFilter; v != nil { if v := update.IdentifierFilter; v != nil {
qb = qb.Set("identifier_filter", *v) set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Config; v != nil { if v := update.Config; v != nil {
var configBytes []byte var configBytes []byte
@ -132,42 +121,53 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
} else { } else {
return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
} }
qb = qb.Set("config", string(configBytes)) set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, string(configBytes))
} }
qb = qb.Where(squirrel.Eq{"id": update.ID}) stmt := `
UPDATE idp
SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, name, type, identifier_filter, config
`
args = append(args, update.ID)
stmt, args, err := qb.ToSql() var identityProvider store.IdentityProvider
if err != nil { var identityProviderConfig string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID,
&identityProvider.Name,
&identityProvider.Type,
&identityProvider.IdentifierFilter,
&identityProviderConfig,
); err != nil {
return nil, err return nil, err
} }
_, err = d.db.ExecContext(ctx, stmt, args...) if identityProvider.Type == store.IdentityProviderOAuth2Type {
if err != nil { oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err return nil, err
} }
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID}) return &identityProvider, nil
} }
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error { func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
qb := squirrel.Delete("idp"). where, args := []string{"id = $1"}, []any{delete.ID}
Where(squirrel.Eq{"id": delete.ID}). stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, stmt, args...) result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil { if err != nil {
return err return err
} }
if _, err = result.RowsAffected(); err != nil { if _, err = result.RowsAffected(); err != nil {
return err return err
} }
return nil return nil
} }

View File

@ -2,8 +2,8 @@ package postgres
import ( import (
"context" "context"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
@ -21,61 +21,54 @@ func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox
messageString = string(bytes) messageString = string(bytes)
} }
qb := squirrel.Insert("inbox"). fields := []string{"sender_id", "receiver_id", "status", "message"}
Columns("sender_id", "receiver_id", "status", "message"). args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
Values(create.SenderID, create.ReceiverID, create.Status, messageString). stmt := "INSERT INTO inbox (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
Suffix("RETURNING id"). if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
PlaceholderFormat(squirrel.Dollar) &create.ID,
&create.CreatedTs,
stmt, args, err := qb.ToSql() ); err != nil {
if err != nil {
return nil, err return nil, err
} }
var id int32 return create, nil
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
if err != nil {
return nil, err
}
return d.GetInbox(ctx, &store.FindInbox{ID: &id})
} }
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) { func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
qb := squirrel.Select("id", "created_ts", "sender_id", "receiver_id", "status", "message"). where, args := []string{"1 = 1"}, []any{}
From("inbox").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if find.ID != nil { if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID}) where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
} }
if find.SenderID != nil { if find.SenderID != nil {
qb = qb.Where(squirrel.Eq{"sender_id": *find.SenderID}) where, args = append(where, "sender_id = "+placeholder(len(args)+1)), append(args, *find.SenderID)
} }
if find.ReceiverID != nil { if find.ReceiverID != nil {
qb = qb.Where(squirrel.Eq{"receiver_id": *find.ReceiverID}) where, args = append(where, "receiver_id = "+placeholder(len(args)+1)), append(args, *find.ReceiverID)
} }
if find.Status != nil { if find.Status != nil {
qb = qb.Where(squirrel.Eq{"status": *find.Status}) where, args = append(where, "status = "+placeholder(len(args)+1)), append(args, *find.Status)
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
} }
query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
rows, err := d.db.QueryContext(ctx, query, args...) rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var list []*store.Inbox list := []*store.Inbox{}
for rows.Next() { for rows.Next() {
inbox := &store.Inbox{} inbox := &store.Inbox{}
var messageBytes []byte var messageBytes []byte
if err := rows.Scan(&inbox.ID, &inbox.CreatedTs, &inbox.SenderID, &inbox.ReceiverID, &inbox.Status, &messageBytes); err != nil { if err := rows.Scan(
&inbox.ID,
&inbox.CreatedTs,
&inbox.SenderID,
&inbox.ReceiverID,
&inbox.Status,
&messageBytes,
); err != nil {
return nil, err return nil, err
} }
@ -87,7 +80,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
list = append(list, inbox) list = append(list, inbox)
} }
return list, rows.Err() if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
} }
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) { func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
@ -102,39 +99,36 @@ func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox,
} }
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) { func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
qb := squirrel.Update("inbox"). set, args := []string{"status = $1"}, []any{update.Status.String()}
Set("status", update.Status.String()). args = append(args, update.ID)
Where(squirrel.Eq{"id": update.ID}). query := "UPDATE inbox SET " + strings.Join(set, ", ") + " WHERE id = $2 RETURNING id, created_ts, sender_id, receiver_id, status, message"
PlaceholderFormat(squirrel.Dollar) inbox := &store.Inbox{}
var messageBytes []byte
stmt, args, err := qb.ToSql() if err := d.db.QueryRowContext(ctx, query, args...).Scan(
if err != nil { &inbox.ID,
&inbox.CreatedTs,
&inbox.SenderID,
&inbox.ReceiverID,
&inbox.Status,
&messageBytes,
); err != nil {
return nil, err return nil, err
} }
message := &storepb.InboxMessage{}
_, err = d.db.ExecContext(ctx, stmt, args...) if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
if err != nil {
return nil, err return nil, err
} }
inbox.Message = message
return d.GetInbox(ctx, &store.FindInbox{ID: &update.ID}) return inbox, nil
} }
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error { func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
qb := squirrel.Delete("inbox"). result, err := d.db.ExecContext(ctx, "DELETE FROM inbox WHERE id = $1", delete.ID)
Where(squirrel.Eq{"id": delete.ID}).
PlaceholderFormat(squirrel.Dollar)
stmt, args, err := qb.ToSql()
if err != nil { if err != nil {
return err return err
} }
if _, err := result.RowsAffected(); err != nil {
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err return err
} }
return nil
_, err = result.RowsAffected()
return err
} }

View File

@ -3,153 +3,127 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings" "strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) { func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
// Initialize a Squirrel statement builder for PostgreSQL fields := []string{"creator_id", "content", "visibility"}
builder := squirrel.Insert("memo"). args := []any{create.CreatorID, create.Content, create.Visibility}
PlaceholderFormat(squirrel.Dollar).
Columns("creator_id", "content", "visibility")
// Add initial values for the columns stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
values := []any{create.CreatorID, create.Content, create.Visibility} if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
// Add all the values at once &create.CreatedTs,
builder = builder.Values(values...) &create.UpdatedTs,
&create.RowStatus,
// Add the RETURNING clause to get the ID of the inserted row ); err != nil {
builder = builder.Suffix("RETURNING id")
// Prepare and execute the query
query, args, err := builder.ToSql()
if err != nil {
return nil, err return nil, err
} }
var id int32 return create, nil
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
// Retrieve the newly created memo
memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id})
if err != nil {
return nil, err
}
if memo == nil {
return nil, errors.Errorf("failed to create memo")
}
return memo, nil
} }
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) { func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
// Start building the SELECT statement where, args := []string{"1 = 1"}, []any{}
builder := squirrel.Select(
"memo.id AS id",
"memo.creator_id AS creator_id",
"memo.created_ts AS created_ts",
"memo.updated_ts AS updated_ts",
"memo.row_status AS row_status",
"memo.content AS content",
"memo.visibility AS visibility",
"MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned").
From("memo").
LeftJoin("memo_organizer ON memo.id = memo_organizer.memo_id").
LeftJoin("resource ON memo.id = resource.memo_id").
GroupBy("memo.id").
PlaceholderFormat(squirrel.Dollar)
// Add conditional where clauses
if v := find.ID; v != nil { if v := find.ID; v != nil {
builder = builder.Where("memo.id = ?", *v) where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.CreatorID; v != nil { if v := find.CreatorID; v != nil {
builder = builder.Where("memo.creator_id = ?", *v) where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.RowStatus; v != nil { if v := find.RowStatus; v != nil {
builder = builder.Where("memo.row_status = ?", *v) where, args = append(where, "memo.row_status = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.CreatedTsBefore; v != nil { if v := find.CreatedTsBefore; v != nil {
builder = builder.Where("memo.created_ts < ?", *v) where, args = append(where, "memo.created_ts < "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.CreatedTsAfter; v != nil { if v := find.CreatedTsAfter; v != nil {
builder = builder.Where("memo.created_ts > ?", *v) where, args = append(where, "memo.created_ts > "+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Pinned; v != nil {
builder = builder.Where("memo_organizer.pinned = 1")
} }
if v := find.ContentSearch; len(v) != 0 { if v := find.ContentSearch; len(v) != 0 {
for _, s := range v { for _, s := range v {
builder = builder.Where("memo.content LIKE ?", "%"+s+"%") where, args = append(where, "memo.content LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", s))
} }
} }
if v := find.VisibilityList; len(v) != 0 {
holders := []string{}
for _, visibility := range v {
holders = append(holders, placeholder(len(args)+1))
args = append(args, visibility.String())
}
where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(holders, ", ")))
}
if v := find.Pinned; v != nil {
where = append(where, "memo_organizer.pinned = 1")
}
if v := find.VisibilityList; len(v) != 0 { orders := []string{}
placeholders := make([]string, len(v))
args := make([]any, len(v))
for i, visibility := range v {
placeholders[i] = "?"
args[i] = visibility // Assuming visibility can be directly used as an argument
}
inClause := strings.Join(placeholders, ",")
builder = builder.Where("memo.visibility IN ("+inClause+")", args...)
}
// Add order by clauses
if find.OrderByPinned { if find.OrderByPinned {
builder = builder.OrderBy("pinned DESC") orders = append(orders, "pinned DESC")
} }
if find.OrderByUpdatedTs { if find.OrderByUpdatedTs {
builder = builder.OrderBy("updated_ts DESC") orders = append(orders, "updated_ts DESC")
} else { } else {
builder = builder.OrderBy("created_ts DESC") orders = append(orders, "created_ts DESC")
} }
builder = builder.OrderBy("id DESC") orders = append(orders, "id DESC")
// Handle pagination fields := []string{
`memo.id AS id`,
`memo.creator_id AS creator_id`,
`memo.created_ts AS created_ts`,
`memo.updated_ts AS updated_ts`,
`memo.row_status AS row_status`,
`memo.visibility AS visibility`,
`MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned`,
}
if !find.ExcludeContent {
fields = append(fields, `memo.content AS content`)
}
query := `SELECT ` + strings.Join(fields, ", ") + `
FROM memo
LEFT JOIN memo_organizer ON memo.id = memo_organizer.memo_id
WHERE ` + strings.Join(where, " AND ") + `
GROUP BY memo.id
ORDER BY ` + strings.Join(orders, ", ")
if find.Limit != nil { if find.Limit != nil {
builder = builder.Limit(uint64(*find.Limit)) query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil { if find.Offset != nil {
builder = builder.Offset(uint64(*find.Offset)) query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
} }
} }
// Prepare and execute the query
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...) rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
// Process the result set
list := make([]*store.Memo, 0) list := make([]*store.Memo, 0)
for rows.Next() { for rows.Next() {
var memo store.Memo var memo store.Memo
if err := rows.Scan( dests := []any{
&memo.ID, &memo.ID,
&memo.CreatorID, &memo.CreatorID,
&memo.CreatedTs, &memo.CreatedTs,
&memo.UpdatedTs, &memo.UpdatedTs,
&memo.RowStatus, &memo.RowStatus,
&memo.Content,
&memo.Visibility, &memo.Visibility,
&memo.Pinned, &memo.Pinned,
); err != nil { }
if !find.ExcludeContent {
dests = append(dests, &memo.Content)
}
if err := rows.Scan(dests...); err != nil {
return nil, err return nil, err
} }
list = append(list, &memo) list = append(list, &memo)
} }
@ -174,51 +148,42 @@ func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, er
} }
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error { func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
// Start building the update statement set, args := []string{}, []any{}
builder := squirrel.Update("memo").
PlaceholderFormat(squirrel.Dollar).
Where("id = ?", update.ID)
// Conditionally add set clauses
if v := update.CreatedTs; v != nil { if v := update.CreatedTs; v != nil {
builder = builder.Set("created_ts", *v) set, args = append(set, "created_ts = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.UpdatedTs; v != nil { if v := update.UpdatedTs; v != nil {
builder = builder.Set("updated_ts", *v) set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.RowStatus; v != nil { if v := update.RowStatus; v != nil {
builder = builder.Set("row_status", *v) set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Content; v != nil { if v := update.Content; v != nil {
builder = builder.Set("content", *v) set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Visibility; v != nil { if v := update.Visibility; v != nil {
builder = builder.Set("visibility", *v) set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, *v)
} }
stmt := `UPDATE memo SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
// Prepare and execute the query args = append(args, update.ID)
query, args, err := builder.ToSql() if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
if err != nil {
return err return err
} }
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return err
}
return nil return nil
} }
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error { func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
stmt := `DELETE FROM memo WHERE id = $1` where, args := []string{"id = " + placeholder(1)}, []any{delete.ID}
result, err := d.db.ExecContext(ctx, stmt, delete.ID) stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
println("stmt", stmt, delete.ID)
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil { if err != nil {
return err return errors.Wrap(err, "failed to delete memo")
} }
if _, err := result.RowsAffected(); err != nil { if _, err := result.RowsAffected(); err != nil {
return err return err
} }
return d.Vacuum(ctx) return nil
} }
func vacuumMemo(ctx context.Context, tx *sql.Tx) error { func vacuumMemo(ctx context.Context, tx *sql.Tx) error {

View File

@ -4,8 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
@ -15,99 +14,94 @@ func (d *DB) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganize
if upsert.Pinned { if upsert.Pinned {
pinned = 1 pinned = 1
} }
stmt := "INSERT INTO memo_organizer (memo_id, user_id, pinned) VALUES ($1, $2, $3) ON CONFLICT (memo_id, user_id) DO UPDATE SET pinned = $4" stmt := `
if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, pinned, pinned); err != nil { INSERT INTO memo_organizer (
memo_id,
user_id,
pinned
)
VALUES (` + placeholders(3) + `)
ON CONFLICT(memo_id, user_id) DO UPDATE
SET pinned = EXCLUDED.pinned`
if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, pinned); err != nil {
return nil, err return nil, err
} }
return upsert, nil return upsert, nil
} }
func (d *DB) ListMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) ([]*store.MemoOrganizer, error) { func (d *DB) ListMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) ([]*store.MemoOrganizer, error) {
qb := squirrel.Select("memo_id", "user_id", "pinned"). where, args := []string{"1 = 1"}, []any{}
From("memo_organizer").
Where("1 = 1").
PlaceholderFormat(squirrel.Dollar)
if find.MemoID != 0 { if find.MemoID != 0 {
qb = qb.Where(squirrel.Eq{"memo_id": find.MemoID}) where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
} }
if find.UserID != 0 { if find.UserID != 0 {
qb = qb.Where(squirrel.Eq{"user_id": find.UserID}) where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, find.UserID)
}
query, args, err := qb.ToSql()
if err != nil {
return nil, err
} }
query := fmt.Sprintf(`
SELECT
memo_id,
user_id,
pinned
FROM memo_organizer
WHERE %s
`, strings.Join(where, " AND "))
rows, err := d.db.QueryContext(ctx, query, args...) rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var list []*store.MemoOrganizer list := []*store.MemoOrganizer{}
for rows.Next() { for rows.Next() {
memoOrganizer := &store.MemoOrganizer{} memoOrganizer := &store.MemoOrganizer{}
if err := rows.Scan(&memoOrganizer.MemoID, &memoOrganizer.UserID, &memoOrganizer.Pinned); err != nil { pinned := 0
if err := rows.Scan(
&memoOrganizer.MemoID,
&memoOrganizer.UserID,
&pinned,
); err != nil {
return nil, err return nil, err
} }
memoOrganizer.Pinned = pinned == 1
list = append(list, memoOrganizer) list = append(list, memoOrganizer)
} }
return list, rows.Err() if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
} }
func (d *DB) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error { func (d *DB) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error {
qb := squirrel.Delete("memo_organizer"). where, args := []string{}, []any{}
PlaceholderFormat(squirrel.Dollar)
if v := delete.MemoID; v != nil { if v := delete.MemoID; v != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *v}) where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := delete.UserID; v != nil { if v := delete.UserID; v != nil {
qb = qb.Where(squirrel.Eq{"user_id": *v}) where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *v)
} }
stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ")
stmt, args, err := qb.ToSql() if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
if err != nil {
return err return err
} }
return nil
}
if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil { func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
memo_organizer
WHERE
memo_id NOT IN (SELECT id FROM memo)
OR user_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err return err
} }
return nil return nil
} }
func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for memo_id
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Build the subquery for user_id
subQueryUser, subArgsUser, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subqueries
query, args, err := squirrel.Delete("memo_organizer").
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
Where(fmt.Sprintf("user_id NOT IN (%s)", subQueryUser), subArgsUser...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Combine the arguments from both subqueries
args = append(args, subArgsUser...)
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

View File

@ -3,127 +3,111 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt" "strings"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) { func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
qb := squirrel.Insert("memo_relation"). stmt := `
Columns("memo_id", "related_memo_id", "type"). INSERT INTO memo_relation (
Values(create.MemoID, create.RelatedMemoID, create.Type). memo_id,
Suffix("ON CONFLICT (version) DO NOTHING"). related_memo_id,
PlaceholderFormat(squirrel.Dollar) type
)
stmt, args, err := qb.ToSql() VALUES (` + placeholders(3) + `)
if err != nil { RETURNING memo_id, related_memo_id, type
`
memoRelation := &store.MemoRelation{}
if err := d.db.QueryRowContext(
ctx,
stmt,
create.MemoID,
create.RelatedMemoID,
create.Type,
).Scan(
&memoRelation.MemoID,
&memoRelation.RelatedMemoID,
&memoRelation.Type,
); err != nil {
return nil, err return nil, err
} }
_, err = d.db.ExecContext(ctx, stmt, args...) return memoRelation, nil
if err != nil {
return nil, err
}
return &store.MemoRelation{
MemoID: create.MemoID,
RelatedMemoID: create.RelatedMemoID,
Type: create.Type,
}, nil
} }
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) { func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
qb := squirrel.Select("memo_id", "related_memo_id", "type"). where, args := []string{"1 = 1"}, []any{}
From("memo_relation").
Where("TRUE").
PlaceholderFormat(squirrel.Dollar)
if find.MemoID != nil { if find.MemoID != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *find.MemoID}) where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
} }
if find.RelatedMemoID != nil { if find.RelatedMemoID != nil {
qb = qb.Where(squirrel.Eq{"related_memo_id": *find.RelatedMemoID}) where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID)
} }
if find.Type != nil { if find.Type != nil {
qb = qb.Where(squirrel.Eq{"type": *find.Type}) where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
} }
query, args, err := qb.ToSql() rows, err := d.db.QueryContext(ctx, `
if err != nil { SELECT
return nil, err memo_id,
} related_memo_id,
type
rows, err := d.db.QueryContext(ctx, query, args...) FROM memo_relation
WHERE `+strings.Join(where, " AND "), args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var list []*store.MemoRelation list := []*store.MemoRelation{}
for rows.Next() { for rows.Next() {
memoRelation := &store.MemoRelation{} memoRelation := &store.MemoRelation{}
if err := rows.Scan(&memoRelation.MemoID, &memoRelation.RelatedMemoID, &memoRelation.Type); err != nil { if err := rows.Scan(
&memoRelation.MemoID,
&memoRelation.RelatedMemoID,
&memoRelation.Type,
); err != nil {
return nil, err return nil, err
} }
list = append(list, memoRelation) list = append(list, memoRelation)
} }
return list, rows.Err() if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
} }
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error { func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
qb := squirrel.Delete("memo_relation"). where, args := []string{"1 = 1"}, []any{}
PlaceholderFormat(squirrel.Dollar)
if delete.MemoID != nil { if delete.MemoID != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *delete.MemoID}) where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID)
} }
if delete.RelatedMemoID != nil { if delete.RelatedMemoID != nil {
qb = qb.Where(squirrel.Eq{"related_memo_id": *delete.RelatedMemoID}) where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID)
} }
if delete.Type != nil { if delete.Type != nil {
qb = qb.Where(squirrel.Eq{"type": *delete.Type}) where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type)
} }
stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ")
stmt, args, err := qb.ToSql()
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, stmt, args...) result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil { if err != nil {
return err return err
} }
if _, err = result.RowsAffected(); err != nil {
_, err = result.RowsAffected()
return err return err
}
return nil
} }
func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for memo_id if _, err := tx.ExecContext(ctx, `
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql() DELETE FROM memo_relation
if err != nil { WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo)
`); err != nil {
return err return err
} }
return nil
// Note: The same subquery is used for related_memo_id as it's also checking against the "memo" table
// Now, build the main delete query using the subqueries
query, args, err := squirrel.Delete("memo_relation").
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
Where(fmt.Sprintf("related_memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Combine the arguments for both instances of the same subquery
args = append(args, subArgsMemo...)
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
} }

View File

@ -3,22 +3,12 @@ package postgres
import ( import (
"context" "context"
"github.com/Masterminds/squirrel"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) { func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
qb := squirrel.Select("version", "created_ts"). query := "SELECT version, created_ts FROM migration_history ORDER BY created_ts DESC"
From("migration_history"). rows, err := d.db.QueryContext(ctx, query)
OrderBy("created_ts DESC")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -27,9 +17,13 @@ func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigratio
list := make([]*store.MigrationHistory, 0) list := make([]*store.MigrationHistory, 0)
for rows.Next() { for rows.Next() {
var migrationHistory store.MigrationHistory var migrationHistory store.MigrationHistory
if err := rows.Scan(&migrationHistory.Version, &migrationHistory.CreatedTs); err != nil { if err := rows.Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err return nil, err
} }
list = append(list, &migrationHistory) list = append(list, &migrationHistory)
} }
@ -41,33 +35,21 @@ func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigratio
} }
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) { func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
qb := squirrel.Insert("migration_history"). stmt := `
Columns("version"). INSERT INTO migration_history (
Values(upsert.Version). version
Suffix("ON CONFLICT (version) DO NOTHING"). )
PlaceholderFormat(squirrel.Dollar) VALUES ($1)
ON CONFLICT(version) DO UPDATE
query, args, err := qb.ToSql() SET
if err != nil { version=EXCLUDED.version
return nil, err RETURNING version, created_ts
} `
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
var migrationHistory store.MigrationHistory var migrationHistory store.MigrationHistory
query, args, err = squirrel.Select("version", "created_ts"). if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
From("migration_history"). &migrationHistory.Version,
Where(squirrel.Eq{"version": upsert.Version}). &migrationHistory.CreatedTs,
PlaceholderFormat(squirrel.Dollar). ); err != nil {
ToSql()
if err != nil {
return nil, err
}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(&migrationHistory.Version, &migrationHistory.CreatedTs); err != nil {
return nil, err return nil, err
} }

View File

@ -4,77 +4,61 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"time" "strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (d *DB) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) { func (d *DB) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) {
qb := squirrel.Insert("resource").Columns("filename", "blob", "external_link", "type", "size", "creator_id", "internal_path", "memo_id") fields := []string{"filename", "blob", "external_link", "type", "size", "creator_id", "internal_path", "memo_id"}
values := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, create.MemoID} args := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, create.MemoID}
qb = qb.Values(values...).Suffix("RETURNING id") stmt := "INSERT INTO resource (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
if err != nil {
return nil, err return nil, err
} }
return create, nil
var id int32
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
list, err := d.ListResources(ctx, &store.FindResource{ID: &id})
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
}
return list[0], nil
} }
func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) { func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) {
qb := squirrel.Select("id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id").From("resource") where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
qb = qb.Where(squirrel.Eq{"id": *v}) where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.CreatorID; v != nil { if v := find.CreatorID; v != nil {
qb = qb.Where(squirrel.Eq{"creator_id": *v}) where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Filename; v != nil { if v := find.Filename; v != nil {
qb = qb.Where(squirrel.Eq{"filename": *v}) where, args = append(where, "filename = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.MemoID; v != nil { if v := find.MemoID; v != nil {
qb = qb.Where(squirrel.Eq{"memo_id": *v}) where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
} }
if find.HasRelatedMemo { if find.HasRelatedMemo {
qb = qb.Where("memo_id IS NOT NULL") where = append(where, "memo_id IS NOT NULL")
} }
fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id"}
if find.GetBlob { if find.GetBlob {
qb = qb.Columns("blob") fields = append(fields, "blob")
} }
qb = qb.GroupBy("id").OrderBy("created_ts DESC") query := fmt.Sprintf(`
SELECT
%s
FROM resource
WHERE %s
GROUP BY id
ORDER BY created_ts DESC
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
if find.Limit != nil { if find.Limit != nil {
qb = qb.Limit(uint64(*find.Limit)) query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil { if find.Offset != nil {
qb = qb.Offset(uint64(*find.Offset)) query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
} }
} }
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...) rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -103,7 +87,6 @@ func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*st
if err := rows.Scan(dests...); err != nil { if err := rows.Scan(dests...); err != nil {
return nil, err return nil, err
} }
if memoID.Valid { if memoID.Valid {
resource.MemoID = &memoID.Int32 resource.MemoID = &memoID.Int32
} }
@ -118,88 +101,72 @@ func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*st
} }
func (d *DB) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) { func (d *DB) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) {
qb := squirrel.Update("resource") set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil { if v := update.UpdatedTs; v != nil {
qb = qb.Set("updated_ts", time.Unix(0, *v)) set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Filename; v != nil { if v := update.Filename; v != nil {
qb = qb.Set("filename", *v) set, args = append(set, "filename = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.InternalPath; v != nil { if v := update.InternalPath; v != nil {
qb = qb.Set("internal_path", *v) set, args = append(set, "internal_path = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.MemoID; v != nil { if v := update.MemoID; v != nil {
qb = qb.Set("memo_id", *v) set, args = append(set, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Blob; v != nil { if v := update.Blob; v != nil {
qb = qb.Set("blob", v) set, args = append(set, "blob = "+placeholder(len(args)+1)), append(args, v)
} }
qb = qb.Where(squirrel.Eq{"id": update.ID}) fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"}
stmt := `
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() UPDATE resource
if err != nil { SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + `
RETURNING ` + strings.Join(fields, ", ")
args = append(args, update.ID)
resource := store.Resource{}
dests := []any{
&resource.ID,
&resource.Filename,
&resource.ExternalLink,
&resource.Type,
&resource.Size,
&resource.CreatorID,
&resource.CreatedTs,
&resource.UpdatedTs,
&resource.InternalPath,
}
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(dests...); err != nil {
return nil, err return nil, err
} }
if _, err := d.db.ExecContext(ctx, query, args...); err != nil { return &resource, nil
return nil, err
}
list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID})
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
}
return list[0], nil
} }
func (d *DB) DeleteResource(ctx context.Context, delete *store.DeleteResource) error { func (d *DB) DeleteResource(ctx context.Context, delete *store.DeleteResource) error {
qb := squirrel.Delete("resource").Where(squirrel.Eq{"id": delete.ID}) stmt := `DELETE FROM resource WHERE id = $1`
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil { if err != nil {
return err return err
} }
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil { if _, err := result.RowsAffected(); err != nil {
return err return err
} }
return nil
}
if err := d.Vacuum(ctx); err != nil { func vacuumResource(ctx context.Context, tx *sql.Tx) error {
// Prevent linter warning. stmt := `
DELETE FROM
resource
WHERE
creator_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err return err
} }
return nil return nil
} }
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
// Now, build the main delete query using the subquery
query, args, err := squirrel.Delete("resource").
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
}

View File

@ -3,7 +3,6 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/Masterminds/squirrel" "github.com/Masterminds/squirrel"
@ -82,22 +81,15 @@ func (d *DB) DeleteTag(ctx context.Context, delete *store.DeleteTag) error {
} }
func vacuumTag(ctx context.Context, tx *sql.Tx) error { func vacuumTag(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery for creator_id stmt := `
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql() DELETE FROM
tag
WHERE
creator_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil { if err != nil {
return err return err
} }
// Now, build the main delete query using the subquery return nil
query, args, err := squirrel.Delete("tag").
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
} }

View File

@ -2,132 +2,113 @@ package postgres
import ( import (
"context" "context"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) { func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
// Start building the insert statement fields := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
builder := squirrel.Insert(`"user"`).PlaceholderFormat(squirrel.Dollar) args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
stmt := "INSERT INTO \"user\" (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, avatar_url, created_ts, updated_ts, row_status"
columns := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"} if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
builder = builder.Columns(columns...) &create.ID,
&create.AvatarURL,
values := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL} &create.CreatedTs,
&create.UpdatedTs,
builder = builder.Values(values...) &create.RowStatus,
builder = builder.Suffix("RETURNING id") ); err != nil {
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil {
return nil, err return nil, err
} }
// Execute the query and get the returned ID return create, nil
var id int32
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
if err != nil {
return nil, err
}
// Use the returned ID to retrieve the full user object
user, err := d.GetUser(ctx, &store.FindUser{ID: &id})
if err != nil {
return nil, err
}
return user, nil
} }
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
// Start building the update statement set, args := []string{}, []any{}
builder := squirrel.Update(`"user"`).PlaceholderFormat(squirrel.Dollar)
// Conditionally add set clauses
if v := update.UpdatedTs; v != nil { if v := update.UpdatedTs; v != nil {
builder = builder.Set("updated_ts", *v) set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.RowStatus; v != nil { if v := update.RowStatus; v != nil {
builder = builder.Set("row_status", *v) set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Username; v != nil { if v := update.Username; v != nil {
builder = builder.Set("username", *v) set, args = append(set, "username = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Email; v != nil { if v := update.Email; v != nil {
builder = builder.Set("email", *v) set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Nickname; v != nil { if v := update.Nickname; v != nil {
builder = builder.Set("nickname", *v) set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.AvatarURL; v != nil { if v := update.AvatarURL; v != nil {
builder = builder.Set("avatar_url", *v) set, args = append(set, "avatar_url = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.PasswordHash; v != nil { if v := update.PasswordHash; v != nil {
builder = builder.Set("password_hash", *v) set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
} }
// Add the WHERE clause query := `
builder = builder.Where(squirrel.Eq{"id": update.ID}) UPDATE "user"
SET ` + strings.Join(set, ", ") + `
// Prepare the final query WHERE id = ` + placeholder(len(args)+1) + `
query, args, err := builder.ToSql() RETURNING id, username, role, email, nickname, password_hash, avatar_url, created_ts, updated_ts, row_status
if err != nil { `
args = append(args, update.ID)
user := &store.User{}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.Username,
&user.Role,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.AvatarURL,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
); err != nil {
return nil, err return nil, err
} }
// Execute the query with the context
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
// Retrieve the updated user
user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
if err != nil {
return nil, err
}
return user, nil return user, nil
} }
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
// Start building the SELECT statement where, args := []string{"1 = 1"}, []any{}
builder := squirrel.Select("id", "username", "role", "email", "nickname", "password_hash", "avatar_url", "created_ts", "updated_ts", "row_status").
From(`"user"`).
PlaceholderFormat(squirrel.Dollar)
// 1 = 1 is often used as a no-op in SQL, ensuring there's always a WHERE clause
builder = builder.Where("1 = 1")
// Conditionally add where clauses
if v := find.ID; v != nil { if v := find.ID; v != nil {
builder = builder.Where(squirrel.Eq{"id": *v}) where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Username; v != nil { if v := find.Username; v != nil {
builder = builder.Where(squirrel.Eq{"username": *v}) where, args = append(where, "username = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Role; v != nil { if v := find.Role; v != nil {
builder = builder.Where(squirrel.Eq{"role": *v}) where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Email; v != nil { if v := find.Email; v != nil {
builder = builder.Where(squirrel.Eq{"email": *v}) where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Nickname; v != nil { if v := find.Nickname; v != nil {
builder = builder.Where(squirrel.Eq{"nickname": *v}) where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
} }
// Add ordering query := `
builder = builder.OrderBy("created_ts DESC", "row_status DESC") SELECT
id,
// Prepare the final query username,
query, args, err := builder.ToSql() role,
if err != nil { email,
return nil, err nickname,
} password_hash,
avatar_url,
// Execute the query with the context created_ts,
updated_ts,
row_status
FROM "user"
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := d.db.QueryContext(ctx, query, args...) rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -161,35 +142,13 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
return list, nil return list, nil
} }
func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
list, err := d.ListUsers(ctx, find)
if err != nil {
return nil, err
}
if len(list) != 1 {
return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
}
return list[0], nil
}
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
// Start building the DELETE statement result, err := d.db.ExecContext(ctx, `
builder := squirrel.Delete(`"user"`). DELETE FROM "user" WHERE id = $1
PlaceholderFormat(squirrel.Dollar). `, delete.ID)
Where(squirrel.Eq{"id": delete.ID})
// Prepare the final query
query, args, err := builder.ToSql()
if err != nil { if err != nil {
return err return err
} }
// Execute the query with the context
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil { if _, err := result.RowsAffected(); err != nil {
return err return err
} }

View File

@ -3,7 +3,6 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/Masterminds/squirrel" "github.com/Masterminds/squirrel"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -130,22 +129,15 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
} }
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
// First, build the subquery stmt := `
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql() DELETE FROM
user_setting
WHERE
user_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil { if err != nil {
return err return err
} }
// Now, build the main delete query using the subquery return nil
query, args, err := squirrel.Delete("user_setting").
Where(fmt.Sprintf("user_id NOT IN (%s)", subQuery), subArgs...).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return err
}
// Execute the query
_, err = tx.ExecContext(ctx, query, args...)
return err
} }

View File

@ -38,7 +38,6 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if find.ID != nil { if find.ID != nil {
where, args = append(where, "`id` = ?"), append(args, *find.ID) where, args = append(where, "`id` = ?"), append(args, *find.ID)
} }

View File

@ -0,0 +1,62 @@
package teststore
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestMemoOrganizerStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoCreate := &store.Memo{
CreatorID: user.ID,
Content: "main memo content",
Visibility: store.Public,
}
memo, err := ts.CreateMemo(ctx, memoCreate)
require.NoError(t, err)
require.Equal(t, memoCreate.Content, memo.Content)
memoOrganizer, err := ts.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
MemoID: memo.ID,
UserID: user.ID,
Pinned: true,
})
require.NoError(t, err)
require.NotNil(t, memoOrganizer)
require.Equal(t, memo.ID, memoOrganizer.MemoID)
require.Equal(t, user.ID, memoOrganizer.UserID)
require.Equal(t, true, memoOrganizer.Pinned)
memoOrganizerTemp, err := ts.GetMemoOrganizer(ctx, &store.FindMemoOrganizer{
MemoID: memo.ID,
})
require.NoError(t, err)
require.Equal(t, memoOrganizer, memoOrganizerTemp)
memoOrganizerTemp, err = ts.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
MemoID: memo.ID,
UserID: user.ID,
Pinned: false,
})
require.NoError(t, err)
require.NotNil(t, memoOrganizerTemp)
require.Equal(t, memo.ID, memoOrganizerTemp.MemoID)
require.Equal(t, user.ID, memoOrganizerTemp.UserID)
require.Equal(t, false, memoOrganizerTemp.Pinned)
err = ts.DeleteMemoOrganizer(ctx, &store.DeleteMemoOrganizer{
MemoID: &memo.ID,
UserID: &user.ID,
})
require.NoError(t, err)
memoOrganizers, err := ts.ListMemoOrganizer(ctx, &store.FindMemoOrganizer{
UserID: user.ID,
})
require.NoError(t, err)
require.Equal(t, 0, len(memoOrganizers))
}

View File

@ -59,3 +59,22 @@ func TestMemoStore(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(memoList)) require.Equal(t, 0, len(memoList))
} }
func TestDeleteMemoStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
memoCreate := &store.Memo{
CreatorID: user.ID,
Content: "test_content",
Visibility: store.Public,
}
memo, err := ts.CreateMemo(ctx, memoCreate)
require.NoError(t, err)
require.Equal(t, memoCreate.Content, memo.Content)
err = ts.DeleteMemo(ctx, &store.DeleteMemo{
ID: memo.ID,
})
require.NoError(t, err)
}

View File

@ -8,6 +8,7 @@ import (
// sqlite driver. // sqlite driver.
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
"github.com/usememos/memos/store/db" "github.com/usememos/memos/store/db"
"github.com/usememos/memos/test" "github.com/usememos/memos/test"
@ -19,6 +20,7 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
if err != nil { if err != nil {
fmt.Printf("failed to create db driver, error: %+v\n", err) fmt.Printf("failed to create db driver, error: %+v\n", err)
} }
resetTestingDB(ctx, profile, dbDriver)
if err := dbDriver.Migrate(ctx); err != nil { if err := dbDriver.Migrate(ctx); err != nil {
fmt.Printf("failed to migrate db, error: %+v\n", err) fmt.Printf("failed to migrate db, error: %+v\n", err)
} }
@ -26,3 +28,13 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
store := store.New(dbDriver, profile) store := store.New(dbDriver, profile)
return store return store
} }
func resetTestingDB(ctx context.Context, profile *profile.Profile, dbDriver store.Driver) {
if profile.Driver == "postgres" {
_, err := dbDriver.GetDB().ExecContext(ctx, `DROP SCHEMA public CASCADE; CREATE SCHEMA public;`)
if err != nil {
fmt.Printf("failed to reset testing db, error: %+v\n", err)
panic(err)
}
}
}