This commit is contained in:
Clar Fon 2024-04-24 04:21:45 -07:00 committed by GitHub
commit 187855c408
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 663 additions and 400 deletions

View File

@ -12,15 +12,15 @@ type AlterTableSqlBuilder struct {
}
func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
if colVal, err := col.CreateSQL(b.Dialect); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal))
}
return b
}
func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder {
if colVal, err := col.String(); err == nil {
b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal))
if colActions, err := col.AlterSQL(b.Dialect, name); err == nil {
b.Changes = append(b.Changes, colActions...)
}
return b
}

View File

@ -18,28 +18,43 @@ func TestAlterTableSqlBuilder_ToSQL(t *testing.T) {
name: "MySQL add int",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)),
want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL",
AddColumn(NonNullableColumn("the_col", ColumnTypeInt{MaxBytes: 4})),
want: "ALTER TABLE the_table ADD COLUMN the_col INTEGER NOT NULL",
wantErr: false,
},
{
name: "MySQL add string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})),
AddColumn(NonNullableColumn("the_col", ColumnTypeString{MaxChars: 128})),
want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL",
wantErr: false,
},
{
name: "MySQL add int and string",
builder: DialectMySQL.
AlterTable("the_table").
AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)).
AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})),
AddColumn(NonNullableColumn("first_col", ColumnTypeInt{MaxBytes: 4})).
AddColumn(NonNullableColumn("second_col", ColumnTypeString{MaxChars: 128})),
want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL",
wantErr: false,
},
{
name: "MySQL change to string",
builder: DialectMySQL.
AlterTable("the_table").
ChangeColumn("old_col", NonNullableColumn("new_col", ColumnTypeString{})),
want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, MODIFY COLUMN new_col TEXT NOT NULL",
wantErr: false,
},
{
name: "PostgreSQL change to int",
builder: DialectMySQL.
AlterTable("the_table").
ChangeColumn("old_col", NullableColumn("new_col", ColumnTypeInt{MaxBytes: 4})),
want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, ALTER COLUMN new_col TYPE INTEGER, ALTER COLUMN new_col DROP NOT NULL",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

15
db/builder.go Normal file
View File

@ -0,0 +1,15 @@
/*
* Copyright © 2019-2022 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package db
type SQLBuilder interface {
ToSQL() (string, error)
}

328
db/column.go Normal file
View File

@ -0,0 +1,328 @@
/*
* Copyright © 2019-2022 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package db
import (
"fmt"
"strings"
)
type Column struct {
Name string
Type ColumnType
Nullable bool
PrimaryKey bool
}
func NullableColumn(name string, ty ColumnType) *Column {
return &Column{
Name: name,
Type: ty,
Nullable: true,
PrimaryKey: false,
}
}
func NonNullableColumn(name string, ty ColumnType) *Column {
return &Column{
Name: name,
Type: ty,
Nullable: false,
PrimaryKey: false,
}
}
func PrimaryKeyColumn(name string, ty ColumnType) *Column {
return &Column{
Name: name,
Type: ty,
Nullable: false,
PrimaryKey: true,
}
}
type ColumnType interface {
Name(DialectType) (string, error)
Default(DialectType) (string, error)
}
type ColumnTypeInt struct {
IsSigned bool
MaxBytes int
MaxDigits int
HasDefault bool
DefaultVal int
}
type ColumnTypeString struct {
IsFixedLength bool
MaxChars int
HasDefault bool
DefaultVal string
}
type ColumnDefault int
type ColumnTypeBool struct {
DefaultVal ColumnDefault
}
const (
NoDefault ColumnDefault = iota
DefaultFalse ColumnDefault = iota
DefaultTrue ColumnDefault = iota
DefaultNow ColumnDefault = iota
)
type ColumnTypeDateTime struct {
DefaultVal ColumnDefault
}
func (intCol ColumnTypeInt) Name(d DialectType) (string, error) {
switch d {
case DialectSQLite:
return "INTEGER", nil
case DialectMySQL, DialectPostgreSQL:
var colName string
switch intCol.MaxBytes {
case 1:
if d == DialectMySQL {
colName = "TINYINT"
} else {
colName = "SMALLINT"
}
case 2:
colName = "SMALLINT"
case 3:
if d == DialectMySQL {
colName = "MEDIUMINT"
} else {
colName = "INTEGER"
}
case 4:
colName = "INTEGER"
default:
colName = "BIGINT"
}
if d == DialectMySQL {
if intCol.MaxDigits > 0 {
colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits)
}
if !intCol.IsSigned {
colName += " UNSIGNED"
}
}
return colName, nil
default:
return "", fmt.Errorf("dialect %d does not support integer columns", d)
}
}
func (intCol ColumnTypeInt) Default(d DialectType) (string, error) {
if intCol.HasDefault {
return fmt.Sprintf("%d", intCol.DefaultVal), nil
}
return "", nil
}
func (strCol ColumnTypeString) Name(d DialectType) (string, error) {
switch d {
case DialectSQLite:
return "TEXT", nil
case DialectMySQL, DialectPostgreSQL:
if strCol.IsFixedLength {
if strCol.MaxChars > 0 {
return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil
}
return "CHAR", nil
}
if strCol.MaxChars <= 0 {
return "TEXT", nil
}
if strCol.MaxChars < (1 << 16) {
return fmt.Sprintf("VARCHAR(%d)", strCol.MaxChars), nil
}
return "TEXT", nil
default:
return "", fmt.Errorf("dialect %d does not support string columns", d)
}
}
func (strCol ColumnTypeString) Default(d DialectType) (string, error) {
if strCol.HasDefault {
return EscapeSimple.SQLEscape(d, strCol.DefaultVal)
}
return "", nil
}
func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) {
switch d {
case DialectSQLite:
return "INTEGER", nil
case DialectMySQL, DialectPostgreSQL:
return "BOOL", nil
default:
return "", fmt.Errorf("boolean column type not supported for dialect %d", d)
}
}
func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) {
switch boolCol.DefaultVal {
case NoDefault:
return "", nil
case DefaultFalse:
return "0", nil
case DefaultTrue:
return "1", nil
default:
return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d)
}
}
func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) {
switch d {
case DialectSQLite, DialectMySQL:
return "DATETIME", nil
case DialectPostgreSQL:
return "TIMESTAMP", nil
default:
return "", fmt.Errorf("datetime column type not supported for dialect %d", d)
}
}
func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) {
switch d {
case DialectSQLite, DialectMySQL:
switch dateTimeCol.DefaultVal {
case NoDefault:
return "", nil
case DefaultNow:
switch d {
case DialectSQLite, DialectPostgreSQL:
return "CURRENT_TIMESTAMP", nil
case DialectMySQL:
return "NOW()", nil
}
}
return "", fmt.Errorf("datetime columns cannot default to %d for dialect %d", dateTimeCol.DefaultVal, d)
default:
return "", fmt.Errorf("dialect %d does not support defaulted datetime columns", d)
}
}
func (c *Column) SetName(name string) *Column {
c.Name = name
return c
}
func (c *Column) SetNullable(nullable bool) *Column {
c.Nullable = nullable
return c
}
func (c *Column) SetPrimaryKey(pk bool) *Column {
c.PrimaryKey = pk
return c
}
func (c *Column) SetType(t ColumnType) *Column {
c.Type = t
return c
}
func (c *Column) AlterSQL(d DialectType, oldName string) ([]string, error) {
var actions []string = make([]string, 0)
switch d {
// MySQL does all modifications at once
case DialectMySQL:
sql, err := c.CreateSQL(d)
if err != nil {
return make([]string, 0), err
}
actions = append(actions, fmt.Sprintf("CHANGE COLUMN %s %s", oldName, sql))
// PostgreSQL does modifications piece by piece
case DialectPostgreSQL:
if oldName != c.Name {
actions = append(actions, fmt.Sprintf("RENAME COLUMN %s TO %s", oldName, c.Name))
}
typeStr, err := c.Type.Name(d)
if err != nil {
return make([]string, 0), err
}
actions = append(actions, fmt.Sprintf("ALTER COLUMN %s TYPE %s", c.Name, typeStr))
var nullAction string
if c.Nullable {
nullAction = "DROP"
} else {
nullAction = "SET"
}
actions = append(actions, fmt.Sprintf("ALTER COLUMN %s %s NOT NULL", c.Name, nullAction))
defaultStr, err := c.Type.Default(d)
if err != nil {
return make([]string, 0), err
}
if len(defaultStr) > 0 {
actions = append(actions, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", c.Name, defaultStr))
}
if c.PrimaryKey {
actions = append(actions, fmt.Sprintf("ADD PRIMARY KEY (%s)", c.Name))
}
default:
return make([]string, 0), fmt.Errorf("dialect %d doesn't support altering column data type", d)
}
return actions, nil
}
func (c *Column) CreateSQL(d DialectType) (string, error) {
var str strings.Builder
str.WriteString(c.Name)
str.WriteString(" ")
typeStr, err := c.Type.Name(d)
if err != nil {
return "", err
}
str.WriteString(typeStr)
if !c.Nullable {
str.WriteString(" NOT NULL")
}
defaultStr, err := c.Type.Default(d)
if err != nil {
return "", err
}
if len(defaultStr) > 0 {
str.WriteString(" DEFAULT ")
str.WriteString(defaultStr)
}
if c.PrimaryKey {
str.WriteString(" PRIMARY KEY")
}
return str.String(), nil
}

151
db/column_test.go Normal file
View File

@ -0,0 +1,151 @@
package db
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestColumnType_Name(t *testing.T) {
tests := []struct {
name string
ty ColumnType
d DialectType
want string
wantErr bool
}{
{"SQLite bool", ColumnTypeBool{}, DialectSQLite, "INTEGER", false},
{"SQLite int", ColumnTypeInt{}, DialectSQLite, "INTEGER", false},
{"SQLite string ", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "TEXT DEFAULT 'that''s a default'", false},
{"SQLite datetime", ColumnTypeDateTime{}, DialectSQLite, "DATETIME", false},
{"MySQL bool", ColumnTypeBool{}, DialectMySQL, "BOOL", false},
{"MySQL tiny int", ColumnTypeInt{MaxBytes: 1}, DialectMySQL, "TINYINT", false},
{"MySQL tiny int with digits", ColumnTypeInt{MaxBytes: 1, MaxDigits: 2}, DialectMySQL, "TINYINT(2)", false},
{"MySQL small int", ColumnTypeInt{MaxBytes: 2}, DialectMySQL, "SMALLINT", false},
{"MySQL small int with digits", ColumnTypeInt{MaxBytes: 2, MaxDigits: 3}, DialectMySQL, "SMALLINT(3)", false},
{"MySQL medium int", ColumnTypeInt{MaxBytes: 3}, DialectMySQL, "MEDIUMINT", false},
{"MySQL medium int with digits", ColumnTypeInt{MaxBytes: 3, MaxDigits: 6}, DialectMySQL, "MEDIUMINT(6)", false},
{"MySQL int", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "INTEGER", false},
{"MySQL int with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 11}, DialectMySQL, "INTEGER(11)", false},
{"MySQL bigint", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "BIGINT", false},
{"MySQL bigint with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 15}, DialectMySQL, "BIGINT(15)", false},
{"MySQL char", ColumnTypeString{IsFixedLength: true}, DialectMySQL, "CHAR", false},
{"MySQL char with length", ColumnTypeString{IsFixedLength: true, MaxChars: 4}, DialectMySQL, "CHAR(4)", false},
{"MySQL varchar with length", ColumnTypeString{MaxChars: 25}, DialectMySQL, "VARCHAR(25)", false},
{"MySQL text", ColumnTypeString{}, DialectMySQL, "TEXT", false},
{"MySQL datetime", ColumnTypeDateTime{}, DialectMySQL, "DATETIME", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.ty.Name(tt.d)
if (err != nil) != tt.wantErr {
t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Name() got = %v, want %v", got, tt.want)
}
})
}
}
func TestColumnType_Default(t *testing.T) {
tests := []struct {
name string
ty ColumnType
d DialectType
want string
wantErr bool
}{
{"SQLite bool none", ColumnTypeBool{}, DialectSQLite, "", false},
{"SQLite bool false", ColumnTypeBool{}, DialectSQLite, "0", false},
{"SQLite bool true", ColumnTypeBool{}, DialectSQLite, "1", false},
{"SQLite int none", ColumnTypeInt{}, DialectSQLite, "", false},
{"SQLite int empty", ColumnTypeInt{HasDefault: true}, DialectSQLite, "0", false},
{"SQLite int", ColumnTypeInt{HasDefault: true, DefaultVal: 10}, DialectSQLite, "10", false},
{"SQLite string none", ColumnTypeString{}, DialectSQLite, "", false},
{"SQLite string empty", ColumnTypeString{HasDefault: true}, DialectSQLite, "''", false},
{"SQLite string", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "'that''s a default'", false},
{"MySQL string", ColumnTypeString{HasDefault: true, DefaultVal: "%that's a default%"}, DialectMySQL, "'%that\\'s a default%'", false},
{"SQLite datetime none", ColumnTypeDateTime{}, DialectSQLite, "", false},
{"SQLite datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectSQLite, "CURRENT_TIMESTAMP", false},
{"MySQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectMySQL, "NOW()", false},
{"PostgreSQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectPostgreSQL, "CURRENT_TIMESTAMP", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.ty.Default(tt.d)
if (err != nil) != tt.wantErr {
t.Errorf("Default() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Default() got = %v, want %v", got, tt.want)
}
})
}
}
func TestColumn_CreateSQL(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Nullable bool
Type ColumnType
PrimaryKey bool
}
tests := []struct {
name string
fields fields
want string
wantErr bool
}{
{"SQLite bool", fields{DialectSQLite, "foo", false, ColumnTypeBool{}, false}, "foo INTEGER NOT NULL", false},
{"SQLite bool nullable", fields{DialectSQLite, "foo", true, ColumnTypeBool{}, false}, "foo INTEGER", false},
{"SQLite int", fields{DialectSQLite, "foo", false, ColumnTypeInt{}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
{"SQLite int nullable", fields{DialectSQLite, "foo", true, ColumnTypeInt{}, false}, "foo INTEGER", false},
{"SQLite text", fields{DialectSQLite, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false},
{"SQLite text nullable", fields{DialectSQLite, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false},
{"SQLite datetime", fields{DialectSQLite, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false},
{"SQLite datetime nullable", fields{DialectSQLite, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false},
{"MySQL bool", fields{DialectMySQL, "foo", false, ColumnTypeBool{}, false}, "foo TINYINT(1) NOT NULL", false},
{"MySQL bool nullable", fields{DialectMySQL, "foo", true, ColumnTypeBool{}, false}, "foo TINYINT(1)", false},
{"MySQL tiny int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 1}, true}, "foo TINYINT NOT NULL PRIMARY KEY", false},
{"MySQL tiny int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 1}, false}, "foo TINYINT", false},
{"MySQL small int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 2}, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
{"MySQL small int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 2}, false}, "foo SMALLINT", false},
{"MySQL int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 4}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
{"MySQL int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 4}, false}, "foo INTEGER", false},
{"MySQL big int", fields{DialectMySQL, "foo", false, ColumnTypeInt{}, true}, "foo BIGINT NOT NULL PRIMARY KEY", false},
{"MySQL big int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{}, false}, "foo BIGINT", false},
{"MySQL char", fields{DialectMySQL, "foo", false, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR NOT NULL", false},
{"MySQL char nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR", false},
{"MySQL varchar", fields{DialectMySQL, "foo", false, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255) NOT NULL", false},
{"MySQL varchar nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255)", false},
{"MySQL text", fields{DialectMySQL, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false},
{"MySQL text nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false},
{"MySQL datetime", fields{DialectMySQL, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false},
{"MySQL datetime nullable", fields{DialectMySQL, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Column{
Name: tt.fields.Name,
Nullable: tt.fields.Nullable,
Type: tt.fields.Type,
PrimaryKey: tt.fields.PrimaryKey,
}
if got, err := c.CreateSQL(tt.fields.Dialect); got != tt.want {
if (err != nil) != tt.wantErr {
t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("String() got = %v, want %v", got, tt.want)
}
}
})
}
}

View File

@ -15,32 +15,6 @@ import (
"strings"
)
type ColumnType int
type OptionalInt struct {
Set bool
Value int
}
type OptionalString struct {
Set bool
Value string
}
type SQLBuilder interface {
ToSQL() (string, error)
}
type Column struct {
Dialect DialectType
Name string
Nullable bool
Default OptionalString
Type ColumnType
Size OptionalInt
PrimaryKey bool
}
type CreateTableSqlBuilder struct {
Dialect DialectType
Name string
@ -50,157 +24,6 @@ type CreateTableSqlBuilder struct {
Constraints []string
}
const (
ColumnTypeBool ColumnType = iota
ColumnTypeSmallInt ColumnType = iota
ColumnTypeInteger ColumnType = iota
ColumnTypeChar ColumnType = iota
ColumnTypeVarChar ColumnType = iota
ColumnTypeText ColumnType = iota
ColumnTypeDateTime ColumnType = iota
)
var _ SQLBuilder = &CreateTableSqlBuilder{}
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
if dialect != DialectMySQL && dialect != DialectSQLite {
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
}
switch d {
case ColumnTypeSmallInt:
{
if dialect == DialectSQLite {
return "INTEGER", nil
}
mod := ""
if size.Set {
mod = fmt.Sprintf("(%d)", size.Value)
}
return "SMALLINT" + mod, nil
}
case ColumnTypeInteger:
{
if dialect == DialectSQLite {
return "INTEGER", nil
}
mod := ""
if size.Set {
mod = fmt.Sprintf("(%d)", size.Value)
}
return "INT" + mod, nil
}
case ColumnTypeChar:
{
if dialect == DialectSQLite {
return "TEXT", nil
}
mod := ""
if size.Set {
mod = fmt.Sprintf("(%d)", size.Value)
}
return "CHAR" + mod, nil
}
case ColumnTypeVarChar:
{
if dialect == DialectSQLite {
return "TEXT", nil
}
mod := ""
if size.Set {
mod = fmt.Sprintf("(%d)", size.Value)
}
return "VARCHAR" + mod, nil
}
case ColumnTypeBool:
{
if dialect == DialectSQLite {
return "INTEGER", nil
}
return "TINYINT(1)", nil
}
case ColumnTypeDateTime:
return "DATETIME", nil
case ColumnTypeText:
return "TEXT", nil
}
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
}
func (c *Column) SetName(name string) *Column {
c.Name = name
return c
}
func (c *Column) SetNullable(nullable bool) *Column {
c.Nullable = nullable
return c
}
func (c *Column) SetPrimaryKey(pk bool) *Column {
c.PrimaryKey = pk
return c
}
func (c *Column) SetDefault(value string) *Column {
c.Default = OptionalString{Set: true, Value: value}
return c
}
func (c *Column) SetDefaultCurrentTimestamp() *Column {
def := "NOW()"
if c.Dialect == DialectSQLite {
def = "CURRENT_TIMESTAMP"
}
c.Default = OptionalString{Set: true, Value: def}
return c
}
func (c *Column) SetType(t ColumnType) *Column {
c.Type = t
return c
}
func (c *Column) SetSize(size int) *Column {
c.Size = OptionalInt{Set: true, Value: size}
return c
}
func (c *Column) String() (string, error) {
var str strings.Builder
str.WriteString(c.Name)
str.WriteString(" ")
typeStr, err := c.Type.Format(c.Dialect, c.Size)
if err != nil {
return "", err
}
str.WriteString(typeStr)
if !c.Nullable {
str.WriteString(" NOT NULL")
}
if c.Default.Set {
str.WriteString(" DEFAULT ")
val := c.Default.Value
if val == "" {
val = "''"
}
str.WriteString(val)
}
if c.PrimaryKey {
str.WriteString(" PRIMARY KEY")
}
return str.String(), nil
}
func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
if b.Columns == nil {
b.Columns = make(map[string]*Column)
@ -241,7 +64,7 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
if !ok {
return "", fmt.Errorf("column not found: %s", columnName)
}
columnStr, err := column.String()
columnStr, err := column.CreateSQL(b.Dialect)
if err != nil {
return "", err
}

View File

@ -5,139 +5,13 @@ import (
"testing"
)
func TestDialect_Column(t *testing.T) {
c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize)
assert.Equal(t, DialectSQLite, c1.Dialect)
c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize)
assert.Equal(t, DialectMySQL, c2.Dialect)
}
func TestColumnType_Format(t *testing.T) {
type args struct {
dialect DialectType
size OptionalInt
}
tests := []struct {
name string
d ColumnType
args args
want string
wantErr bool
}{
{"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false},
{"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false},
{"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false},
{"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false},
{"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false},
{"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false},
{"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false},
{"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false},
{"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false},
{"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false},
{"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false},
{"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false},
{"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false},
{"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false},
{"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false},
{"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false},
{"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false},
{"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false},
{"invalid column type", 10000, args{dialect: DialectMySQL}, "", true},
{"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.Format(tt.args.dialect, tt.args.size)
if (err != nil) != tt.wantErr {
t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Format() got = %v, want %v", got, tt.want)
}
})
}
}
func TestColumn_Build(t *testing.T) {
type fields struct {
Dialect DialectType
Name string
Nullable bool
Default OptionalString
Type ColumnType
Size OptionalInt
PrimaryKey bool
}
tests := []struct {
name string
fields fields
want string
wantErr bool
}{
{"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false},
{"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false},
{"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false},
{"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false},
{"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false},
{"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false},
{"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
{"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false},
{"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false},
{"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false},
{"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
{"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
{"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
{"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
{"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false},
{"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false},
{"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false},
{"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false},
{"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false},
{"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false},
{"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false},
{"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false},
{"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false},
{"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false},
{"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false},
{"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false},
{"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false},
{"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Column{
Dialect: tt.fields.Dialect,
Name: tt.fields.Name,
Nullable: tt.fields.Nullable,
Default: tt.fields.Default,
Type: tt.fields.Type,
Size: tt.fields.Size,
PrimaryKey: tt.fields.PrimaryKey,
}
if got, err := c.String(); got != tt.want {
if (err != nil) != tt.wantErr {
t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("String() got = %v, want %v", got, tt.want)
}
}
})
}
}
func TestCreateTableSqlBuilder_ToSQL(t *testing.T) {
sql, err := DialectMySQL.
Table("foo").
SetIfNotExists(true).
Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)).
Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)).
Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")).
Column(PrimaryKeyColumn("bar", ColumnTypeInt{MaxBytes: 4})).
Column(NonNullableColumn("baz", ColumnTypeString{})).
Column(NonNullableColumn("qux", ColumnTypeDateTime{DefaultVal: DefaultNow})).
UniqueConstraint("bar").
UniqueConstraint("bar", "baz").
ToSQL()

View File

@ -5,72 +5,47 @@ import "fmt"
type DialectType int
const (
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
DialectSQLite DialectType = iota
DialectMySQL DialectType = iota
DialectPostgreSQL DialectType = iota
)
func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
func (d DialectType) IsKnown() bool {
switch d {
case DialectSQLite:
return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
case DialectMySQL:
return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
case DialectSQLite, DialectMySQL, DialectPostgreSQL:
return true
default:
return false
}
}
func (d DialectType) AssertKnown() {
if !d.IsKnown() {
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
}
func (d DialectType) Table(name string) *CreateTableSqlBuilder {
switch d {
case DialectSQLite:
return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
d.AssertKnown()
return &CreateTableSqlBuilder{Dialect: d, Name: name}
}
func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder {
switch d {
case DialectSQLite:
return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name}
case DialectMySQL:
return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
d.AssertKnown()
return &AlterTableSqlBuilder{Dialect: d, Name: name}
}
func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
switch d {
case DialectSQLite:
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns}
case DialectMySQL:
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
d.AssertKnown()
return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: true, Columns: columns}
}
func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder {
switch d {
case DialectSQLite:
return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns}
case DialectMySQL:
return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
d.AssertKnown()
return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: false, Columns: columns}
}
func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder {
switch d {
case DialectSQLite:
return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table}
case DialectMySQL:
return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table}
default:
panic(fmt.Sprintf("unexpected dialect: %d", d))
}
d.AssertKnown()
return &DropIndexSqlBuilder{Dialect: d, Name: name, Table: table}
}

63
db/escape.go Normal file
View File

@ -0,0 +1,63 @@
/*
* Copyright © 2019-2022 A Bunch Tell LLC.
*
* This file is part of WriteFreely.
*
* WriteFreely is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, included
* in the LICENSE file in this source code package.
*/
package db
import (
"strings"
)
type EscapeContext int
const (
EscapeSimple EscapeContext = iota
)
func (_ EscapeContext) SQLEscape(d DialectType, s string) (string, error) {
builder := strings.Builder{}
switch d {
case DialectSQLite:
builder.WriteRune('\'')
for _, c := range s {
if c == '\'' {
builder.WriteString("''")
} else {
builder.WriteRune(c)
}
}
builder.WriteRune('\'')
case DialectMySQL:
builder.WriteRune('\'')
for _, c := range s {
switch c {
case 0:
builder.WriteString("\\0")
case '\'':
builder.WriteString("\\'")
case '"':
builder.WriteString("\\\"")
case '\b':
builder.WriteString("\\b")
case '\n':
builder.WriteString("\\n")
case '\r':
builder.WriteString("\\r")
case '\t':
builder.WriteString("\\t")
case '\\':
builder.WriteString("\\\\")
default:
builder.WriteRune(c)
}
}
builder.WriteRune('\'')
}
return builder.String(), nil
}

View File

@ -26,8 +26,8 @@ func oauth(db *datastore) error {
createTableUsersOauth, err := dialect.
Table("oauth_users").
SetIfNotExists(false).
Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)).
Column(wf_db.NonNullableColumn("user_id", wf_db.ColumnTypeInt{MaxBytes: 4})).
Column(wf_db.NonNullableColumn("remote_user_id", wf_db.ColumnTypeInt{MaxBytes: 4})).
ToSQL()
if err != nil {
return err
@ -35,9 +35,9 @@ func oauth(db *datastore) error {
createTableOauthClientState, err := dialect.
Table("oauth_client_states").
SetIfNotExists(false).
Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})).
Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)).
Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefaultCurrentTimestamp()).
Column(wf_db.NonNullableColumn("state", wf_db.ColumnTypeString{MaxChars: 255})).
Column(wf_db.NonNullableColumn("used", wf_db.ColumnTypeBool{})).
Column(wf_db.NonNullableColumn("created_at", wf_db.ColumnTypeDateTime{DefaultVal: wf_db.DefaultNow})).
UniqueConstraint("state").
ToSQL()
if err != nil {

View File

@ -26,39 +26,55 @@ func oauthSlack(db *datastore) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.
Column(
AddColumn(wf_db.
NonNullableColumn(
"provider",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")),
wf_db.ColumnTypeString{
MaxChars: 24,
HasDefault: true,
DefaultVal: "",
})),
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.
Column(
AddColumn(wf_db.
NonNullableColumn(
"client_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")),
wf_db.ColumnTypeString{
MaxChars: 128,
HasDefault: true,
DefaultVal: "",
},
)),
dialect.
AlterTable("oauth_users").
AddColumn(dialect.
Column(
AddColumn(wf_db.
NonNullableColumn(
"provider",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")),
wf_db.ColumnTypeString{
MaxChars: 24,
HasDefault: true,
DefaultVal: "",
})),
dialect.
AlterTable("oauth_users").
AddColumn(dialect.
Column(
AddColumn(wf_db.
NonNullableColumn(
"client_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")),
wf_db.ColumnTypeString{
MaxChars: 128,
HasDefault: true,
DefaultVal: "",
})),
dialect.
AlterTable("oauth_users").
AddColumn(dialect.
Column(
AddColumn(wf_db.
NonNullableColumn(
"access_token",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 512}).SetDefault("")),
wf_db.ColumnTypeString{
MaxChars: 512,
HasDefault: true,
DefaultVal: "",
})),
dialect.CreateUniqueIndex("oauth_users_uk", "oauth_users", "user_id", "provider", "client_id"),
}
@ -67,11 +83,12 @@ func oauthSlack(db *datastore) error {
builders = append(builders, dialect.
AlterTable("oauth_users").
ChangeColumn("remote_user_id",
dialect.
Column(
wf_db.
NonNullableColumn(
"remote_user_id",
wf_db.ColumnTypeVarChar,
wf_db.OptionalInt{Set: true, Value: 128})))
wf_db.ColumnTypeString{
MaxChars: 128,
})))
}
for _, builder := range builders {

View File

@ -26,11 +26,13 @@ func oauthAttach(db *datastore) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.
Column(
AddColumn(wf_db.
NullableColumn(
"attach_user_id",
wf_db.ColumnTypeInteger,
wf_db.OptionalInt{Set: true, Value: 24}).SetNullable(true)),
wf_db.ColumnTypeInt{
MaxBytes: 4,
MaxDigits: 24,
})),
}
for _, builder := range builders {
query, err := builder.ToSQL()

View File

@ -26,10 +26,10 @@ func oauthInvites(db *datastore) error {
builders := []wf_db.SQLBuilder{
dialect.
AlterTable("oauth_client_states").
AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{
Set: true,
Value: 6,
}).SetNullable(true)),
AddColumn(wf_db.NullableColumn("invite_code", wf_db.ColumnTypeString{
IsFixedLength: true,
MaxChars: 6,
})),
}
for _, builder := range builders {
query, err := builder.ToSQL()