mirror of
https://github.com/usememos/memos.git
synced 2025-02-16 03:12:13 +01:00
chore: fix postgres stmts
This commit is contained in:
parent
ee13927607
commit
501f8898f6
@ -2,8 +2,8 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
@ -21,50 +21,29 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
|
||||
payloadString = string(bytes)
|
||||
}
|
||||
|
||||
qb := squirrel.Insert("activity").
|
||||
Columns("creator_id", "type", "level", "payload").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
qb = qb.Values(values...).Suffix("RETURNING id")
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to construct query")
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to execute statement and retrieve ID")
|
||||
}
|
||||
|
||||
list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id})
|
||||
if err != nil || len(list) == 0 {
|
||||
return nil, errors.Wrap(err, "failed to find activity")
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
qb := squirrel.Select("id", "created_ts", "creator_id", "type", "level", "payload").
|
||||
From("activity").
|
||||
Where("1 = 1").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
if find.ID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *find.ID})
|
||||
}
|
||||
if find.Type != nil {
|
||||
qb = qb.Where(squirrel.Eq{"type": find.Type.String()})
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
fields := []string{"creator_id", "type", "level", "payload"}
|
||||
args := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString}
|
||||
stmt := "INSERT INTO activity (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
|
||||
}
|
||||
|
||||
query := "SELECT id, creator_id, type, level, payload, created_ts FROM activity WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -77,17 +56,17 @@ func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*s
|
||||
var payloadBytes []byte
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatedTs,
|
||||
&activity.CreatorID,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&payloadBytes,
|
||||
&activity.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := &storepb.ActivityPayload{}
|
||||
if err := protojson.Unmarshal(payloadBytes, payload); err != nil {
|
||||
if err := protojsonUnmarshaler.Unmarshal(payloadBytes, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activity.Payload = payload
|
||||
|
@ -1,9 +1,26 @@
|
||||
package postgres
|
||||
|
||||
import "google.golang.org/protobuf/encoding/protojson"
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
var (
|
||||
protojsonUnmarshaler = protojson.UnmarshalOptions{
|
||||
DiscardUnknown: true,
|
||||
}
|
||||
)
|
||||
|
||||
func placeholder(n int) string {
|
||||
return "$" + fmt.Sprint(n)
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
list := []string{}
|
||||
for i := 0; i < n; i++ {
|
||||
list = append(list, placeholder(i+1))
|
||||
}
|
||||
return strings.Join(list, ", ")
|
||||
}
|
||||
|
@ -3,8 +3,8 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
@ -22,42 +22,34 @@ func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityP
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
|
||||
}
|
||||
|
||||
qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config")
|
||||
values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}
|
||||
|
||||
qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar)
|
||||
qb = qb.Suffix("RETURNING id")
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
fields := []string{"name", "type", "identifier_filter", "config"}
|
||||
args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)}
|
||||
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
create.ID = id
|
||||
return create, nil
|
||||
identityProvider := create
|
||||
return identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) {
|
||||
qb := squirrel.Select("id", "name", "type", "identifier_filter", "config").
|
||||
From("idp").
|
||||
Where("1 = 1").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *v})
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
identifier_filter,
|
||||
config
|
||||
FROM idp
|
||||
WHERE `+strings.Join(where, " AND ")+` ORDER BY id ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -111,15 +103,12 @@ func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityPr
|
||||
}
|
||||
|
||||
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
|
||||
qb := squirrel.Update("idp").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
var err error
|
||||
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.Name; v != nil {
|
||||
qb = qb.Set("name", *v)
|
||||
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.IdentifierFilter; v != nil {
|
||||
qb = qb.Set("identifier_filter", *v)
|
||||
set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Config; v != nil {
|
||||
var configBytes []byte
|
||||
@ -132,42 +121,53 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
|
||||
}
|
||||
qb = qb.Set("config", string(configBytes))
|
||||
set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, string(configBytes))
|
||||
}
|
||||
|
||||
qb = qb.Where(squirrel.Eq{"id": update.ID})
|
||||
stmt := `
|
||||
UPDATE idp
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, name, type, identifier_filter, config
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
var identityProvider store.IdentityProvider
|
||||
var identityProviderConfig string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&identityProvider.ID,
|
||||
&identityProvider.Name,
|
||||
&identityProvider.Type,
|
||||
&identityProvider.IdentifierFilter,
|
||||
&identityProviderConfig,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if identityProvider.Type == store.IdentityProviderOAuth2Type {
|
||||
oauth2Config := &store.IdentityProviderOAuth2Config{}
|
||||
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identityProvider.Config = &store.IdentityProviderConfig{
|
||||
OAuth2Config: oauth2Config,
|
||||
}
|
||||
} else {
|
||||
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
|
||||
}
|
||||
|
||||
return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID})
|
||||
return &identityProvider, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error {
|
||||
qb := squirrel.Delete("idp").
|
||||
Where(squirrel.Eq{"id": delete.ID}).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
where, args := []string{"id = $1"}, []any{delete.ID}
|
||||
stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -2,8 +2,8 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
@ -21,61 +21,54 @@ func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox
|
||||
messageString = string(bytes)
|
||||
}
|
||||
|
||||
qb := squirrel.Insert("inbox").
|
||||
Columns("sender_id", "receiver_id", "status", "message").
|
||||
Values(create.SenderID, create.ReceiverID, create.Status, messageString).
|
||||
Suffix("RETURNING id").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
fields := []string{"sender_id", "receiver_id", "status", "message"}
|
||||
args := []any{create.SenderID, create.ReceiverID, create.Status, messageString}
|
||||
stmt := "INSERT INTO inbox (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.GetInbox(ctx, &store.FindInbox{ID: &id})
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) {
|
||||
qb := squirrel.Select("id", "created_ts", "sender_id", "receiver_id", "status", "message").
|
||||
From("inbox").
|
||||
Where("1 = 1").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *find.ID})
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
|
||||
}
|
||||
if find.SenderID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"sender_id": *find.SenderID})
|
||||
where, args = append(where, "sender_id = "+placeholder(len(args)+1)), append(args, *find.SenderID)
|
||||
}
|
||||
if find.ReceiverID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"receiver_id": *find.ReceiverID})
|
||||
where, args = append(where, "receiver_id = "+placeholder(len(args)+1)), append(args, *find.ReceiverID)
|
||||
}
|
||||
if find.Status != nil {
|
||||
qb = qb.Where(squirrel.Eq{"status": *find.Status})
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
where, args = append(where, "status = "+placeholder(len(args)+1)), append(args, *find.Status)
|
||||
}
|
||||
|
||||
query := "SELECT id, created_ts, sender_id, receiver_id, status, message FROM inbox WHERE " + strings.Join(where, " AND ") + " ORDER BY created_ts DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var list []*store.Inbox
|
||||
list := []*store.Inbox{}
|
||||
for rows.Next() {
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := rows.Scan(&inbox.ID, &inbox.CreatedTs, &inbox.SenderID, &inbox.ReceiverID, &inbox.Status, &messageBytes); err != nil {
|
||||
if err := rows.Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -87,7 +80,11 @@ func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.I
|
||||
list = append(list, inbox)
|
||||
}
|
||||
|
||||
return list, rows.Err()
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) {
|
||||
@ -102,39 +99,36 @@ func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox,
|
||||
}
|
||||
|
||||
func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) {
|
||||
qb := squirrel.Update("inbox").
|
||||
Set("status", update.Status.String()).
|
||||
Where(squirrel.Eq{"id": update.ID}).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
set, args := []string{"status = $1"}, []any{update.Status.String()}
|
||||
args = append(args, update.ID)
|
||||
query := "UPDATE inbox SET " + strings.Join(set, ", ") + " WHERE id = $2 RETURNING id, created_ts, sender_id, receiver_id, status, message"
|
||||
inbox := &store.Inbox{}
|
||||
var messageBytes []byte
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&inbox.ID,
|
||||
&inbox.CreatedTs,
|
||||
&inbox.SenderID,
|
||||
&inbox.ReceiverID,
|
||||
&inbox.Status,
|
||||
&messageBytes,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
message := &storepb.InboxMessage{}
|
||||
if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return d.GetInbox(ctx, &store.FindInbox{ID: &update.ID})
|
||||
inbox.Message = message
|
||||
return inbox, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error {
|
||||
qb := squirrel.Delete("inbox").
|
||||
Where(squirrel.Eq{"id": delete.ID}).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
result, err := d.db.ExecContext(ctx, "DELETE FROM inbox WHERE id = $1", delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = result.RowsAffected()
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
@ -3,153 +3,127 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) {
|
||||
// Initialize a Squirrel statement builder for PostgreSQL
|
||||
builder := squirrel.Insert("memo").
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
Columns("creator_id", "content", "visibility")
|
||||
fields := []string{"creator_id", "content", "visibility"}
|
||||
args := []any{create.CreatorID, create.Content, create.Visibility}
|
||||
|
||||
// Add initial values for the columns
|
||||
values := []any{create.CreatorID, create.Content, create.Visibility}
|
||||
|
||||
// Add all the values at once
|
||||
builder = builder.Values(values...)
|
||||
|
||||
// Add the RETURNING clause to get the ID of the inserted row
|
||||
builder = builder.Suffix("RETURNING id")
|
||||
|
||||
// Prepare and execute the query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
stmt := "INSERT INTO memo (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Retrieve the newly created memo
|
||||
memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, errors.Errorf("failed to create memo")
|
||||
}
|
||||
|
||||
return memo, nil
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
|
||||
// Start building the SELECT statement
|
||||
builder := squirrel.Select(
|
||||
"memo.id AS id",
|
||||
"memo.creator_id AS creator_id",
|
||||
"memo.created_ts AS created_ts",
|
||||
"memo.updated_ts AS updated_ts",
|
||||
"memo.row_status AS row_status",
|
||||
"memo.content AS content",
|
||||
"memo.visibility AS visibility",
|
||||
"MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned").
|
||||
From("memo").
|
||||
LeftJoin("memo_organizer ON memo.id = memo_organizer.memo_id").
|
||||
LeftJoin("resource ON memo.id = resource.memo_id").
|
||||
GroupBy("memo.id").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
// Add conditional where clauses
|
||||
if v := find.ID; v != nil {
|
||||
builder = builder.Where("memo.id = ?", *v)
|
||||
where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
builder = builder.Where("memo.creator_id = ?", *v)
|
||||
where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
builder = builder.Where("memo.row_status = ?", *v)
|
||||
where, args = append(where, "memo.row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsBefore; v != nil {
|
||||
builder = builder.Where("memo.created_ts < ?", *v)
|
||||
where, args = append(where, "memo.created_ts < "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatedTsAfter; v != nil {
|
||||
builder = builder.Where("memo.created_ts > ?", *v)
|
||||
}
|
||||
if v := find.Pinned; v != nil {
|
||||
builder = builder.Where("memo_organizer.pinned = 1")
|
||||
where, args = append(where, "memo.created_ts > "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.ContentSearch; len(v) != 0 {
|
||||
for _, s := range v {
|
||||
builder = builder.Where("memo.content LIKE ?", "%"+s+"%")
|
||||
where, args = append(where, "memo.content LIKE "+placeholder(len(args)+1)), append(args, fmt.Sprintf("%%%s%%", s))
|
||||
}
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
holders := []string{}
|
||||
for _, visibility := range v {
|
||||
holders = append(holders, placeholder(len(args)+1))
|
||||
args = append(args, visibility.String())
|
||||
}
|
||||
where = append(where, fmt.Sprintf("memo.visibility in (%s)", strings.Join(holders, ", ")))
|
||||
}
|
||||
if v := find.Pinned; v != nil {
|
||||
where = append(where, "memo_organizer.pinned = 1")
|
||||
}
|
||||
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
placeholders := make([]string, len(v))
|
||||
args := make([]any, len(v))
|
||||
for i, visibility := range v {
|
||||
placeholders[i] = "?"
|
||||
args[i] = visibility // Assuming visibility can be directly used as an argument
|
||||
}
|
||||
inClause := strings.Join(placeholders, ",")
|
||||
builder = builder.Where("memo.visibility IN ("+inClause+")", args...)
|
||||
}
|
||||
// Add order by clauses
|
||||
orders := []string{}
|
||||
if find.OrderByPinned {
|
||||
builder = builder.OrderBy("pinned DESC")
|
||||
orders = append(orders, "pinned DESC")
|
||||
}
|
||||
if find.OrderByUpdatedTs {
|
||||
builder = builder.OrderBy("updated_ts DESC")
|
||||
orders = append(orders, "updated_ts DESC")
|
||||
} else {
|
||||
builder = builder.OrderBy("created_ts DESC")
|
||||
orders = append(orders, "created_ts DESC")
|
||||
}
|
||||
builder = builder.OrderBy("id DESC")
|
||||
orders = append(orders, "id DESC")
|
||||
|
||||
// Handle pagination
|
||||
fields := []string{
|
||||
`memo.id AS id`,
|
||||
`memo.creator_id AS creator_id`,
|
||||
`memo.created_ts AS created_ts`,
|
||||
`memo.updated_ts AS updated_ts`,
|
||||
`memo.row_status AS row_status`,
|
||||
`memo.visibility AS visibility`,
|
||||
`MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned`,
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
fields = append(fields, `memo.content AS content`)
|
||||
}
|
||||
|
||||
query := `SELECT ` + strings.Join(fields, ", ") + `
|
||||
FROM memo
|
||||
LEFT JOIN memo_organizer ON memo.id = memo_organizer.memo_id
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
GROUP BY memo.id
|
||||
ORDER BY ` + strings.Join(orders, ", ")
|
||||
if find.Limit != nil {
|
||||
builder = builder.Limit(uint64(*find.Limit))
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
builder = builder.Offset(uint64(*find.Offset))
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare and execute the query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Process the result set
|
||||
list := make([]*store.Memo, 0)
|
||||
for rows.Next() {
|
||||
var memo store.Memo
|
||||
if err := rows.Scan(
|
||||
dests := []any{
|
||||
&memo.ID,
|
||||
&memo.CreatorID,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&memo.RowStatus,
|
||||
&memo.Content,
|
||||
&memo.Visibility,
|
||||
&memo.Pinned,
|
||||
); err != nil {
|
||||
}
|
||||
if !find.ExcludeContent {
|
||||
dests = append(dests, &memo.Content)
|
||||
}
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, &memo)
|
||||
}
|
||||
|
||||
@ -174,51 +148,42 @@ func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, er
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error {
|
||||
// Start building the update statement
|
||||
builder := squirrel.Update("memo").
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
Where("id = ?", update.ID)
|
||||
|
||||
// Conditionally add set clauses
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.CreatedTs; v != nil {
|
||||
builder = builder.Set("created_ts", *v)
|
||||
set, args = append(set, "created_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
builder = builder.Set("updated_ts", *v)
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
builder = builder.Set("row_status", *v)
|
||||
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Content; v != nil {
|
||||
builder = builder.Set("content", *v)
|
||||
set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Visibility; v != nil {
|
||||
builder = builder.Set("visibility", *v)
|
||||
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
// Prepare and execute the query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
stmt := `UPDATE memo SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1)
|
||||
args = append(args, update.ID)
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
stmt := `DELETE FROM memo WHERE id = $1`
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
where, args := []string{"id = " + placeholder(1)}, []any{delete.ID}
|
||||
stmt := `DELETE FROM memo WHERE ` + strings.Join(where, " AND ")
|
||||
println("stmt", stmt, delete.ID)
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "failed to delete memo")
|
||||
}
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return d.Vacuum(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
|
||||
|
@ -4,8 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
@ -15,99 +14,94 @@ func (d *DB) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganize
|
||||
if upsert.Pinned {
|
||||
pinned = 1
|
||||
}
|
||||
stmt := "INSERT INTO memo_organizer (memo_id, user_id, pinned) VALUES ($1, $2, $3) ON CONFLICT (memo_id, user_id) DO UPDATE SET pinned = $4"
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, pinned, pinned); err != nil {
|
||||
stmt := `
|
||||
INSERT INTO memo_organizer (
|
||||
memo_id,
|
||||
user_id,
|
||||
pinned
|
||||
)
|
||||
VALUES (` + placeholders(3) + `)
|
||||
ON CONFLICT(memo_id, user_id) DO UPDATE
|
||||
SET pinned = EXCLUDED.pinned`
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.MemoID, upsert.UserID, pinned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return upsert, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) ([]*store.MemoOrganizer, error) {
|
||||
qb := squirrel.Select("memo_id", "user_id", "pinned").
|
||||
From("memo_organizer").
|
||||
Where("1 = 1").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.MemoID != 0 {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": find.MemoID})
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
|
||||
}
|
||||
if find.UserID != 0 {
|
||||
qb = qb.Where(squirrel.Eq{"user_id": find.UserID})
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, find.UserID)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
memo_id,
|
||||
user_id,
|
||||
pinned
|
||||
FROM memo_organizer
|
||||
WHERE %s
|
||||
`, strings.Join(where, " AND "))
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var list []*store.MemoOrganizer
|
||||
list := []*store.MemoOrganizer{}
|
||||
for rows.Next() {
|
||||
memoOrganizer := &store.MemoOrganizer{}
|
||||
if err := rows.Scan(&memoOrganizer.MemoID, &memoOrganizer.UserID, &memoOrganizer.Pinned); err != nil {
|
||||
pinned := 0
|
||||
if err := rows.Scan(
|
||||
&memoOrganizer.MemoID,
|
||||
&memoOrganizer.UserID,
|
||||
&pinned,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
memoOrganizer.Pinned = pinned == 1
|
||||
list = append(list, memoOrganizer)
|
||||
}
|
||||
|
||||
return list, rows.Err()
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error {
|
||||
qb := squirrel.Delete("memo_organizer").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
where, args := []string{}, []any{}
|
||||
if v := delete.MemoID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": *v})
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := delete.UserID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"user_id": *v})
|
||||
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
stmt := `DELETE FROM memo_organizer WHERE ` + strings.Join(where, " AND ")
|
||||
if _, err := d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil {
|
||||
func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
memo_organizer
|
||||
WHERE
|
||||
memo_id NOT IN (SELECT id FROM memo)
|
||||
OR user_id NOT IN (SELECT id FROM "user")`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery for memo_id
|
||||
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the subquery for user_id
|
||||
subQueryUser, subArgsUser, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subqueries
|
||||
query, args, err := squirrel.Delete("memo_organizer").
|
||||
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
|
||||
Where(fmt.Sprintf("user_id NOT IN (%s)", subQueryUser), subArgsUser...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine the arguments from both subqueries
|
||||
args = append(args, subArgsUser...)
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
@ -3,127 +3,111 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) {
|
||||
qb := squirrel.Insert("memo_relation").
|
||||
Columns("memo_id", "related_memo_id", "type").
|
||||
Values(create.MemoID, create.RelatedMemoID, create.Type).
|
||||
Suffix("ON CONFLICT (version) DO NOTHING").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
stmt := `
|
||||
INSERT INTO memo_relation (
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
)
|
||||
VALUES (` + placeholders(3) + `)
|
||||
RETURNING memo_id, related_memo_id, type
|
||||
`
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := d.db.QueryRowContext(
|
||||
ctx,
|
||||
stmt,
|
||||
create.MemoID,
|
||||
create.RelatedMemoID,
|
||||
create.Type,
|
||||
).Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &store.MemoRelation{
|
||||
MemoID: create.MemoID,
|
||||
RelatedMemoID: create.RelatedMemoID,
|
||||
Type: create.Type,
|
||||
}, nil
|
||||
return memoRelation, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
|
||||
qb := squirrel.Select("memo_id", "related_memo_id", "type").
|
||||
From("memo_relation").
|
||||
Where("TRUE").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.MemoID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": *find.MemoID})
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
|
||||
}
|
||||
if find.RelatedMemoID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"related_memo_id": *find.RelatedMemoID})
|
||||
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID)
|
||||
}
|
||||
if find.Type != nil {
|
||||
qb = qb.Where(squirrel.Eq{"type": *find.Type})
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
|
||||
}
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
memo_id,
|
||||
related_memo_id,
|
||||
type
|
||||
FROM memo_relation
|
||||
WHERE `+strings.Join(where, " AND "), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var list []*store.MemoRelation
|
||||
list := []*store.MemoRelation{}
|
||||
for rows.Next() {
|
||||
memoRelation := &store.MemoRelation{}
|
||||
if err := rows.Scan(&memoRelation.MemoID, &memoRelation.RelatedMemoID, &memoRelation.Type); err != nil {
|
||||
if err := rows.Scan(
|
||||
&memoRelation.MemoID,
|
||||
&memoRelation.RelatedMemoID,
|
||||
&memoRelation.Type,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, memoRelation)
|
||||
}
|
||||
|
||||
return list, rows.Err()
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
|
||||
qb := squirrel.Delete("memo_relation").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if delete.MemoID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": *delete.MemoID})
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID)
|
||||
}
|
||||
if delete.RelatedMemoID != nil {
|
||||
qb = qb.Where(squirrel.Eq{"related_memo_id": *delete.RelatedMemoID})
|
||||
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID)
|
||||
}
|
||||
if delete.Type != nil {
|
||||
qb = qb.Where(squirrel.Eq{"type": *delete.Type})
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type)
|
||||
}
|
||||
|
||||
stmt, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ")
|
||||
result, err := d.db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = result.RowsAffected()
|
||||
return err
|
||||
if _, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery for memo_id
|
||||
subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
DELETE FROM memo_relation
|
||||
WHERE memo_id NOT IN (SELECT id FROM memo) OR related_memo_id NOT IN (SELECT id FROM memo)
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Note: The same subquery is used for related_memo_id as it's also checking against the "memo" table
|
||||
|
||||
// Now, build the main delete query using the subqueries
|
||||
query, args, err := squirrel.Delete("memo_relation").
|
||||
Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
|
||||
Where(fmt.Sprintf("related_memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Combine the arguments for both instances of the same subquery
|
||||
args = append(args, subArgsMemo...)
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
@ -3,22 +3,12 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
|
||||
qb := squirrel.Select("version", "created_ts").
|
||||
From("migration_history").
|
||||
OrderBy("created_ts DESC")
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
query := "SELECT version, created_ts FROM migration_history ORDER BY created_ts DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -27,9 +17,13 @@ func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigratio
|
||||
list := make([]*store.MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := rows.Scan(&migrationHistory.Version, &migrationHistory.CreatedTs); err != nil {
|
||||
if err := rows.Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, &migrationHistory)
|
||||
}
|
||||
|
||||
@ -41,33 +35,21 @@ func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigratio
|
||||
}
|
||||
|
||||
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
|
||||
qb := squirrel.Insert("migration_history").
|
||||
Columns("version").
|
||||
Values(upsert.Version).
|
||||
Suffix("ON CONFLICT (version) DO NOTHING").
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
query, args, err := qb.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO migration_history (
|
||||
version
|
||||
)
|
||||
VALUES ($1)
|
||||
ON CONFLICT(version) DO UPDATE
|
||||
SET
|
||||
version=EXCLUDED.version
|
||||
RETURNING version, created_ts
|
||||
`
|
||||
var migrationHistory store.MigrationHistory
|
||||
query, args, err = squirrel.Select("version", "created_ts").
|
||||
From("migration_history").
|
||||
Where(squirrel.Eq{"version": upsert.Version}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(&migrationHistory.Version, &migrationHistory.CreatedTs); err != nil {
|
||||
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -4,77 +4,61 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) {
|
||||
qb := squirrel.Insert("resource").Columns("filename", "blob", "external_link", "type", "size", "creator_id", "internal_path", "memo_id")
|
||||
values := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, create.MemoID}
|
||||
fields := []string{"filename", "blob", "external_link", "type", "size", "creator_id", "internal_path", "memo_id"}
|
||||
args := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath, create.MemoID}
|
||||
|
||||
qb = qb.Values(values...).Suffix("RETURNING id")
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
stmt := "INSERT INTO resource (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID, &create.CreatedTs, &create.UpdatedTs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list, err := d.ListResources(ctx, &store.FindResource{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) {
|
||||
qb := squirrel.Select("id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id").From("resource")
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"id": *v})
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"creator_id": *v})
|
||||
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Filename; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"filename": *v})
|
||||
where, args = append(where, "filename = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.MemoID; v != nil {
|
||||
qb = qb.Where(squirrel.Eq{"memo_id": *v})
|
||||
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if find.HasRelatedMemo {
|
||||
qb = qb.Where("memo_id IS NOT NULL")
|
||||
where = append(where, "memo_id IS NOT NULL")
|
||||
}
|
||||
|
||||
fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id"}
|
||||
if find.GetBlob {
|
||||
qb = qb.Columns("blob")
|
||||
fields = append(fields, "blob")
|
||||
}
|
||||
|
||||
qb = qb.GroupBy("id").OrderBy("created_ts DESC")
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
%s
|
||||
FROM resource
|
||||
WHERE %s
|
||||
GROUP BY id
|
||||
ORDER BY created_ts DESC
|
||||
`, strings.Join(fields, ", "), strings.Join(where, " AND "))
|
||||
if find.Limit != nil {
|
||||
qb = qb.Limit(uint64(*find.Limit))
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
|
||||
if find.Offset != nil {
|
||||
qb = qb.Offset(uint64(*find.Offset))
|
||||
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -103,7 +87,6 @@ func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*st
|
||||
if err := rows.Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if memoID.Valid {
|
||||
resource.MemoID = &memoID.Int32
|
||||
}
|
||||
@ -118,88 +101,72 @@ func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*st
|
||||
}
|
||||
|
||||
func (d *DB) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) {
|
||||
qb := squirrel.Update("resource")
|
||||
set, args := []string{}, []any{}
|
||||
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
qb = qb.Set("updated_ts", time.Unix(0, *v))
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Filename; v != nil {
|
||||
qb = qb.Set("filename", *v)
|
||||
set, args = append(set, "filename = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.InternalPath; v != nil {
|
||||
qb = qb.Set("internal_path", *v)
|
||||
set, args = append(set, "internal_path = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.MemoID; v != nil {
|
||||
qb = qb.Set("memo_id", *v)
|
||||
set, args = append(set, "memo_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Blob; v != nil {
|
||||
qb = qb.Set("blob", v)
|
||||
set, args = append(set, "blob = "+placeholder(len(args)+1)), append(args, v)
|
||||
}
|
||||
|
||||
qb = qb.Where(squirrel.Eq{"id": update.ID})
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
fields := []string{"id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path"}
|
||||
stmt := `
|
||||
UPDATE resource
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING ` + strings.Join(fields, ", ")
|
||||
args = append(args, update.ID)
|
||||
resource := store.Resource{}
|
||||
dests := []any{
|
||||
&resource.ID,
|
||||
&resource.Filename,
|
||||
&resource.ExternalLink,
|
||||
&resource.Type,
|
||||
&resource.Size,
|
||||
&resource.CreatorID,
|
||||
&resource.CreatedTs,
|
||||
&resource.UpdatedTs,
|
||||
&resource.InternalPath,
|
||||
}
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(dests...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list))
|
||||
}
|
||||
|
||||
return list[0], nil
|
||||
return &resource, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteResource(ctx context.Context, delete *store.DeleteResource) error {
|
||||
qb := squirrel.Delete("resource").Where(squirrel.Eq{"id": delete.ID})
|
||||
|
||||
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
stmt := `DELETE FROM resource WHERE id = $1`
|
||||
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := d.Vacuum(ctx); err != nil {
|
||||
// Prevent linter warning.
|
||||
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
resource
|
||||
WHERE
|
||||
creator_id NOT IN (SELECT id FROM "user")`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumResource(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery
|
||||
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subquery
|
||||
query, args, err := squirrel.Delete("resource").
|
||||
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
|
||||
@ -82,22 +81,15 @@ func (d *DB) DeleteTag(ctx context.Context, delete *store.DeleteTag) error {
|
||||
}
|
||||
|
||||
func vacuumTag(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery for creator_id
|
||||
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
tag
|
||||
WHERE
|
||||
creator_id NOT IN (SELECT id FROM "user")`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subquery
|
||||
query, args, err := squirrel.Delete("tag").
|
||||
Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
@ -2,132 +2,113 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
// Start building the insert statement
|
||||
builder := squirrel.Insert(`"user"`).PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
columns := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
|
||||
builder = builder.Columns(columns...)
|
||||
|
||||
values := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
|
||||
builder = builder.Values(values...)
|
||||
builder = builder.Suffix("RETURNING id")
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
fields := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"}
|
||||
args := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL}
|
||||
stmt := "INSERT INTO \"user\" (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, avatar_url, created_ts, updated_ts, row_status"
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.ID,
|
||||
&create.AvatarURL,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and get the returned ID
|
||||
var id int32
|
||||
err = d.db.QueryRowContext(ctx, query, args...).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use the returned ID to retrieve the full user object
|
||||
user, err := d.GetUser(ctx, &store.FindUser{ID: &id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
return create, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
// Start building the update statement
|
||||
builder := squirrel.Update(`"user"`).PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
// Conditionally add set clauses
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.UpdatedTs; v != nil {
|
||||
builder = builder.Set("updated_ts", *v)
|
||||
set, args = append(set, "updated_ts = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.RowStatus; v != nil {
|
||||
builder = builder.Set("row_status", *v)
|
||||
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Username; v != nil {
|
||||
builder = builder.Set("username", *v)
|
||||
set, args = append(set, "username = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
builder = builder.Set("email", *v)
|
||||
set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
builder = builder.Set("nickname", *v)
|
||||
set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.AvatarURL; v != nil {
|
||||
builder = builder.Set("avatar_url", *v)
|
||||
set, args = append(set, "avatar_url = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
builder = builder.Set("password_hash", *v)
|
||||
set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
// Add the WHERE clause
|
||||
builder = builder.Where(squirrel.Eq{"id": update.ID})
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
query := `
|
||||
UPDATE "user"
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, username, role, email, nickname, password_hash, avatar_url, created_ts, updated_ts, row_status
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
user := &store.User{}
|
||||
if err := d.db.QueryRowContext(ctx, query, args...).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.Role,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.AvatarURL,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query with the context
|
||||
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Retrieve the updated user
|
||||
user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
// Start building the SELECT statement
|
||||
builder := squirrel.Select("id", "username", "role", "email", "nickname", "password_hash", "avatar_url", "created_ts", "updated_ts", "row_status").
|
||||
From(`"user"`).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
// 1 = 1 is often used as a no-op in SQL, ensuring there's always a WHERE clause
|
||||
builder = builder.Where("1 = 1")
|
||||
|
||||
// Conditionally add where clauses
|
||||
if v := find.ID; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"id": *v})
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Username; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"username": *v})
|
||||
where, args = append(where, "username = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"role": *v})
|
||||
where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"email": *v})
|
||||
where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
builder = builder.Where(squirrel.Eq{"nickname": *v})
|
||||
where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
// Add ordering
|
||||
builder = builder.OrderBy("created_ts DESC", "row_status DESC")
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query with the context
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
username,
|
||||
role,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
avatar_url,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status
|
||||
FROM "user"
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY created_ts DESC, row_status DESC
|
||||
`
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -161,35 +142,13 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) {
|
||||
list, err := d.ListUsers(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) != 1 {
|
||||
return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list))
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
// Start building the DELETE statement
|
||||
builder := squirrel.Delete(`"user"`).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
Where(squirrel.Eq{"id": delete.ID})
|
||||
|
||||
// Prepare the final query
|
||||
query, args, err := builder.ToSql()
|
||||
result, err := d.db.ExecContext(ctx, `
|
||||
DELETE FROM "user" WHERE id = $1
|
||||
`, delete.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query with the context
|
||||
result, err := d.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
@ -130,22 +129,15 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
|
||||
}
|
||||
|
||||
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
|
||||
// First, build the subquery
|
||||
subQuery, subArgs, err := squirrel.Select("id").From(`"user"`).PlaceholderFormat(squirrel.Dollar).ToSql()
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
user_setting
|
||||
WHERE
|
||||
user_id NOT IN (SELECT id FROM "user")`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, build the main delete query using the subquery
|
||||
query, args, err := squirrel.Delete("user_setting").
|
||||
Where(fmt.Sprintf("user_id NOT IN (%s)", subQuery), subArgs...).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute the query
|
||||
_, err = tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
@ -38,7 +38,6 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.ID != nil {
|
||||
where, args = append(where, "`id` = ?"), append(args, *find.ID)
|
||||
}
|
||||
|
62
test/store/memo_organizer_test.go
Normal file
62
test/store/memo_organizer_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
package teststore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestMemoOrganizerStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
memoCreate := &store.Memo{
|
||||
CreatorID: user.ID,
|
||||
Content: "main memo content",
|
||||
Visibility: store.Public,
|
||||
}
|
||||
memo, err := ts.CreateMemo(ctx, memoCreate)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memoCreate.Content, memo.Content)
|
||||
|
||||
memoOrganizer, err := ts.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
|
||||
MemoID: memo.ID,
|
||||
UserID: user.ID,
|
||||
Pinned: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoOrganizer)
|
||||
require.Equal(t, memo.ID, memoOrganizer.MemoID)
|
||||
require.Equal(t, user.ID, memoOrganizer.UserID)
|
||||
require.Equal(t, true, memoOrganizer.Pinned)
|
||||
|
||||
memoOrganizerTemp, err := ts.GetMemoOrganizer(ctx, &store.FindMemoOrganizer{
|
||||
MemoID: memo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memoOrganizer, memoOrganizerTemp)
|
||||
memoOrganizerTemp, err = ts.UpsertMemoOrganizer(ctx, &store.MemoOrganizer{
|
||||
MemoID: memo.ID,
|
||||
UserID: user.ID,
|
||||
Pinned: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, memoOrganizerTemp)
|
||||
require.Equal(t, memo.ID, memoOrganizerTemp.MemoID)
|
||||
require.Equal(t, user.ID, memoOrganizerTemp.UserID)
|
||||
require.Equal(t, false, memoOrganizerTemp.Pinned)
|
||||
err = ts.DeleteMemoOrganizer(ctx, &store.DeleteMemoOrganizer{
|
||||
MemoID: &memo.ID,
|
||||
UserID: &user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
memoOrganizers, err := ts.ListMemoOrganizer(ctx, &store.FindMemoOrganizer{
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(memoOrganizers))
|
||||
}
|
@ -59,3 +59,22 @@ func TestMemoStore(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(memoList))
|
||||
}
|
||||
|
||||
func TestDeleteMemoStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingHostUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
memoCreate := &store.Memo{
|
||||
CreatorID: user.ID,
|
||||
Content: "test_content",
|
||||
Visibility: store.Public,
|
||||
}
|
||||
memo, err := ts.CreateMemo(ctx, memoCreate)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, memoCreate.Content, memo.Content)
|
||||
err = ts.DeleteMemo(ctx, &store.DeleteMemo{
|
||||
ID: memo.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
// sqlite driver.
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"github.com/usememos/memos/server/profile"
|
||||
"github.com/usememos/memos/store"
|
||||
"github.com/usememos/memos/store/db"
|
||||
"github.com/usememos/memos/test"
|
||||
@ -19,6 +20,7 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
|
||||
if err != nil {
|
||||
fmt.Printf("failed to create db driver, error: %+v\n", err)
|
||||
}
|
||||
resetTestingDB(ctx, profile, dbDriver)
|
||||
if err := dbDriver.Migrate(ctx); err != nil {
|
||||
fmt.Printf("failed to migrate db, error: %+v\n", err)
|
||||
}
|
||||
@ -26,3 +28,13 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
|
||||
store := store.New(dbDriver, profile)
|
||||
return store
|
||||
}
|
||||
|
||||
func resetTestingDB(ctx context.Context, profile *profile.Profile, dbDriver store.Driver) {
|
||||
if profile.Driver == "postgres" {
|
||||
_, err := dbDriver.GetDB().ExecContext(ctx, `DROP SCHEMA public CASCADE; CREATE SCHEMA public;`)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to reset testing db, error: %+v\n", err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user