diff --git a/db/alter.go b/db/alter.go index 0564d3e..6028700 100644 --- a/db/alter.go +++ b/db/alter.go @@ -12,15 +12,15 @@ type AlterTableSqlBuilder struct { } func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { - if colVal, err := col.ToSQL(b.Dialect); err == nil { + if colVal, err := col.CreateSQL(b.Dialect); err == nil { b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) } return b } func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { - if colVal, err := col.ToSQL(b.Dialect); err == nil { - b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) + if colActions, err := col.AlterSQL(b.Dialect, name); err == nil { + b.Changes = append(b.Changes, colActions...) } return b } diff --git a/db/alter_test.go b/db/alter_test.go index 4bd58ac..4d47821 100644 --- a/db/alter_test.go +++ b/db/alter_test.go @@ -18,28 +18,43 @@ func TestAlterTableSqlBuilder_ToSQL(t *testing.T) { name: "MySQL add int", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("the_col", ColumnTypeInteger, UnsetSize)), - want: "ALTER TABLE the_table ADD COLUMN the_col INT NOT NULL", + AddColumn(NonNullableColumn("the_col", ColumnTypeInt{MaxBytes: 4})), + want: "ALTER TABLE the_table ADD COLUMN the_col INTEGER NOT NULL", wantErr: false, }, { name: "MySQL add string", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("the_col", ColumnTypeVarChar, OptionalInt{true, 128})), + AddColumn(NonNullableColumn("the_col", ColumnTypeString{MaxChars: 128})), want: "ALTER TABLE the_table ADD COLUMN the_col VARCHAR(128) NOT NULL", wantErr: false, }, - { name: "MySQL add int and string", builder: DialectMySQL. AlterTable("the_table"). - AddColumn(DialectMySQL.Column("first_col", ColumnTypeInteger, UnsetSize)). - AddColumn(DialectMySQL.Column("second_col", ColumnTypeVarChar, OptionalInt{true, 128})), + AddColumn(NonNullableColumn("first_col", ColumnTypeInt{MaxBytes: 4})). + AddColumn(NonNullableColumn("second_col", ColumnTypeString{MaxChars: 128})), want: "ALTER TABLE the_table ADD COLUMN first_col INT NOT NULL, ADD COLUMN second_col VARCHAR(128) NOT NULL", wantErr: false, }, + { + name: "MySQL change to string", + builder: DialectMySQL. + AlterTable("the_table"). + ChangeColumn("old_col", NonNullableColumn("new_col", ColumnTypeString{})), + want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, MODIFY COLUMN new_col TEXT NOT NULL", + wantErr: false, + }, + { + name: "PostgreSQL change to int", + builder: DialectMySQL. + AlterTable("the_table"). + ChangeColumn("old_col", NullableColumn("new_col", ColumnTypeInt{MaxBytes: 4})), + want: "ALTER TABLE the_table RENAME COLUMN old_col TO new_col, ALTER COLUMN new_col TYPE INTEGER, ALTER COLUMN new_col DROP NOT NULL", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/db/column.go b/db/column.go index 4e0bb6d..5bf64e4 100644 --- a/db/column.go +++ b/db/column.go @@ -91,25 +91,35 @@ func (intCol ColumnTypeInt) Name(d DialectType) (string, error) { case DialectSQLite: return "INTEGER", nil - case DialectMySQL: + case DialectMySQL, DialectPostgreSQL: var colName string switch intCol.MaxBytes { case 1: - colName = "TINYINT" + if d == DialectMySQL { + colName = "TINYINT" + } else { + colName = "SMALLINT" + } case 2: colName = "SMALLINT" case 3: - colName = "MEDIUMINT" + if d == DialectMySQL { + colName = "MEDIUMINT" + } else { + colName = "INTEGER" + } case 4: - colName = "INT" + colName = "INTEGER" default: colName = "BIGINT" } - if intCol.MaxDigits > 0 { - colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits) - } - if !intCol.IsSigned { - colName += " UNSIGNED" + if d == DialectMySQL { + if intCol.MaxDigits > 0 { + colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits) + } + if !intCol.IsSigned { + colName += " UNSIGNED" + } } return colName, nil @@ -119,15 +129,10 @@ func (intCol ColumnTypeInt) Name(d DialectType) (string, error) { } func (intCol ColumnTypeInt) Default(d DialectType) (string, error) { - switch d { - case DialectSQLite, DialectMySQL: - if intCol.HasDefault { - return fmt.Sprintf("%d", intCol.DefaultVal), nil - } - return "", nil - default: - return "", fmt.Errorf("dialect %d does not support defaulted integer columns", d) + if intCol.HasDefault { + return fmt.Sprintf("%d", intCol.DefaultVal), nil } + return "", nil } func (strCol ColumnTypeString) Name(d DialectType) (string, error) { @@ -135,7 +140,7 @@ func (strCol ColumnTypeString) Name(d DialectType) (string, error) { case DialectSQLite: return "TEXT", nil - case DialectMySQL: + case DialectMySQL, DialectPostgreSQL: if strCol.IsFixedLength { if strCol.MaxChars > 0 { return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil @@ -157,22 +162,17 @@ func (strCol ColumnTypeString) Name(d DialectType) (string, error) { } func (strCol ColumnTypeString) Default(d DialectType) (string, error) { - switch d { - case DialectSQLite, DialectMySQL: - if strCol.HasDefault { - return EscapeSimple.SQLEscape(d, strCol.DefaultVal) - } - return "", nil - default: - return "", fmt.Errorf("dialect %d does not support defaulted string columns", d) + if strCol.HasDefault { + return EscapeSimple.SQLEscape(d, strCol.DefaultVal) } + return "", nil } func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) { switch d { case DialectSQLite: return "INTEGER", nil - case DialectMySQL: + case DialectMySQL, DialectPostgreSQL: return "BOOL", nil default: return "", fmt.Errorf("boolean column type not supported for dialect %d", d) @@ -180,20 +180,15 @@ func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) { } func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) { - switch d { - case DialectSQLite, DialectMySQL: - switch boolCol.DefaultVal { - case NoDefault: - return "", nil - case DefaultFalse: - return "0", nil - case DefaultTrue: - return "1", nil - default: - return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d) - } + switch boolCol.DefaultVal { + case NoDefault: + return "", nil + case DefaultFalse: + return "0", nil + case DefaultTrue: + return "1", nil default: - return "", fmt.Errorf("dialect %d does not support defaulted boolean columns", d) + return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d) } } @@ -201,6 +196,8 @@ func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) { switch d { case DialectSQLite, DialectMySQL: return "DATETIME", nil + case DialectPostgreSQL: + return "TIMESTAMP", nil default: return "", fmt.Errorf("datetime column type not supported for dialect %d", d) } @@ -214,7 +211,7 @@ func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) { return "", nil case DefaultNow: switch d { - case DialectSQLite: + case DialectSQLite, DialectPostgreSQL: return "CURRENT_TIMESTAMP", nil case DialectMySQL: return "NOW()", nil @@ -246,7 +243,58 @@ func (c *Column) SetType(t ColumnType) *Column { return c } -func (c *Column) ToSQL(d DialectType) (string, error) { +func (c *Column) AlterSQL(d DialectType, oldName string) ([]string, error) { + var actions []string = make([]string, 0) + + switch d { + // MySQL does all modifications at once + case DialectMySQL: + sql, err := c.CreateSQL(d) + if err != nil { + return make([]string, 0), err + } + actions = append(actions, fmt.Sprintf("CHANGE COLUMN %s %s", oldName, sql)) + + // PostgreSQL does modifications piece by piece + case DialectPostgreSQL: + if oldName != c.Name { + actions = append(actions, fmt.Sprintf("RENAME COLUMN %s TO %s", oldName, c.Name)) + } + + typeStr, err := c.Type.Name(d) + if err != nil { + return make([]string, 0), err + } + + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s TYPE %s", c.Name, typeStr)) + var nullAction string + if c.Nullable { + nullAction = "DROP" + } else { + nullAction = "SET" + } + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s %s NOT NULL", c.Name, nullAction)) + + defaultStr, err := c.Type.Default(d) + if err != nil { + return make([]string, 0), err + } + if len(defaultStr) > 0 { + actions = append(actions, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", c.Name, defaultStr)) + } + + if c.PrimaryKey { + actions = append(actions, fmt.Sprintf("ADD PRIMARY KEY (%s)", c.Name)) + } + + default: + return make([]string, 0), fmt.Errorf("dialect %d doesn't support altering column data type", d) + } + + return actions, nil +} + +func (c *Column) CreateSQL(d DialectType) (string, error) { var str strings.Builder str.WriteString(c.Name) diff --git a/db/column_test.go b/db/column_test.go new file mode 100644 index 0000000..175ec3a --- /dev/null +++ b/db/column_test.go @@ -0,0 +1,151 @@ +package db + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestColumnType_Name(t *testing.T) { + tests := []struct { + name string + ty ColumnType + d DialectType + want string + wantErr bool + }{ + {"SQLite bool", ColumnTypeBool{}, DialectSQLite, "INTEGER", false}, + {"SQLite int", ColumnTypeInt{}, DialectSQLite, "INTEGER", false}, + {"SQLite string ", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "TEXT DEFAULT 'that''s a default'", false}, + {"SQLite datetime", ColumnTypeDateTime{}, DialectSQLite, "DATETIME", false}, + + {"MySQL bool", ColumnTypeBool{}, DialectMySQL, "BOOL", false}, + {"MySQL tiny int", ColumnTypeInt{MaxBytes: 1}, DialectMySQL, "TINYINT", false}, + {"MySQL tiny int with digits", ColumnTypeInt{MaxBytes: 1, MaxDigits: 2}, DialectMySQL, "TINYINT(2)", false}, + {"MySQL small int", ColumnTypeInt{MaxBytes: 2}, DialectMySQL, "SMALLINT", false}, + {"MySQL small int with digits", ColumnTypeInt{MaxBytes: 2, MaxDigits: 3}, DialectMySQL, "SMALLINT(3)", false}, + {"MySQL medium int", ColumnTypeInt{MaxBytes: 3}, DialectMySQL, "MEDIUMINT", false}, + {"MySQL medium int with digits", ColumnTypeInt{MaxBytes: 3, MaxDigits: 6}, DialectMySQL, "MEDIUMINT(6)", false}, + {"MySQL int", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "INTEGER", false}, + {"MySQL int with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 11}, DialectMySQL, "INTEGER(11)", false}, + {"MySQL bigint", ColumnTypeInt{MaxBytes: 4}, DialectMySQL, "BIGINT", false}, + {"MySQL bigint with digits", ColumnTypeInt{MaxBytes: 4, MaxDigits: 15}, DialectMySQL, "BIGINT(15)", false}, + {"MySQL char", ColumnTypeString{IsFixedLength: true}, DialectMySQL, "CHAR", false}, + {"MySQL char with length", ColumnTypeString{IsFixedLength: true, MaxChars: 4}, DialectMySQL, "CHAR(4)", false}, + {"MySQL varchar with length", ColumnTypeString{MaxChars: 25}, DialectMySQL, "VARCHAR(25)", false}, + {"MySQL text", ColumnTypeString{}, DialectMySQL, "TEXT", false}, + {"MySQL datetime", ColumnTypeDateTime{}, DialectMySQL, "DATETIME", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.ty.Name(tt.d) + if (err != nil) != tt.wantErr { + t.Errorf("Name() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Name() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumnType_Default(t *testing.T) { + tests := []struct { + name string + ty ColumnType + d DialectType + want string + wantErr bool + }{ + {"SQLite bool none", ColumnTypeBool{}, DialectSQLite, "", false}, + {"SQLite bool false", ColumnTypeBool{}, DialectSQLite, "0", false}, + {"SQLite bool true", ColumnTypeBool{}, DialectSQLite, "1", false}, + {"SQLite int none", ColumnTypeInt{}, DialectSQLite, "", false}, + {"SQLite int empty", ColumnTypeInt{HasDefault: true}, DialectSQLite, "0", false}, + {"SQLite int", ColumnTypeInt{HasDefault: true, DefaultVal: 10}, DialectSQLite, "10", false}, + {"SQLite string none", ColumnTypeString{}, DialectSQLite, "", false}, + {"SQLite string empty", ColumnTypeString{HasDefault: true}, DialectSQLite, "''", false}, + {"SQLite string", ColumnTypeString{HasDefault: true, DefaultVal: "that's a default"}, DialectSQLite, "'that''s a default'", false}, + {"MySQL string", ColumnTypeString{HasDefault: true, DefaultVal: "%that's a default%"}, DialectMySQL, "'%that\\'s a default%'", false}, + + {"SQLite datetime none", ColumnTypeDateTime{}, DialectSQLite, "", false}, + {"SQLite datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectSQLite, "CURRENT_TIMESTAMP", false}, + {"MySQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectMySQL, "NOW()", false}, + {"PostgreSQL datetime now", ColumnTypeDateTime{DefaultVal: DefaultNow}, DialectPostgreSQL, "CURRENT_TIMESTAMP", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.ty.Default(tt.d) + if (err != nil) != tt.wantErr { + t.Errorf("Default() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Default() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumn_CreateSQL(t *testing.T) { + type fields struct { + Dialect DialectType + Name string + Nullable bool + Type ColumnType + PrimaryKey bool + } + tests := []struct { + name string + fields fields + want string + wantErr bool + }{ + {"SQLite bool", fields{DialectSQLite, "foo", false, ColumnTypeBool{}, false}, "foo INTEGER NOT NULL", false}, + {"SQLite bool nullable", fields{DialectSQLite, "foo", true, ColumnTypeBool{}, false}, "foo INTEGER", false}, + {"SQLite int", fields{DialectSQLite, "foo", false, ColumnTypeInt{}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, + {"SQLite int nullable", fields{DialectSQLite, "foo", true, ColumnTypeInt{}, false}, "foo INTEGER", false}, + {"SQLite text", fields{DialectSQLite, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false}, + {"SQLite text nullable", fields{DialectSQLite, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false}, + {"SQLite datetime", fields{DialectSQLite, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false}, + {"SQLite datetime nullable", fields{DialectSQLite, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false}, + + {"MySQL bool", fields{DialectMySQL, "foo", false, ColumnTypeBool{}, false}, "foo TINYINT(1) NOT NULL", false}, + {"MySQL bool nullable", fields{DialectMySQL, "foo", true, ColumnTypeBool{}, false}, "foo TINYINT(1)", false}, + {"MySQL tiny int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 1}, true}, "foo TINYINT NOT NULL PRIMARY KEY", false}, + {"MySQL tiny int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 1}, false}, "foo TINYINT", false}, + {"MySQL small int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 2}, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, + {"MySQL small int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 2}, false}, "foo SMALLINT", false}, + {"MySQL int", fields{DialectMySQL, "foo", false, ColumnTypeInt{MaxBytes: 4}, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, + {"MySQL int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{MaxBytes: 4}, false}, "foo INTEGER", false}, + {"MySQL big int", fields{DialectMySQL, "foo", false, ColumnTypeInt{}, true}, "foo BIGINT NOT NULL PRIMARY KEY", false}, + {"MySQL big int nullable", fields{DialectMySQL, "foo", true, ColumnTypeInt{}, false}, "foo BIGINT", false}, + {"MySQL char", fields{DialectMySQL, "foo", false, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR NOT NULL", false}, + {"MySQL char nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{IsFixedLength: true}, false}, "foo CHAR", false}, + {"MySQL varchar", fields{DialectMySQL, "foo", false, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255) NOT NULL", false}, + {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{MaxChars: 255}, false}, "foo VARCHAR(255)", false}, + {"MySQL text", fields{DialectMySQL, "foo", false, ColumnTypeString{}, false}, "foo TEXT NOT NULL", false}, + {"MySQL text nullable", fields{DialectMySQL, "foo", true, ColumnTypeString{}, false}, "foo TEXT", false}, + {"MySQL datetime", fields{DialectMySQL, "foo", false, ColumnTypeDateTime{}, false}, "foo DATETIME NOT NULL", false}, + {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, ColumnTypeDateTime{}, false}, "foo DATETIME", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Column{ + Name: tt.fields.Name, + Nullable: tt.fields.Nullable, + Type: tt.fields.Type, + PrimaryKey: tt.fields.PrimaryKey, + } + if got, err := c.CreateSQL(tt.fields.Dialect); got != tt.want { + if (err != nil) != tt.wantErr { + t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("String() got = %v, want %v", got, tt.want) + } + } + }) + } +} diff --git a/db/create.go b/db/create.go index a9fad98..1bef61e 100644 --- a/db/create.go +++ b/db/create.go @@ -64,7 +64,7 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) { if !ok { return "", fmt.Errorf("column not found: %s", columnName) } - columnStr, err := column.ToSQL(b.Dialect) + columnStr, err := column.CreateSQL(b.Dialect) if err != nil { return "", err } diff --git a/db/create_test.go b/db/create_test.go index 369d5c1..09efd18 100644 --- a/db/create_test.go +++ b/db/create_test.go @@ -5,139 +5,13 @@ import ( "testing" ) -func TestDialect_Column(t *testing.T) { - c1 := DialectSQLite.Column("foo", ColumnTypeBool, UnsetSize) - assert.Equal(t, DialectSQLite, c1.Dialect) - c2 := DialectMySQL.Column("foo", ColumnTypeBool, UnsetSize) - assert.Equal(t, DialectMySQL, c2.Dialect) -} - -func TestColumnType_Format(t *testing.T) { - type args struct { - dialect DialectType - size OptionalInt - } - tests := []struct { - name string - d ColumnType - args args - want string - wantErr bool - }{ - {"Sqlite bool", ColumnTypeBool, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite small int", ColumnTypeSmallInt, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite int", ColumnTypeInteger, args{dialect: DialectSQLite}, "INTEGER", false}, - {"Sqlite char", ColumnTypeChar, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite varchar", ColumnTypeVarChar, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite text", ColumnTypeText, args{dialect: DialectSQLite}, "TEXT", false}, - {"Sqlite datetime", ColumnTypeDateTime, args{dialect: DialectSQLite}, "DATETIME", false}, - - {"MySQL bool", ColumnTypeBool, args{dialect: DialectMySQL}, "TINYINT(1)", false}, - {"MySQL small int", ColumnTypeSmallInt, args{dialect: DialectMySQL}, "SMALLINT", false}, - {"MySQL small int with param", ColumnTypeSmallInt, args{dialect: DialectMySQL, size: OptionalInt{true, 3}}, "SMALLINT(3)", false}, - {"MySQL int", ColumnTypeInteger, args{dialect: DialectMySQL}, "INT", false}, - {"MySQL int with param", ColumnTypeInteger, args{dialect: DialectMySQL, size: OptionalInt{true, 11}}, "INT(11)", false}, - {"MySQL char", ColumnTypeChar, args{dialect: DialectMySQL}, "CHAR", false}, - {"MySQL char with param", ColumnTypeChar, args{dialect: DialectMySQL, size: OptionalInt{true, 4}}, "CHAR(4)", false}, - {"MySQL varchar", ColumnTypeVarChar, args{dialect: DialectMySQL}, "VARCHAR", false}, - {"MySQL varchar with param", ColumnTypeVarChar, args{dialect: DialectMySQL, size: OptionalInt{true, 25}}, "VARCHAR(25)", false}, - {"MySQL text", ColumnTypeText, args{dialect: DialectMySQL}, "TEXT", false}, - {"MySQL datetime", ColumnTypeDateTime, args{dialect: DialectMySQL}, "DATETIME", false}, - - {"invalid column type", 10000, args{dialect: DialectMySQL}, "", true}, - {"invalid dialect", ColumnTypeBool, args{dialect: 10000}, "", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.d.Format(tt.args.dialect, tt.args.size) - if (err != nil) != tt.wantErr { - t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Format() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestColumn_Build(t *testing.T) { - type fields struct { - Dialect DialectType - Name string - Nullable bool - Default OptionalString - Type ColumnType - Size OptionalInt - PrimaryKey bool - } - tests := []struct { - name string - fields fields - want string - wantErr bool - }{ - {"Sqlite bool", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER NOT NULL", false}, - {"Sqlite bool nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite small int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo INTEGER NOT NULL PRIMARY KEY", false}, - {"Sqlite small int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite int", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER NOT NULL", false}, - {"Sqlite int nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INTEGER", false}, - {"Sqlite char", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite char nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite varchar", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite varchar nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite text", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"Sqlite text nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, - {"Sqlite datetime", fields{DialectSQLite, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, - {"Sqlite datetime nullable", fields{DialectSQLite, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, - - {"MySQL bool", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1) NOT NULL", false}, - {"MySQL bool nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeBool, UnsetSize, false}, "foo TINYINT(1)", false}, - {"MySQL small int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeSmallInt, UnsetSize, true}, "foo SMALLINT NOT NULL PRIMARY KEY", false}, - {"MySQL small int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeSmallInt, UnsetSize, false}, "foo SMALLINT", false}, - {"MySQL int", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT NOT NULL", false}, - {"MySQL int nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeInteger, UnsetSize, false}, "foo INT", false}, - {"MySQL char", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR NOT NULL", false}, - {"MySQL char nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeChar, UnsetSize, false}, "foo CHAR", false}, - {"MySQL varchar", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR NOT NULL", false}, - {"MySQL varchar nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeVarChar, UnsetSize, false}, "foo VARCHAR", false}, - {"MySQL text", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT NOT NULL", false}, - {"MySQL text nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeText, UnsetSize, false}, "foo TEXT", false}, - {"MySQL datetime", fields{DialectMySQL, "foo", false, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME NOT NULL", false}, - {"MySQL datetime nullable", fields{DialectMySQL, "foo", true, UnsetDefault, ColumnTypeDateTime, UnsetSize, false}, "foo DATETIME", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &Column{ - Dialect: tt.fields.Dialect, - Name: tt.fields.Name, - Nullable: tt.fields.Nullable, - Default: tt.fields.Default, - Type: tt.fields.Type, - Size: tt.fields.Size, - PrimaryKey: tt.fields.PrimaryKey, - } - if got, err := c.String(); got != tt.want { - if (err != nil) != tt.wantErr { - t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("String() got = %v, want %v", got, tt.want) - } - } - }) - } -} - func TestCreateTableSqlBuilder_ToSQL(t *testing.T) { sql, err := DialectMySQL. Table("foo"). SetIfNotExists(true). - Column(DialectMySQL.Column("bar", ColumnTypeInteger, UnsetSize).SetPrimaryKey(true)). - Column(DialectMySQL.Column("baz", ColumnTypeText, UnsetSize)). - Column(DialectMySQL.Column("qux", ColumnTypeDateTime, UnsetSize).SetDefault("NOW()")). + Column(PrimaryKeyColumn("bar", ColumnTypeInt{MaxBytes: 4})). + Column(NonNullableColumn("baz", ColumnTypeString{})). + Column(NonNullableColumn("qux", ColumnTypeDateTime{DefaultVal: DefaultNow})). UniqueConstraint("bar"). UniqueConstraint("bar", "baz"). ToSQL() diff --git a/db/dialect.go b/db/dialect.go index ee1eb0f..3e2b90b 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -5,13 +5,14 @@ import "fmt" type DialectType int const ( - DialectSQLite DialectType = iota - DialectMySQL DialectType = iota + DialectSQLite DialectType = iota + DialectMySQL DialectType = iota + DialectPostgreSQL DialectType = iota ) func (d DialectType) IsKnown() bool { switch d { - case DialectSQLite, DialectMySQL: + case DialectSQLite, DialectMySQL, DialectPostgreSQL: return true default: return false