Compare commits
11 Commits
77de7233d6
...
187855c408
Author | SHA1 | Date |
---|---|---|
Clar Fon | 187855c408 | |
Matt Baer | 038a80c25e | |
Matt Baer | 9ece6682ef | |
Matt Baer | 41e1989345 | |
Matt Baer | 34d902062f | |
dependabot[bot] | ed9ff51b68 | |
Riley Chang | 83ffea7fa0 | |
dependabot[bot] | 3dd0a9b8dc | |
Matt Baer | e34a58d0ef | |
ltdk | 7eb35ff5d4 | |
ltdk | 940c50c067 |
|
@ -12,15 +12,15 @@ type AlterTableSqlBuilder struct {
|
|||
}
|
||||
|
||||
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
|
||||
if colVal, err := col.String(); err == nil {
|
||||
if colVal, err := col.CreateSQL(b.Dialect); err == nil {
|
||||
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder {
|
||||
if colVal, err := col.String(); err == nil {
|
||||
b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal))
|
||||
if colActions, err := col.AlterSQL(b.Dialect, name); err == nil {
|
||||
b.Changes = append(b.Changes, colActions...)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
|
|
@ -18,28 +18,43 @@ func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
|
|||
name: "MySQL add int",
|
||||
builder: DialectMySQL.
|
||||
AlterTable("the_table").
|
||||
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
|
||||
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
|
||||
AddColumn(NonNullableColumn("the_col", ColumnTypeInt{MaxBytes: 4})),
|
||||
want: "ALTER TABLE the_table ADD COLUMN the_col INTEGER NOT NULL",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MySQL add string",
|
||||
builder: DialectMySQL.
|
||||
AlterTable("the_table").
|
||||
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
|
||||
AddColumn(NonNullableColumn("the_col", ColumnTypeString{MaxChars: 128})),
|
||||
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "MySQL add int and string",
|
||||
builder: DialectMySQL.
|
||||
AlterTable("the_table").
|
||||
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
|
||||
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
|
||||
AddColumn(NonNullableColumn("first_col", ColumnTypeInt{MaxBytes: 4})).
|
||||
AddColumn(NonNullableColumn("second_col", ColumnTypeString{MaxChars: 128})),
|
||||
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MySQL change to string",
|
||||
builder: DialectMySQL.
|
||||
AlterTable("the_table").
|
||||
ChangeColumn("old_col", NonNullableColumn("new_col", ColumnTypeString{})),
|
||||
want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, MODIFY COLUMN new_col TEXT NOT NULL",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "PostgreSQL change to int",
|
||||
builder: DialectMySQL.
|
||||
AlterTable("the_table").
|
||||
ChangeColumn("old_col", NullableColumn("new_col", ColumnTypeInt{MaxBytes: 4})),
|
||||
want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, ALTER COLUMN new_col TYPE INTEGER, ALTER COLUMN new_col DROP NOT NULL",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
/*
|
||||
* Copyright © 2019-2022 A Bunch Tell LLC.
|
||||
*
|
||||
* This file is part of WriteFreely.
|
||||
*
|
||||
* WriteFreely is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License, included
|
||||
* in the LICENSE file in this source code package.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
type SQLBuilder interface {
|
||||
ToSQL() (string, error)
|
||||
}
|
|
@ -0,0 +1,328 @@
|
|||
/*
|
||||
* Copyright © 2019-2022 A Bunch Tell LLC.
|
||||
*
|
||||
* This file is part of WriteFreely.
|
||||
*
|
||||
* WriteFreely is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License, included
|
||||
* in the LICENSE file in this source code package.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Column struct {
|
||||
Name string
|
||||
Type ColumnType
|
||||
Nullable bool
|
||||
PrimaryKey bool
|
||||
}
|
||||
|
||||
func NullableColumn(name string, ty ColumnType) *Column {
|
||||
return &Column{
|
||||
Name: name,
|
||||
Type: ty,
|
||||
Nullable: true,
|
||||
PrimaryKey: false,
|
||||
}
|
||||
}
|
||||
|
||||
func NonNullableColumn(name string, ty ColumnType) *Column {
|
||||
return &Column{
|
||||
Name: name,
|
||||
Type: ty,
|
||||
Nullable: false,
|
||||
PrimaryKey: false,
|
||||
}
|
||||
}
|
||||
|
||||
func PrimaryKeyColumn(name string, ty ColumnType) *Column {
|
||||
return &Column{
|
||||
Name: name,
|
||||
Type: ty,
|
||||
Nullable: false,
|
||||
PrimaryKey: true,
|
||||
}
|
||||
}
|
||||
|
||||
type ColumnType interface {
|
||||
Name(DialectType) (string, error)
|
||||
Default(DialectType) (string, error)
|
||||
}
|
||||
|
||||
type ColumnTypeInt struct {
|
||||
IsSigned bool
|
||||
MaxBytes int
|
||||
MaxDigits int
|
||||
HasDefault bool
|
||||
DefaultVal int
|
||||
}
|
||||
|
||||
type ColumnTypeString struct {
|
||||
IsFixedLength bool
|
||||
MaxChars int
|
||||
HasDefault bool
|
||||
DefaultVal string
|
||||
}
|
||||
|
||||
type ColumnDefault int
|
||||
|
||||
type ColumnTypeBool struct {
|
||||
DefaultVal ColumnDefault
|
||||
}
|
||||
|
||||
const (
|
||||
NoDefault ColumnDefault = iota
|
||||
DefaultFalse ColumnDefault = iota
|
||||
DefaultTrue ColumnDefault = iota
|
||||
DefaultNow ColumnDefault = iota
|
||||
)
|
||||
|
||||
type ColumnTypeDateTime struct {
|
||||
DefaultVal ColumnDefault
|
||||
}
|
||||
|
||||
func (intCol ColumnTypeInt) Name(d DialectType) (string, error) {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return "INTEGER", nil
|
||||
|
||||
case DialectMySQL, DialectPostgreSQL:
|
||||
var colName string
|
||||
switch intCol.MaxBytes {
|
||||
case 1:
|
||||
if d == DialectMySQL {
|
||||
colName = "TINYINT"
|
||||
} else {
|
||||
colName = "SMALLINT"
|
||||
}
|
||||
case 2:
|
||||
colName = "SMALLINT"
|
||||
case 3:
|
||||
if d == DialectMySQL {
|
||||
colName = "MEDIUMINT"
|
||||
} else {
|
||||
colName = "INTEGER"
|
||||
}
|
||||
case 4:
|
||||
colName = "INTEGER"
|
||||
default:
|
||||
colName = "BIGINT"
|
||||
}
|
||||
if d == DialectMySQL {
|
||||
if intCol.MaxDigits > 0 {
|
||||
colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits)
|
||||
}
|
||||
if !intCol.IsSigned {
|
||||
colName += " UNSIGNED"
|
||||
}
|
||||
}
|
||||
return colName, nil
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("dialect %d does not support integer columns", d)
|
||||
}
|
||||
}
|
||||
|
||||
func (intCol ColumnTypeInt) Default(d DialectType) (string, error) {
|
||||
if intCol.HasDefault {
|
||||
return fmt.Sprintf("%d", intCol.DefaultVal), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (strCol ColumnTypeString) Name(d DialectType) (string, error) {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return "TEXT", nil
|
||||
|
||||
case DialectMySQL, DialectPostgreSQL:
|
||||
if strCol.IsFixedLength {
|
||||
if strCol.MaxChars > 0 {
|
||||
return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil
|
||||
}
|
||||
return "CHAR", nil
|
||||
}
|
||||
|
||||
if strCol.MaxChars <= 0 {
|
||||
return "TEXT", nil
|
||||
}
|
||||
if strCol.MaxChars < (1 << 16) {
|
||||
return fmt.Sprintf("VARCHAR(%d)", strCol.MaxChars), nil
|
||||
}
|
||||
return "TEXT", nil
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("dialect %d does not support string columns", d)
|
||||
}
|
||||
}
|
||||
|
||||
func (strCol ColumnTypeString) Default(d DialectType) (string, error) {
|
||||
if strCol.HasDefault {
|
||||
return EscapeSimple.SQLEscape(d, strCol.DefaultVal)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return "INTEGER", nil
|
||||
case DialectMySQL, DialectPostgreSQL:
|
||||
return "BOOL", nil
|
||||
default:
|
||||
return "", fmt.Errorf("boolean column type not supported for dialect %d", d)
|
||||
}
|
||||
}
|
||||
|
||||
func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) {
|
||||
switch boolCol.DefaultVal {
|
||||
case NoDefault:
|
||||
return "", nil
|
||||
case DefaultFalse:
|
||||
return "0", nil
|
||||
case DefaultTrue:
|
||||
return "1", nil
|
||||
default:
|
||||
return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d)
|
||||
}
|
||||
}
|
||||
|
||||
func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) {
|
||||
switch d {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return "DATETIME", nil
|
||||
case DialectPostgreSQL:
|
||||
return "TIMESTAMP", nil
|
||||
default:
|
||||
return "", fmt.Errorf("datetime column type not supported for dialect %d", d)
|
||||
}
|
||||
}
|
||||
|
||||
func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) {
|
||||
switch d {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
switch dateTimeCol.DefaultVal {
|
||||
case NoDefault:
|
||||
return "", nil
|
||||
case DefaultNow:
|
||||
switch d {
|
||||
case DialectSQLite, DialectPostgreSQL:
|
||||
return "CURRENT_TIMESTAMP", nil
|
||||
case DialectMySQL:
|
||||
return "NOW()", nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("datetime columns cannot default to %d for dialect %d", dateTimeCol.DefaultVal, d)
|
||||
default:
|
||||
return "", fmt.Errorf("dialect %d does not support defaulted datetime columns", d)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Column) SetName(name string) *Column {
|
||||
c.Name = name
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetNullable(nullable bool) *Column {
|
||||
c.Nullable = nullable
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetPrimaryKey(pk bool) *Column {
|
||||
c.PrimaryKey = pk
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetType(t ColumnType) *Column {
|
||||
c.Type = t
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) AlterSQL(d DialectType, oldName string) ([]string, error) {
|
||||
var actions []string = make([]string, 0)
|
||||
|
||||
switch d {
|
||||
// MySQL does all modifications at once
|
||||
case DialectMySQL:
|
||||
sql, err := c.CreateSQL(d)
|
||||
if err != nil {
|
||||
return make([]string, 0), err
|
||||
}
|
||||
actions = append(actions, fmt.Sprintf("CHANGE COLUMN %s %s", oldName, sql))
|
||||
|
||||
// PostgreSQL does modifications piece by piece
|
||||
case DialectPostgreSQL:
|
||||
if oldName != c.Name {
|
||||
actions = append(actions, fmt.Sprintf("RENAME COLUMN %s TO %s", oldName, c.Name))
|
||||
}
|
||||
|
||||
typeStr, err := c.Type.Name(d)
|
||||
if err != nil {
|
||||
return make([]string, 0), err
|
||||
}
|
||||
|
||||
actions = append(actions, fmt.Sprintf("ALTER COLUMN %s TYPE %s", c.Name, typeStr))
|
||||
var nullAction string
|
||||
if c.Nullable {
|
||||
nullAction = "DROP"
|
||||
} else {
|
||||
nullAction = "SET"
|
||||
}
|
||||
actions = append(actions, fmt.Sprintf("ALTER COLUMN %s %s NOT NULL", c.Name, nullAction))
|
||||
|
||||
defaultStr, err := c.Type.Default(d)
|
||||
if err != nil {
|
||||
return make([]string, 0), err
|
||||
}
|
||||
if len(defaultStr) > 0 {
|
||||
actions = append(actions, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", c.Name, defaultStr))
|
||||
}
|
||||
|
||||
if c.PrimaryKey {
|
||||
actions = append(actions, fmt.Sprintf("ADD PRIMARY KEY (%s)", c.Name))
|
||||
}
|
||||
|
||||
default:
|
||||
return make([]string, 0), fmt.Errorf("dialect %d doesn't support altering column data type", d)
|
||||
}
|
||||
|
||||
return actions, nil
|
||||
}
|
||||
|
||||
func (c *Column) CreateSQL(d DialectType) (string, error) {
|
||||
var str strings.Builder
|
||||
|
||||
str.WriteString(c.Name)
|
||||
|
||||
str.WriteString(" ")
|
||||
typeStr, err := c.Type.Name(d)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
str.WriteString(typeStr)
|
||||
|
||||
if !c.Nullable {
|
||||
str.WriteString(" NOT NULL")
|
||||
}
|
||||
|
||||
defaultStr, err := c.Type.Default(d)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(defaultStr) > 0 {
|
||||
str.WriteString(" DEFAULT ")
|
||||
str.WriteString(defaultStr)
|
||||
}
|
||||
|
||||
if c.PrimaryKey {
|
||||
str.WriteString(" PRIMARY KEY")
|
||||
}
|
||||
|
||||
return str.String(), nil
|
||||
}
|
|
@ -0,0 +1,151 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestColumnType_Name(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ty ColumnType
|
||||
d DialectType
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"SQLite bool", ColumnTypeBool{}, DialectSQLite, "INTEGER", false},
|
||||
{"SQLite int", ColumnTypeInt{}, DialectSQLite, "INTEGER", false},
|
||||
{"SQLite string ", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "TEXT DEFAULT 'that''s a default'", false},
|
||||
{"SQLite datetime", ColumnTypeDateTime{}, DialectSQLite, "DATETIME", false},
|
||||
|
||||
{"MySQL bool", ColumnTypeBool{}, DialectMySQL, "BOOL", false},
|
||||
{"MySQL tiny int", ColumnTypeInt{MaxBytes: 1}, DialectMySQL, "TINYINT", false},
|
||||
{"MySQL tiny int with digits", ColumnTypeInt{MaxBytes: 1, MaxDigits: 2}, DialectMySQL, "TINYINT(2)", false},
|
||||
{"MySQL small int", ColumnTypeInt{MaxBytes: 2}, DialectMySQL, "SMALLINT", false},
|
||||
{"MySQL small int with digits", ColumnTypeInt{MaxBytes: 2, MaxDigits: 3}, DialectMySQL, "SMALLINT(3)", false},
|
||||
{"MySQL medium int", ColumnTypeInt{MaxBytes: 3}, DialectMySQL, "MEDIUMINT", false},
|
||||
{"MySQL medium int with digits", ColumnTypeInt{MaxBytes: 3, MaxDigits: 6}, DialectMySQL, "MEDIUMINT(6)", false},
|
||||
{"MySQL int", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "INTEGER", false},
|
||||
{"MySQL int with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 11}, DialectMySQL, "INTEGER(11)", false},
|
||||
{"MySQL bigint", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "BIGINT", false},
|
||||
{"MySQL bigint with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 15}, DialectMySQL, "BIGINT(15)", false},
|
||||
{"MySQL char", ColumnTypeString{IsFixedLength: true}, DialectMySQL, "CHAR", false},
|
||||
{"MySQL char with length", ColumnTypeString{IsFixedLength: true, MaxChars: 4}, DialectMySQL, "CHAR(4)", false},
|
||||
{"MySQL varchar with length", ColumnTypeString{MaxChars: 25}, DialectMySQL, "VARCHAR(25)", false},
|
||||
{"MySQL text", ColumnTypeString{}, DialectMySQL, "TEXT", false},
|
||||
{"MySQL datetime", ColumnTypeDateTime{}, DialectMySQL, "DATETIME", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.ty.Name(tt.d)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Name() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumnType_Default(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ty ColumnType
|
||||
d DialectType
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"SQLite bool none", ColumnTypeBool{}, DialectSQLite, "", false},
|
||||
{"SQLite bool false", ColumnTypeBool{}, DialectSQLite, "0", false},
|
||||
{"SQLite bool true", ColumnTypeBool{}, DialectSQLite, "1", false},
|
||||
{"SQLite int none", ColumnTypeInt{}, DialectSQLite, "", false},
|
||||
{"SQLite int empty", ColumnTypeInt{HasDefault: true}, DialectSQLite, "0", false},
|
||||
{"SQLite int", ColumnTypeInt{HasDefault: true, DefaultVal: 10}, DialectSQLite, "10", false},
|
||||
{"SQLite string none", ColumnTypeString{}, DialectSQLite, "", false},
|
||||
{"SQLite string empty", ColumnTypeString{HasDefault: true}, DialectSQLite, "''", false},
|
||||
{"SQLite string", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "'that''s a default'", false},
|
||||
{"MySQL string", ColumnTypeString{HasDefault: true, DefaultVal: "%that's a default%"}, DialectMySQL, "'%that\\'s a default%'", false},
|
||||
|
||||
{"SQLite datetime none", ColumnTypeDateTime{}, DialectSQLite, "", false},
|
||||
{"SQLite datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectSQLite, "CURRENT_TIMESTAMP", false},
|
||||
{"MySQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectMySQL, "NOW()", false},
|
||||
{"PostgreSQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectPostgreSQL, "CURRENT_TIMESTAMP", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.ty.Default(tt.d)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Default() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Default() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumn_CreateSQL(t *testing.T) {
|
||||
type fields struct {
|
||||
Dialect DialectType
|
||||
Name string
|
||||
Nullable bool
|
||||
Type ColumnType
|
||||
PrimaryKey bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"SQLite bool", fields{DialectSQLite, "foo", false, ColumnTypeBool{}, false}, "foo INTEGER NOT NULL", false},
|
||||
{"SQLite bool nullable", fields{DialectSQLite, "foo", true, ColumnTypeBool{}, false}, "foo INTEGER", false},
|
||||
{"SQLite int", fields{DialectSQLite, "foo", false, ColumnTypeInt{}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
|
||||
{"SQLite int nullable", fields{DialectSQLite, "foo", true, ColumnTypeInt{}, false}, "foo INTEGER", false},
|
||||
{"SQLite text", fields{DialectSQLite, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false},
|
||||
{"SQLite text nullable", fields{DialectSQLite, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false},
|
||||
{"SQLite datetime", fields{DialectSQLite, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false},
|
||||
{"SQLite datetime nullable", fields{DialectSQLite, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false},
|
||||
|
||||
{"MySQL bool", fields{DialectMySQL, "foo", false, ColumnTypeBool{}, false}, "foo TINYINT(1) NOT NULL", false},
|
||||
{"MySQL bool nullable", fields{DialectMySQL, "foo", true, ColumnTypeBool{}, false}, "foo TINYINT(1)", false},
|
||||
{"MySQL tiny int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 1}, true}, "foo TINYINT NOT NULL PRIMARY KEY", false},
|
||||
{"MySQL tiny int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 1}, false}, "foo TINYINT", false},
|
||||
{"MySQL small int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 2}, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
|
||||
{"MySQL small int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 2}, false}, "foo SMALLINT", false},
|
||||
{"MySQL int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 4}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
|
||||
{"MySQL int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 4}, false}, "foo INTEGER", false},
|
||||
{"MySQL big int", fields{DialectMySQL, "foo", false, ColumnTypeInt{}, true}, "foo BIGINT NOT NULL PRIMARY KEY", false},
|
||||
{"MySQL big int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{}, false}, "foo BIGINT", false},
|
||||
{"MySQL char", fields{DialectMySQL, "foo", false, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR NOT NULL", false},
|
||||
{"MySQL char nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR", false},
|
||||
{"MySQL varchar", fields{DialectMySQL, "foo", false, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255) NOT NULL", false},
|
||||
{"MySQL varchar nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255)", false},
|
||||
{"MySQL text", fields{DialectMySQL, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false},
|
||||
{"MySQL text nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false},
|
||||
{"MySQL datetime", fields{DialectMySQL, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false},
|
||||
{"MySQL datetime nullable", fields{DialectMySQL, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Column{
|
||||
Name: tt.fields.Name,
|
||||
Nullable: tt.fields.Nullable,
|
||||
Type: tt.fields.Type,
|
||||
PrimaryKey: tt.fields.PrimaryKey,
|
||||
}
|
||||
if got, err := c.CreateSQL(tt.fields.Dialect); got != tt.want {
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("String() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
179
db/create.go
179
db/create.go
|
@ -15,32 +15,6 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
type ColumnType int
|
||||
|
||||
type OptionalInt struct {
|
||||
Set bool
|
||||
Value int
|
||||
}
|
||||
|
||||
type OptionalString struct {
|
||||
Set bool
|
||||
Value string
|
||||
}
|
||||
|
||||
type SQLBuilder interface {
|
||||
ToSQL() (string, error)
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
Dialect DialectType
|
||||
Name string
|
||||
Nullable bool
|
||||
Default OptionalString
|
||||
Type ColumnType
|
||||
Size OptionalInt
|
||||
PrimaryKey bool
|
||||
}
|
||||
|
||||
type CreateTableSqlBuilder struct {
|
||||
Dialect DialectType
|
||||
Name string
|
||||
|
@ -50,157 +24,6 @@ type CreateTableSqlBuilder struct {
|
|||
Constraints []string
|
||||
}
|
||||
|
||||
const (
|
||||
ColumnTypeBool ColumnType = iota
|
||||
ColumnTypeSmallInt ColumnType = iota
|
||||
ColumnTypeInteger ColumnType = iota
|
||||
ColumnTypeChar ColumnType = iota
|
||||
ColumnTypeVarChar ColumnType = iota
|
||||
ColumnTypeText ColumnType = iota
|
||||
ColumnTypeDateTime ColumnType = iota
|
||||
)
|
||||
|
||||
var _ SQLBuilder = &CreateTableSqlBuilder{}
|
||||
|
||||
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
|
||||
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
|
||||
|
||||
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
|
||||
if dialect != DialectMySQL && dialect != DialectSQLite {
|
||||
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
|
||||
}
|
||||
switch d {
|
||||
case ColumnTypeSmallInt:
|
||||
{
|
||||
if dialect == DialectSQLite {
|
||||
return "INTEGER", nil
|
||||
}
|
||||
mod := ""
|
||||
if size.Set {
|
||||
mod = fmt.Sprintf("(%d)", size.Value)
|
||||
}
|
||||
return "SMALLINT" + mod, nil
|
||||
}
|
||||
case ColumnTypeInteger:
|
||||
{
|
||||
if dialect == DialectSQLite {
|
||||
return "INTEGER", nil
|
||||
}
|
||||
mod := ""
|
||||
if size.Set {
|
||||
mod = fmt.Sprintf("(%d)", size.Value)
|
||||
}
|
||||
return "INT" + mod, nil
|
||||
}
|
||||
case ColumnTypeChar:
|
||||
{
|
||||
if dialect == DialectSQLite {
|
||||
return "TEXT", nil
|
||||
}
|
||||
mod := ""
|
||||
if size.Set {
|
||||
mod = fmt.Sprintf("(%d)", size.Value)
|
||||
}
|
||||
return "CHAR" + mod, nil
|
||||
}
|
||||
case ColumnTypeVarChar:
|
||||
{
|
||||
if dialect == DialectSQLite {
|
||||
return "TEXT", nil
|
||||
}
|
||||
mod := ""
|
||||
if size.Set {
|
||||
mod = fmt.Sprintf("(%d)", size.Value)
|
||||
}
|
||||
return "VARCHAR" + mod, nil
|
||||
}
|
||||
case ColumnTypeBool:
|
||||
{
|
||||
if dialect == DialectSQLite {
|
||||
return "INTEGER", nil
|
||||
}
|
||||
return "TINYINT(1)", nil
|
||||
}
|
||||
case ColumnTypeDateTime:
|
||||
return "DATETIME", nil
|
||||
case ColumnTypeText:
|
||||
return "TEXT", nil
|
||||
}
|
||||
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
|
||||
}
|
||||
|
||||
func (c *Column) SetName(name string) *Column {
|
||||
c.Name = name
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetNullable(nullable bool) *Column {
|
||||
c.Nullable = nullable
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetPrimaryKey(pk bool) *Column {
|
||||
c.PrimaryKey = pk
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetDefault(value string) *Column {
|
||||
c.Default = OptionalString{Set: true, Value: value}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetDefaultCurrentTimestamp() *Column {
|
||||
def := "NOW()"
|
||||
if c.Dialect == DialectSQLite {
|
||||
def = "CURRENT_TIMESTAMP"
|
||||
}
|
||||
c.Default = OptionalString{Set: true, Value: def}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetType(t ColumnType) *Column {
|
||||
c.Type = t
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) SetSize(size int) *Column {
|
||||
c.Size = OptionalInt{Set: true, Value: size}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Column) String() (string, error) {
|
||||
var str strings.Builder
|
||||
|
||||
str.WriteString(c.Name)
|
||||
|
||||
str.WriteString(" ")
|
||||
typeStr, err := c.Type.Format(c.Dialect, c.Size)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
str.WriteString(typeStr)
|
||||
|
||||
if !c.Nullable {
|
||||
str.WriteString(" NOT NULL")
|
||||
}
|
||||
|
||||
if c.Default.Set {
|
||||
str.WriteString(" DEFAULT ")
|
||||
val := c.Default.Value
|
||||
if val == "" {
|
||||
val = "''"
|
||||
}
|
||||
str.WriteString(val)
|
||||
}
|
||||
|
||||
if c.PrimaryKey {
|
||||
str.WriteString(" PRIMARY KEY")
|
||||
}
|
||||
|
||||
return str.String(), nil
|
||||
}
|
||||
|
||||
func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
|
||||
if b.Columns == nil {
|
||||
b.Columns = make(map[string]*Column)
|
||||
|
@ -241,7 +64,7 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
|
|||
if !ok {
|
||||
return "", fmt.Errorf("column not found: %s", columnName)
|
||||
}
|
||||
columnStr, err := column.String()
|
||||
columnStr, err := column.CreateSQL(b.Dialect)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -5,139 +5,13 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestDialect_Column(t *testing.T) {
|
||||
c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize)
|
||||
assert.Equal(t, DialectSQLite, c1.Dialect)
|
||||
c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize)
|
||||
assert.Equal(t, DialectMySQL, c2.Dialect)
|
||||
}
|
||||
|
||||
func TestColumnType_Format(t *testing.T) {
|
||||
type args struct {
|
||||
dialect DialectType
|
||||
size OptionalInt
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
d ColumnType
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false},
|
||||
{"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false},
|
||||
{"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false},
|
||||
{"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false},
|
||||
{"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false},
|
||||
{"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false},
|
||||
{"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false},
|
||||
|
||||
{"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false},
|
||||
{"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false},
|
||||
{"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false},
|
||||
{"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false},
|
||||
{"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false},
|
||||
{"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false},
|
||||
{"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false},
|
||||
{"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false},
|
||||
{"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false},
|
||||
{"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false},
|
||||
{"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false},
|
||||
|
||||
{"invalid column type", 10000, args{dialect: DialectMySQL}, "", true},
|
||||
{"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.d.Format(tt.args.dialect, tt.args.size)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("Format() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumn_Build(t *testing.T) {
|
||||
type fields struct {
|
||||
Dialect DialectType
|
||||
Name string
|
||||
Nullable bool
|
||||
Default OptionalString
|
||||
Type ColumnType
|
||||
Size OptionalInt
|
||||
PrimaryKey bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false},
|
||||
{"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false},
|
||||
{"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
|
||||
{"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false},
|
||||
{"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false},
|
||||
{"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false},
|
||||
{"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
|
||||
{"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false},
|
||||
{"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
|
||||
{"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false},
|
||||
{"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
|
||||
{"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
|
||||
{"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
|
||||
{"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
|
||||
|
||||
{"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false},
|
||||
{"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false},
|
||||
{"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
|
||||
{"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false},
|
||||
{"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false},
|
||||
{"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false},
|
||||
{"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false},
|
||||
{"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false},
|
||||
{"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false},
|
||||
{"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false},
|
||||
{"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
|
||||
{"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
|
||||
{"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
|
||||
{"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Column{
|
||||
Dialect: tt.fields.Dialect,
|
||||
Name: tt.fields.Name,
|
||||
Nullable: tt.fields.Nullable,
|
||||
Default: tt.fields.Default,
|
||||
Type: tt.fields.Type,
|
||||
Size: tt.fields.Size,
|
||||
PrimaryKey: tt.fields.PrimaryKey,
|
||||
}
|
||||
if got, err := c.String(); got != tt.want {
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("String() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTableSqlBuilder_ToSQL(t *testing.T) {
|
||||
sql, err := DialectMySQL.
|
||||
Table("foo").
|
||||
SetIfNotExists(true).
|
||||
Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)).
|
||||
Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)).
|
||||
Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")).
|
||||
Column(PrimaryKeyColumn("bar", ColumnTypeInt{MaxBytes: 4})).
|
||||
Column(NonNullableColumn("baz", ColumnTypeString{})).
|
||||
Column(NonNullableColumn("qux", ColumnTypeDateTime{DefaultVal: DefaultNow})).
|
||||
UniqueConstraint("bar").
|
||||
UniqueConstraint("bar", "baz").
|
||||
ToSQL()
|
||||
|
|
|
@ -5,72 +5,47 @@ import "fmt"
|
|||
type DialectType int
|
||||
|
||||
const (
|
||||
DialectSQLite DialectType = iota
|
||||
DialectMySQL DialectType = iota
|
||||
DialectSQLite DialectType = iota
|
||||
DialectMySQL DialectType = iota
|
||||
DialectPostgreSQL DialectType = iota
|
||||
)
|
||||
|
||||
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
|
||||
func (d DialectType) IsKnown() bool {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
|
||||
case DialectMySQL:
|
||||
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
|
||||
case DialectSQLite, DialectMySQL, DialectPostgreSQL:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (d DialectType) AssertKnown() {
|
||||
if !d.IsKnown() {
|
||||
panic(fmt.Sprintf("unexpected dialect: %d", d))
|
||||
}
|
||||
}
|
||||
|
||||
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
|
||||
case DialectMySQL:
|
||||
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected dialect: %d", d))
|
||||
}
|
||||
d.AssertKnown()
|
||||
return &CreateTableSqlBuilder{Dialect: d, Name: name}
|
||||
}
|
||||
|
||||
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
|
||||
case DialectMySQL:
|
||||
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected dialect: %d", d))
|
||||
}
|
||||
d.AssertKnown()
|
||||
return &AlterTableSqlBuilder{Dialect: d, Name: name}
|
||||
}
|
||||
|
||||
func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns}
|
||||
case DialectMySQL:
|
||||
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns}
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected dialect: %d", d))
|
||||
}
|
||||
d.AssertKnown()
|
||||
return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: true, Columns: columns}
|
||||
}
|
||||
|
||||
func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns}
|
||||
case DialectMySQL:
|
||||
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns}
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected dialect: %d", d))
|
||||
}
|
||||
d.AssertKnown()
|
||||
return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: false, Columns: columns}
|
||||
}
|
||||
|
||||
func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder {
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table}
|
||||
case DialectMySQL:
|
||||
return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table}
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected dialect: %d", d))
|
||||
}
|
||||
d.AssertKnown()
|
||||
return &DropIndexSqlBuilder{Dialect: d, Name: name, Table: table}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright © 2019-2022 A Bunch Tell LLC.
|
||||
*
|
||||
* This file is part of WriteFreely.
|
||||
*
|
||||
* WriteFreely is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License, included
|
||||
* in the LICENSE file in this source code package.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EscapeContext int
|
||||
|
||||
const (
|
||||
EscapeSimple EscapeContext = iota
|
||||
)
|
||||
|
||||
func (_ EscapeContext) SQLEscape(d DialectType, s string) (string, error) {
|
||||
builder := strings.Builder{}
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
builder.WriteRune('\'')
|
||||
for _, c := range s {
|
||||
if c == '\'' {
|
||||
builder.WriteString("''")
|
||||
} else {
|
||||
builder.WriteRune(c)
|
||||
}
|
||||
}
|
||||
builder.WriteRune('\'')
|
||||
case DialectMySQL:
|
||||
builder.WriteRune('\'')
|
||||
for _, c := range s {
|
||||
switch c {
|
||||
case 0:
|
||||
builder.WriteString("\\0")
|
||||
case '\'':
|
||||
builder.WriteString("\\'")
|
||||
case '"':
|
||||
builder.WriteString("\\\"")
|
||||
case '\b':
|
||||
builder.WriteString("\\b")
|
||||
case '\n':
|
||||
builder.WriteString("\\n")
|
||||
case '\r':
|
||||
builder.WriteString("\\r")
|
||||
case '\t':
|
||||
builder.WriteString("\\t")
|
||||
case '\\':
|
||||
builder.WriteString("\\\\")
|
||||
default:
|
||||
builder.WriteRune(c)
|
||||
}
|
||||
}
|
||||
builder.WriteRune('\'')
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
8
go.mod
8
go.mod
|
@ -34,7 +34,7 @@ require (
|
|||
github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be // indirect
|
||||
github.com/smartystreets/assertions v0.0.0-20190116191733-b6c0e53d7304 // indirect
|
||||
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c // indirect
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/urfave/cli/v2 v2.27.1
|
||||
github.com/writeas/activity v0.1.2
|
||||
github.com/writeas/activityserve v0.0.0-20230428180247-dc13a4f4d835
|
||||
|
@ -49,8 +49,8 @@ require (
|
|||
github.com/writeas/web-core v1.6.1-0.20231003013047-d81124d45431
|
||||
github.com/writefreely/go-gopher v0.0.0-20220429181814-40127126f83b
|
||||
github.com/writefreely/go-nodeinfo v1.2.0
|
||||
golang.org/x/crypto v0.18.0
|
||||
golang.org/x/net v0.20.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/net v0.22.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
@ -83,7 +83,7 @@ require (
|
|||
github.com/writeas/go-writeas/v2 v2.0.2 // indirect
|
||||
github.com/writeas/openssl-go v1.0.0 // indirect
|
||||
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
|
||||
golang.org/x/sys v0.16.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
gopkg.in/ini.v1 v1.62.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
|
16
go.sum
16
go.sum
|
@ -166,8 +166,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
|
|||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/urfave/cli/v2 v2.27.1 h1:8xSQ6szndafKVRmfyeUMxkNUJQMjL1F2zmsZ+qHpfho=
|
||||
github.com/urfave/cli/v2 v2.27.1/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ=
|
||||
github.com/writeas/activity v0.1.2 h1:Y12B5lIrabfqKE7e7HFCWiXrlfXljr9tlkFm2mp7DgY=
|
||||
|
@ -213,8 +213,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||
golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
|
@ -233,8 +233,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
|||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
|
||||
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
|
@ -261,8 +261,8 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
|
|
|
@ -26,8 +26,8 @@ func oauth(db *datastore) error {
|
|||
createTableUsersOauth, err := dialect.
|
||||
Table("oauth_users").
|
||||
SetIfNotExists(false).
|
||||
Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
|
||||
Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
|
||||
Column(wf_db.NonNullableColumn("user_id", wf_db.ColumnTypeInt{MaxBytes: 4})).
|
||||
Column(wf_db.NonNullableColumn("remote_user_id", wf_db.ColumnTypeInt{MaxBytes: 4})).
|
||||
ToSQL()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -35,9 +35,9 @@ func oauth(db *datastore) error {
|
|||
createTableOauthClientState, err := dialect.
|
||||
Table("oauth_client_states").
|
||||
SetIfNotExists(false).
|
||||
Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})).
|
||||
Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)).
|
||||
Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefaultCurrentTimestamp()).
|
||||
Column(wf_db.NonNullableColumn("state", wf_db.ColumnTypeString{MaxChars: 255})).
|
||||
Column(wf_db.NonNullableColumn("used", wf_db.ColumnTypeBool{})).
|
||||
Column(wf_db.NonNullableColumn("created_at", wf_db.ColumnTypeDateTime{DefaultVal: wf_db.DefaultNow})).
|
||||
UniqueConstraint("state").
|
||||
ToSQL()
|
||||
if err != nil {
|
||||
|
|
|
@ -26,39 +26,55 @@ func oauthSlack(db *datastore) error {
|
|||
builders := []wf_db.SQLBuilder{
|
||||
dialect.
|
||||
AlterTable("oauth_client_states").
|
||||
AddColumn(dialect.
|
||||
Column(
|
||||
AddColumn(wf_db.
|
||||
NonNullableColumn(
|
||||
"provider",
|
||||
wf_db.ColumnTypeVarChar,
|
||||
wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")),
|
||||
wf_db.ColumnTypeString{
|
||||
MaxChars: 24,
|
||||
HasDefault: true,
|
||||
DefaultVal: "",
|
||||
})),
|
||||
dialect.
|
||||
AlterTable("oauth_client_states").
|
||||
AddColumn(dialect.
|
||||
Column(
|
||||
AddColumn(wf_db.
|
||||
NonNullableColumn(
|
||||
"client_id",
|
||||
wf_db.ColumnTypeVarChar,
|
||||
wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")),
|
||||
wf_db.ColumnTypeString{
|
||||
MaxChars: 128,
|
||||
HasDefault: true,
|
||||
DefaultVal: "",
|
||||
},
|
||||
)),
|
||||
dialect.
|
||||
AlterTable("oauth_users").
|
||||
AddColumn(dialect.
|
||||
Column(
|
||||
AddColumn(wf_db.
|
||||
NonNullableColumn(
|
||||
"provider",
|
||||
wf_db.ColumnTypeVarChar,
|
||||
wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")),
|
||||
wf_db.ColumnTypeString{
|
||||
MaxChars: 24,
|
||||
HasDefault: true,
|
||||
DefaultVal: "",
|
||||
})),
|
||||
dialect.
|
||||
AlterTable("oauth_users").
|
||||
AddColumn(dialect.
|
||||
Column(
|
||||
AddColumn(wf_db.
|
||||
NonNullableColumn(
|
||||
"client_id",
|
||||
wf_db.ColumnTypeVarChar,
|
||||
wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")),
|
||||
wf_db.ColumnTypeString{
|
||||
MaxChars: 128,
|
||||
HasDefault: true,
|
||||
DefaultVal: "",
|
||||
})),
|
||||
dialect.
|
||||
AlterTable("oauth_users").
|
||||
AddColumn(dialect.
|
||||
Column(
|
||||
AddColumn(wf_db.
|
||||
NonNullableColumn(
|
||||
"access_token",
|
||||
wf_db.ColumnTypeVarChar,
|
||||
wf_db.OptionalInt{Set: true, Value: 512}).SetDefault("")),
|
||||
wf_db.ColumnTypeString{
|
||||
MaxChars: 512,
|
||||
HasDefault: true,
|
||||
DefaultVal: "",
|
||||
})),
|
||||
dialect.CreateUniqueIndex("oauth_users_uk", "oauth_users", "user_id", "provider", "client_id"),
|
||||
}
|
||||
|
||||
|
@ -67,11 +83,12 @@ func oauthSlack(db *datastore) error {
|
|||
builders = append(builders, dialect.
|
||||
AlterTable("oauth_users").
|
||||
ChangeColumn("remote_user_id",
|
||||
dialect.
|
||||
Column(
|
||||
wf_db.
|
||||
NonNullableColumn(
|
||||
"remote_user_id",
|
||||
wf_db.ColumnTypeVarChar,
|
||||
wf_db.OptionalInt{Set: true, Value: 128})))
|
||||
wf_db.ColumnTypeString{
|
||||
MaxChars: 128,
|
||||
})))
|
||||
}
|
||||
|
||||
for _, builder := range builders {
|
||||
|
|
|
@ -26,11 +26,13 @@ func oauthAttach(db *datastore) error {
|
|||
builders := []wf_db.SQLBuilder{
|
||||
dialect.
|
||||
AlterTable("oauth_client_states").
|
||||
AddColumn(dialect.
|
||||
Column(
|
||||
AddColumn(wf_db.
|
||||
NullableColumn(
|
||||
"attach_user_id",
|
||||
wf_db.ColumnTypeInteger,
|
||||
wf_db.OptionalInt{Set: true, Value: 24}).SetNullable(true)),
|
||||
wf_db.ColumnTypeInt{
|
||||
MaxBytes: 4,
|
||||
MaxDigits: 24,
|
||||
})),
|
||||
}
|
||||
for _, builder := range builders {
|
||||
query, err := builder.ToSQL()
|
||||
|
|
|
@ -26,10 +26,10 @@ func oauthInvites(db *datastore) error {
|
|||
builders := []wf_db.SQLBuilder{
|
||||
dialect.
|
||||
AlterTable("oauth_client_states").
|
||||
AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{
|
||||
Set: true,
|
||||
Value: 6,
|
||||
}).SetNullable(true)),
|
||||
AddColumn(wf_db.NullableColumn("invite_code", wf_db.ColumnTypeString{
|
||||
IsFixedLength: true,
|
||||
MaxChars: 6,
|
||||
})),
|
||||
}
|
||||
for _, builder := range builders {
|
||||
query, err := builder.ToSQL()
|
||||
|
|
2
posts.go
2
posts.go
|
@ -1669,7 +1669,7 @@ func (rp *RawPost) Updated8601() string {
|
|||
return rp.Updated.Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
|
||||
var imageURLRegex = regexp.MustCompile(`(?i)[^ ]+\.(gif|png|jpg|jpeg|image)$`)
|
||||
var imageURLRegex = regexp.MustCompile(`(?i)[^ ]+\.(gif|png|jpg|jpeg|avif|avifs|webp|jxl|image)$`)
|
||||
|
||||
func (p *Post) extractImages() {
|
||||
p.Images = extractImages(p.Content)
|
||||
|
|
8
read.go
8
read.go
|
@ -229,11 +229,9 @@ func showLocalTimeline(app *App, w http.ResponseWriter, r *http.Request, page in
|
|||
TotalPages: ttlPages,
|
||||
SelTopic: tag,
|
||||
}
|
||||
if app.cfg.App.Chorus {
|
||||
u := getUserSession(app, r)
|
||||
d.IsAdmin = u != nil && u.IsAdmin()
|
||||
d.CanInvite = canUserInvite(app.cfg, d.IsAdmin)
|
||||
}
|
||||
u := getUserSession(app, r)
|
||||
d.IsAdmin = u != nil && u.IsAdmin()
|
||||
d.CanInvite = canUserInvite(app.cfg, d.IsAdmin)
|
||||
c, err := getReaderSection(app)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
Loading…
Reference in New Issue