chore: update postgres stmt builder

This commit is contained in:
Steven 2024-01-06 17:12:10 +08:00
parent 8893a302e2
commit 9459ae8265
7 changed files with 155 additions and 225 deletions

3
go.mod
View File

@ -3,7 +3,6 @@ module github.com/usememos/memos
go 1.21
require (
github.com/Masterminds/squirrel v1.5.4
github.com/aws/aws-sdk-go-v2 v1.24.0
github.com/aws/aws-sdk-go-v2/config v1.26.2
github.com/aws/aws-sdk-go-v2/credentials v1.16.13
@ -50,8 +49,6 @@ require (
github.com/gorilla/css v1.0.1 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rs/cors v1.10.1 // indirect

6
go.sum
View File

@ -6,8 +6,6 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
@ -283,10 +281,6 @@ github.com/labstack/echo/v4 v4.11.4 h1:vDZmA+qNeh1pd/cCkEicDMrjtrnMGQ1QFI9gWN1zG
github.com/labstack/echo/v4 v4.11.4/go.mod h1:noh7EvLwqDsmh/X/HWKPUl1AjzJrhyptRyEbQJfxen8=
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=

View File

@ -2,43 +2,43 @@ package postgres
import (
"context"
"github.com/Masterminds/squirrel"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) CreateStorage(ctx context.Context, create *store.Storage) (*store.Storage, error) {
qb := squirrel.Insert("storage").Columns("name", "type", "config")
values := []any{create.Name, create.Type, create.Config}
fields := []string{"name", "type", "config"}
args := []any{create.Name, create.Type, create.Config}
qb = qb.Values(values...).Suffix("RETURNING id")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
stmt := "INSERT INTO storage (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.ID,
); err != nil {
return nil, err
}
err = d.db.QueryRowContext(ctx, query, args...).Scan(&create.ID)
if err != nil {
return nil, err
}
return create, nil
storage := create
return storage, nil
}
func (d *DB) ListStorages(ctx context.Context, find *store.FindStorage) ([]*store.Storage, error) {
qb := squirrel.Select("id", "name", "type", "config").From("storage").OrderBy("id DESC")
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
name,
type,
config
FROM storage
WHERE `+strings.Join(where, " AND ")+`
ORDER BY id DESC`,
args...,
)
if err != nil {
return nil, err
}
@ -47,7 +47,12 @@ func (d *DB) ListStorages(ctx context.Context, find *store.FindStorage) ([]*stor
list := []*store.Storage{}
for rows.Next() {
storage := &store.Storage{}
if err := rows.Scan(&storage.ID, &storage.Name, &storage.Type, &storage.Config); err != nil {
if err := rows.Scan(
&storage.ID,
&storage.Name,
&storage.Type,
&storage.Config,
); err != nil {
return nil, err
}
list = append(list, storage)
@ -61,38 +66,23 @@ func (d *DB) ListStorages(ctx context.Context, find *store.FindStorage) ([]*stor
}
func (d *DB) UpdateStorage(ctx context.Context, update *store.UpdateStorage) (*store.Storage, error) {
qb := squirrel.Update("storage")
set, args := []string{}, []any{}
if update.Name != nil {
qb = qb.Set("name", *update.Name)
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name)
}
if update.Config != nil {
qb = qb.Set("config", *update.Config)
}
qb = qb.Where(squirrel.Eq{"id": update.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, *update.Config)
}
stmt := `UPDATE storage SET ` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1) + ` RETURNING id, name, type, config`
args = append(args, update.ID)
storage := &store.Storage{}
query, args, err = squirrel.Select("id", "name", "type", "config").
From("storage").
Where(squirrel.Eq{"id": update.ID}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, err
}
if err := d.db.QueryRowContext(ctx, query, args...).Scan(&storage.ID, &storage.Name, &storage.Type, &storage.Config); err != nil {
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&storage.ID,
&storage.Name,
&storage.Type,
&storage.Config,
); err != nil {
return nil, err
}
@ -100,21 +90,13 @@ func (d *DB) UpdateStorage(ctx context.Context, update *store.UpdateStorage) (*s
}
func (d *DB) DeleteStorage(ctx context.Context, delete *store.DeleteStorage) error {
qb := squirrel.Delete("storage").Where(squirrel.Eq{"id": delete.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
stmt := `DELETE FROM storage WHERE id = $1`
result, err := d.db.ExecContext(ctx, stmt, delete.ID)
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err := result.RowsAffected(); err != nil {
return err
}
return nil
}

View File

@ -2,26 +2,23 @@ package postgres
import (
"context"
"github.com/Masterminds/squirrel"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) UpsertSystemSetting(ctx context.Context, upsert *store.SystemSetting) (*store.SystemSetting, error) {
qb := squirrel.Insert("system_setting").
Columns("name", "value", "description").
Values(upsert.Name, upsert.Value, upsert.Description).
Suffix("ON CONFLICT (name) DO UPDATE SET value = EXCLUDED.value, description = EXCLUDED.description").
PlaceholderFormat(squirrel.Dollar)
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
stmt := `
INSERT INTO system_setting (
name, value, description
)
VALUES ($1, $2, $3)
ON CONFLICT(name) DO UPDATE
SET
value = EXCLUDED.value,
description = EXCLUDED.description
`
if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil {
return nil, err
}
@ -29,16 +26,18 @@ func (d *DB) UpsertSystemSetting(ctx context.Context, upsert *store.SystemSettin
}
func (d *DB) ListSystemSettings(ctx context.Context, find *store.FindSystemSetting) ([]*store.SystemSetting, error) {
qb := squirrel.Select("name", "value", "description").From("system_setting")
where, args := []string{"1 = 1"}, []any{}
if find.Name != "" {
qb = qb.Where(squirrel.Eq{"name": find.Name})
where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, find.Name)
}
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
query := `
SELECT
name,
value,
description
FROM system_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
@ -48,11 +47,15 @@ func (d *DB) ListSystemSettings(ctx context.Context, find *store.FindSystemSetti
list := []*store.SystemSetting{}
for rows.Next() {
systemSetting := &store.SystemSetting{}
if err := rows.Scan(&systemSetting.Name, &systemSetting.Value, &systemSetting.Description); err != nil {
systemSettingMessage := &store.SystemSetting{}
if err := rows.Scan(
&systemSettingMessage.Name,
&systemSettingMessage.Value,
&systemSettingMessage.Description,
); err != nil {
return nil, err
}
list = append(list, systemSetting)
list = append(list, systemSettingMessage)
}
if err := rows.Err(); err != nil {

View File

@ -3,8 +3,7 @@ package postgres
import (
"context"
"database/sql"
"github.com/Masterminds/squirrel"
"strings"
"github.com/usememos/memos/store"
)
@ -18,20 +17,20 @@ func (d *DB) UpsertTag(ctx context.Context, upsert *store.Tag) (*store.Tag, erro
}
func (d *DB) ListTags(ctx context.Context, find *store.FindTag) ([]*store.Tag, error) {
builder := squirrel.Select("name", "creator_id").From("tag").
Where("1 = 1").
OrderBy("name ASC").
PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
if find.CreatorID != 0 {
builder = builder.Where("creator_id = ?", find.CreatorID)
}
query, args, err := builder.ToSql()
if err != nil {
return nil, err
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, find.CreatorID)
}
query := `
SELECT
name,
creator_id
FROM tag
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY name ASC
`
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
@ -59,33 +58,20 @@ func (d *DB) ListTags(ctx context.Context, find *store.FindTag) ([]*store.Tag, e
}
func (d *DB) DeleteTag(ctx context.Context, delete *store.DeleteTag) error {
builder := squirrel.Delete("tag").
Where(squirrel.Eq{"name": delete.Name, "creator_id": delete.CreatorID}).
PlaceholderFormat(squirrel.Dollar)
query, args, err := builder.ToSql()
where, args := []string{"name = $1", "creator_id = $2"}, []any{delete.Name, delete.CreatorID}
stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...)
if err != nil {
return err
}
result, err := d.db.ExecContext(ctx, query, args...)
if err != nil {
return err
}
if _, err = result.RowsAffected(); err != nil {
return err
}
return nil
}
func vacuumTag(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
tag
WHERE
creator_id NOT IN (SELECT id FROM "user")`
stmt := `DELETE FROM tag WHERE creator_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err

View File

@ -3,8 +3,8 @@ package postgres
import (
"context"
"database/sql"
"strings"
"github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
@ -13,6 +13,14 @@ import (
)
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
stmt := `
INSERT INTO user_setting (
user_id, key, value
)
VALUES ($1, $2, $3)
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
`
var valueString string
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
@ -32,20 +40,7 @@ func (d *DB) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting)
return nil, errors.Errorf("unknown user setting key: %s", upsert.Key.String())
}
// Construct the query using Squirrel
query, args, err := squirrel.
Insert("user_setting").
Columns("user_id", "key", "value").
Values(upsert.UserId, upsert.Key.String(), valueString).
Suffix("ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value").
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return nil, err
}
// Execute the query
if _, err := d.db.ExecContext(ctx, query, args...); err != nil {
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
return nil, err
}
@ -53,31 +48,28 @@ func (d *DB) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting)
}
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*storepb.UserSetting, error) {
// Start building the query using Squirrel
qb := squirrel.Select("user_id", "key", "value").From("user_setting").PlaceholderFormat(squirrel.Dollar)
where, args := []string{"1 = 1"}, []any{}
// Add conditions based on the provided find parameters
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
qb = qb.Where(squirrel.Eq{"key": v.String()})
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
}
if v := find.UserID; v != nil {
qb = qb.Where(squirrel.Eq{"user_id": *v})
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
}
// Finalize the query
query, args, err := qb.ToSql()
if err != nil {
return nil, err
}
// Execute the query
query := `
SELECT
user_id,
key,
value
FROM user_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
// Process the rows
userSettingList := make([]*storepb.UserSetting, 0)
for rows.Next() {
userSetting := &storepb.UserSetting{}
@ -129,11 +121,7 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
}
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
user_setting
WHERE
user_id NOT IN (SELECT id FROM "user")`
stmt := `DELETE FROM user_setting WHERE user_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err

View File

@ -2,52 +2,54 @@ package postgres
import (
"context"
"github.com/Masterminds/squirrel"
"strings"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (d *DB) CreateWebhook(ctx context.Context, create *storepb.Webhook) (*storepb.Webhook, error) {
qb := squirrel.Insert("webhook").Columns("name", "url", "creator_id")
values := []any{create.Name, create.Url, create.CreatorId}
qb = qb.Values(values...).Suffix("RETURNING id")
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
fields := []string{"name", "url", "creator_id"}
args := []any{create.Name, create.Url, create.CreatorId}
stmt := "INSERT INTO webhook (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id, created_ts, updated_ts, row_status"
var rowStatus string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
&rowStatus,
); err != nil {
return nil, err
}
err = d.db.QueryRowContext(ctx, query, args...).Scan(&create.Id)
if err != nil {
return nil, err
}
create, err = d.GetWebhook(ctx, &store.FindWebhook{ID: &create.Id})
if err != nil {
return nil, err
}
return create, nil
create.RowStatus = storepb.RowStatus(storepb.RowStatus_value[rowStatus])
webhook := create
return webhook, nil
}
func (d *DB) ListWebhooks(ctx context.Context, find *store.FindWebhook) ([]*storepb.Webhook, error) {
qb := squirrel.Select("id", "created_ts", "updated_ts", "row_status", "creator_id", "name", "url").From("webhook").OrderBy("id DESC")
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
qb = qb.Where(squirrel.Eq{"id": *find.ID})
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
if find.CreatorID != nil {
qb = qb.Where(squirrel.Eq{"creator_id": *find.CreatorID})
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *find.CreatorID)
}
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return nil, err
}
rows, err := d.db.QueryContext(ctx, query, args...)
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
created_ts,
updated_ts,
row_status,
creator_id,
name,
url
FROM webhook
WHERE `+strings.Join(where, " AND ")+`
ORDER BY id DESC`,
args...,
)
if err != nil {
return nil, err
}
@ -68,9 +70,7 @@ func (d *DB) ListWebhooks(ctx context.Context, find *store.FindWebhook) ([]*stor
); err != nil {
return nil, err
}
webhook.RowStatus = storepb.RowStatus(storepb.RowStatus_value[rowStatus])
list = append(list, webhook)
}
@ -81,58 +81,38 @@ func (d *DB) ListWebhooks(ctx context.Context, find *store.FindWebhook) ([]*stor
return list, nil
}
func (d *DB) GetWebhook(ctx context.Context, find *store.FindWebhook) (*storepb.Webhook, error) {
list, err := d.ListWebhooks(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
func (d *DB) UpdateWebhook(ctx context.Context, update *store.UpdateWebhook) (*storepb.Webhook, error) {
qb := squirrel.Update("webhook")
set, args := []string{}, []any{}
if update.RowStatus != nil {
qb = qb.Set("row_status", update.RowStatus.String())
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, update.RowStatus.String())
}
if update.Name != nil {
qb = qb.Set("name", *update.Name)
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name)
}
if update.URL != nil {
qb = qb.Set("url", *update.URL)
set, args = append(set, "url = "+placeholder(len(args)+1)), append(args, *update.URL)
}
qb = qb.Where(squirrel.Eq{"id": update.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
stmt := "UPDATE webhook SET " + strings.Join(set, ", ") + " WHERE id = " + placeholder(len(args)+1) + " RETURNING id, created_ts, updated_ts, row_status, creator_id, name, url"
args = append(args, update.ID)
webhook := &storepb.Webhook{}
var rowStatus string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&webhook.Id,
&webhook.CreatedTs,
&webhook.UpdatedTs,
&rowStatus,
&webhook.CreatorId,
&webhook.Name,
&webhook.Url,
); err != nil {
return nil, err
}
_, err = d.db.ExecContext(ctx, query, args...)
if err != nil {
return nil, err
}
webhook, err := d.GetWebhook(ctx, &store.FindWebhook{ID: &update.ID})
if err != nil {
return nil, err
}
webhook.RowStatus = storepb.RowStatus(storepb.RowStatus_value[rowStatus])
return webhook, nil
}
func (d *DB) DeleteWebhook(ctx context.Context, delete *store.DeleteWebhook) error {
qb := squirrel.Delete("webhook").Where(squirrel.Eq{"id": delete.ID})
query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql()
if err != nil {
return err
}
_, err = d.db.ExecContext(ctx, query, args...)
_, err := d.db.ExecContext(ctx, "DELETE FROM webhook WHERE id = $1", delete.ID)
return err
}