[chore] bump bun library versions (#2837)

This commit is contained in:
kim
2024-04-15 11:01:20 +01:00
committed by GitHub
parent 6bb43f3f9b
commit 1018cde107
22 changed files with 403 additions and 388 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"
"github.com/jinzhu/inflection"
@ -52,12 +51,14 @@ type Table struct {
Alias string
SQLAlias Safe
allFields []*Field // all fields including scanonly
Fields []*Field // PKs + DataFields
PKs []*Field
DataFields []*Field
relFields []*Field
fieldsMapMu sync.RWMutex
FieldMap map[string]*Field
FieldMap map[string]*Field
StructMap map[string]*structField
Relations map[string]*Relation
Unique map[string][]*Field
@ -65,23 +66,38 @@ type Table struct {
SoftDeleteField *Field
UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error
allFields []*Field // read only
flags internal.Flag
}
func newTable(dialect Dialect, typ reflect.Type) *Table {
t := new(Table)
t.dialect = dialect
t.Type = typ
t.ZeroValue = reflect.New(t.Type).Elem()
t.ZeroIface = reflect.New(t.Type).Interface()
t.TypeName = internal.ToExported(t.Type.Name())
t.ModelName = internal.Underscore(t.Type.Name())
tableName := tableNameInflector(t.ModelName)
t.setName(tableName)
t.Alias = t.ModelName
t.SQLAlias = t.quoteIdent(t.ModelName)
type structField struct {
Index []int
Table *Table
}
func newTable(
dialect Dialect, typ reflect.Type, seen map[reflect.Type]*Table, canAddr bool,
) *Table {
if table, ok := seen[typ]; ok {
return table
}
table := new(Table)
seen[typ] = table
table.dialect = dialect
table.Type = typ
table.ZeroValue = reflect.New(table.Type).Elem()
table.ZeroIface = reflect.New(table.Type).Interface()
table.TypeName = internal.ToExported(table.Type.Name())
table.ModelName = internal.Underscore(table.Type.Name())
tableName := tableNameInflector(table.ModelName)
table.setName(tableName)
table.Alias = table.ModelName
table.SQLAlias = table.quoteIdent(table.ModelName)
table.Fields = make([]*Field, 0, typ.NumField())
table.FieldMap = make(map[string]*Field, typ.NumField())
table.processFields(typ, seen, canAddr)
hooks := []struct {
typ reflect.Type
@ -89,45 +105,168 @@ func newTable(dialect Dialect, typ reflect.Type) *Table {
}{
{beforeAppendModelHookType, beforeAppendModelHookFlag},
{beforeScanHookType, beforeScanHookFlag},
{afterScanHookType, afterScanHookFlag},
{beforeScanRowHookType, beforeScanRowHookFlag},
{afterScanRowHookType, afterScanRowHookFlag},
}
typ = reflect.PtrTo(t.Type)
typ = reflect.PtrTo(table.Type)
for _, hook := range hooks {
if typ.Implements(hook.typ) {
t.flags = t.flags.Set(hook.flag)
table.flags = table.flags.Set(hook.flag)
}
}
// Deprecated.
deprecatedHooks := []struct {
typ reflect.Type
flag internal.Flag
msg string
}{
{beforeScanHookType, beforeScanHookFlag, "rename BeforeScan hook to BeforeScanRow"},
{afterScanHookType, afterScanHookFlag, "rename AfterScan hook to AfterScanRow"},
return table
}
func (t *Table) init() {
for _, field := range t.relFields {
t.processRelation(field)
}
for _, hook := range deprecatedHooks {
if typ.Implements(hook.typ) {
internal.Deprecated.Printf("%s: %s", t.TypeName, hook.msg)
t.flags = t.flags.Set(hook.flag)
t.relFields = nil
}
func (t *Table) processFields(
typ reflect.Type,
seen map[reflect.Type]*Table,
canAddr bool,
) {
type embeddedField struct {
prefix string
index []int
unexported bool
subtable *Table
subfield *Field
}
names := make(map[string]struct{})
embedded := make([]embeddedField, 0, 10)
for i, n := 0, typ.NumField(); i < n; i++ {
sf := typ.Field(i)
unexported := sf.PkgPath != ""
tagstr := sf.Tag.Get("bun")
if tagstr == "-" {
names[sf.Name] = struct{}{}
continue
}
tag := tagparser.Parse(tagstr)
if unexported && !sf.Anonymous { // unexported
continue
}
if sf.Anonymous {
if sf.Name == "BaseModel" && sf.Type == baseModelType {
t.processBaseModelField(sf)
continue
}
sfType := sf.Type
if sfType.Kind() == reflect.Ptr {
sfType = sfType.Elem()
}
if sfType.Kind() != reflect.Struct { // ignore unexported non-struct types
continue
}
subtable := newTable(t.dialect, sfType, seen, canAddr)
for _, subfield := range subtable.allFields {
embedded = append(embedded, embeddedField{
index: sf.Index,
unexported: unexported,
subtable: subtable,
subfield: subfield,
})
}
if tagstr != "" {
tag := tagparser.Parse(tagstr)
if tag.HasOption("inherit") || tag.HasOption("extend") {
t.Name = subtable.Name
t.TypeName = subtable.TypeName
t.SQLName = subtable.SQLName
t.SQLNameForSelects = subtable.SQLNameForSelects
t.Alias = subtable.Alias
t.SQLAlias = subtable.SQLAlias
t.ModelName = subtable.ModelName
}
}
continue
}
if prefix, ok := tag.Option("embed"); ok {
fieldType := indirectType(sf.Type)
if fieldType.Kind() != reflect.Struct {
panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct",
t.TypeName, sf.Name, fieldType.Kind()))
}
subtable := newTable(t.dialect, fieldType, seen, canAddr)
for _, subfield := range subtable.allFields {
embedded = append(embedded, embeddedField{
prefix: prefix,
index: sf.Index,
unexported: unexported,
subtable: subtable,
subfield: subfield,
})
}
continue
}
field := t.newField(sf, tag)
t.addField(field)
names[field.Name] = struct{}{}
if field.IndirectType.Kind() == reflect.Struct {
if t.StructMap == nil {
t.StructMap = make(map[string]*structField)
}
t.StructMap[field.Name] = &structField{
Index: field.Index,
Table: newTable(t.dialect, field.IndirectType, seen, canAddr),
}
}
}
return t
}
// Only unambiguous embedded fields must be serialized.
ambiguousNames := make(map[string]int)
ambiguousTags := make(map[string]int)
func (t *Table) init1() {
t.initFields()
}
// Embedded types can never override a field that was already present at
// the top-level.
for name := range names {
ambiguousNames[name]++
ambiguousTags[name]++
}
func (t *Table) init2() {
t.initRelations()
for _, f := range embedded {
ambiguousNames[f.prefix+f.subfield.Name]++
if !f.subfield.Tag.IsZero() {
ambiguousTags[f.prefix+f.subfield.Name]++
}
}
for _, embfield := range embedded {
subfield := embfield.subfield.Clone()
if ambiguousNames[subfield.Name] > 1 &&
!(!subfield.Tag.IsZero() && ambiguousTags[subfield.Name] == 1) {
continue // ambiguous embedded field
}
subfield.Index = makeIndex(embfield.index, subfield.Index)
if embfield.prefix != "" {
subfield.Name = embfield.prefix + subfield.Name
subfield.SQLName = t.quoteIdent(subfield.Name)
}
t.addField(subfield)
}
}
func (t *Table) setName(name string) {
@ -152,30 +291,67 @@ func (t *Table) CheckPKs() error {
}
func (t *Table) addField(field *Field) {
t.allFields = append(t.allFields, field)
if field.Tag.HasOption("rel") || field.Tag.HasOption("m2m") {
t.relFields = append(t.relFields, field)
return
}
if field.Tag.HasOption("join") {
internal.Warn.Printf(
`%s.%s "join" option must come together with "rel" option`,
t.TypeName, field.GoName,
)
}
t.FieldMap[field.Name] = field
if altName, ok := field.Tag.Option("alt"); ok {
t.FieldMap[altName] = field
}
if field.Tag.HasOption("scanonly") {
return
}
if _, ok := field.Tag.Options["soft_delete"]; ok {
t.SoftDeleteField = field
t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
}
t.Fields = append(t.Fields, field)
if field.IsPK {
t.PKs = append(t.PKs, field)
} else {
t.DataFields = append(t.DataFields, field)
}
t.FieldMap[field.Name] = field
}
func (t *Table) removeField(field *Field) {
t.Fields = removeField(t.Fields, field)
if field.IsPK {
t.PKs = removeField(t.PKs, field)
} else {
t.DataFields = removeField(t.DataFields, field)
func (t *Table) LookupField(name string) *Field {
if field, ok := t.FieldMap[name]; ok {
return field
}
delete(t.FieldMap, field.Name)
}
func (t *Table) fieldWithLock(name string) *Field {
t.fieldsMapMu.RLock()
field := t.FieldMap[name]
t.fieldsMapMu.RUnlock()
return field
table := t
var index []int
for {
structName, columnName, ok := strings.Cut(name, "__")
if !ok {
field, ok := table.FieldMap[name]
if !ok {
return nil
}
return field.WithIndex(index)
}
name = columnName
strct := table.StructMap[structName]
if strct == nil {
return nil
}
table = strct.Table
index = append(index, strct.Index...)
}
}
func (t *Table) HasField(name string) bool {
@ -200,59 +376,6 @@ func (t *Table) fieldByGoName(name string) *Field {
return nil
}
func (t *Table) initFields() {
t.Fields = make([]*Field, 0, t.Type.NumField())
t.FieldMap = make(map[string]*Field, t.Type.NumField())
t.addFields(t.Type, "", nil)
}
func (t *Table) addFields(typ reflect.Type, prefix string, index []int) {
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
unexported := f.PkgPath != ""
if unexported && !f.Anonymous { // unexported
continue
}
if f.Tag.Get("bun") == "-" {
continue
}
if f.Anonymous {
if f.Name == "BaseModel" && f.Type == baseModelType {
if len(index) == 0 {
t.processBaseModelField(f)
}
continue
}
// If field is an embedded struct, add each field of the embedded struct.
fieldType := indirectType(f.Type)
if fieldType.Kind() == reflect.Struct {
t.addFields(fieldType, "", withIndex(index, f.Index))
tag := tagparser.Parse(f.Tag.Get("bun"))
if tag.HasOption("inherit") || tag.HasOption("extend") {
embeddedTable := t.dialect.Tables().Ref(fieldType)
t.TypeName = embeddedTable.TypeName
t.SQLName = embeddedTable.SQLName
t.SQLNameForSelects = embeddedTable.SQLNameForSelects
t.Alias = embeddedTable.Alias
t.SQLAlias = embeddedTable.SQLAlias
t.ModelName = embeddedTable.ModelName
}
continue
}
}
// If field is not a struct, add it.
// This will also add any embedded non-struct type as a field.
if field := t.newField(f, prefix, index); field != nil {
t.addField(field)
}
}
}
func (t *Table) processBaseModelField(f reflect.StructField) {
tag := tagparser.Parse(f.Tag.Get("bun"))
@ -288,58 +411,34 @@ func (t *Table) processBaseModelField(f reflect.StructField) {
}
// nolint
func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Field {
tag := tagparser.Parse(f.Tag.Get("bun"))
if nextPrefix, ok := tag.Option("embed"); ok {
fieldType := indirectType(f.Type)
if fieldType.Kind() != reflect.Struct {
panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct",
t.TypeName, f.Name, fieldType.Kind()))
}
t.addFields(fieldType, prefix+nextPrefix, withIndex(index, f.Index))
return nil
}
sqlName := internal.Underscore(f.Name)
func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field {
sqlName := internal.Underscore(sf.Name)
if tag.Name != "" && tag.Name != sqlName {
if isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
"%s.%s tag name %q is also an option name, is it a mistake? Try column:%s.",
t.TypeName, f.Name, tag.Name, tag.Name,
t.TypeName, sf.Name, tag.Name, tag.Name,
)
}
sqlName = tag.Name
}
if s, ok := tag.Option("column"); ok {
sqlName = s
}
sqlName = prefix + sqlName
for name := range tag.Options {
if !isKnownFieldOption(name) {
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, sf.Name, name)
}
}
index = withIndex(index, f.Index)
if field := t.fieldWithLock(sqlName); field != nil {
if indexEqual(field.Index, index) {
return field
}
t.removeField(field)
}
field := &Field{
StructField: f,
IsPtr: f.Type.Kind() == reflect.Ptr,
StructField: sf,
IsPtr: sf.Type.Kind() == reflect.Ptr,
Tag: tag,
IndirectType: indirectType(f.Type),
Index: index,
IndirectType: indirectType(sf.Type),
Index: sf.Index,
Name: sqlName,
GoName: f.Name,
GoName: sf.Name,
SQLName: t.quoteIdent(sqlName),
}
@ -386,63 +485,21 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie
field.Scan = FieldScanner(t.dialect, field)
field.IsZero = zeroChecker(field.StructField.Type)
if v, ok := tag.Option("alt"); ok {
t.FieldMap[v] = field
}
t.allFields = append(t.allFields, field)
if tag.HasOption("scanonly") {
t.FieldMap[field.Name] = field
if field.IndirectType.Kind() == reflect.Struct {
t.inlineFields(field, nil)
}
return nil
}
if _, ok := tag.Options["soft_delete"]; ok {
t.SoftDeleteField = field
t.UpdateSoftDeleteField = softDeleteFieldUpdater(field)
}
return field
}
//---------------------------------------------------------------------------------------
func (t *Table) initRelations() {
for i := 0; i < len(t.Fields); {
f := t.Fields[i]
if t.tryRelation(f) {
t.Fields = removeField(t.Fields, f)
t.DataFields = removeField(t.DataFields, f)
} else {
i++
}
if f.IndirectType.Kind() == reflect.Struct {
t.inlineFields(f, nil)
}
}
}
func (t *Table) tryRelation(field *Field) bool {
func (t *Table) processRelation(field *Field) {
if rel, ok := field.Tag.Option("rel"); ok {
t.initRelation(field, rel)
return true
return
}
if field.Tag.HasOption("m2m") {
t.addRelation(t.m2mRelation(field))
return true
return
}
if field.Tag.HasOption("join") {
internal.Warn.Printf(
`%s.%s "join" option must come together with "rel" option`,
t.TypeName, field.GoName,
)
}
return false
panic("not reached")
}
func (t *Table) initRelation(field *Field, rel string) {
@ -470,7 +527,7 @@ func (t *Table) addRelation(rel *Relation) {
}
func (t *Table) belongsToRelation(field *Field) *Relation {
joinTable := t.dialect.Tables().Ref(field.IndirectType)
joinTable := t.dialect.Tables().InProgress(field.IndirectType)
if err := joinTable.CheckPKs(); err != nil {
panic(err)
}
@ -519,7 +576,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
for i, baseColumn := range baseColumns {
joinColumn := joinColumns[i]
if f := t.fieldWithLock(baseColumn); f != nil {
if f := t.FieldMap[baseColumn]; f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
@ -528,7 +585,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
))
}
if f := joinTable.fieldWithLock(joinColumn); f != nil {
if f := joinTable.FieldMap[joinColumn]; f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
@ -544,12 +601,12 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
fkPrefix := internal.Underscore(field.GoName) + "_"
for _, joinPK := range joinTable.PKs {
fkName := fkPrefix + joinPK.Name
if fk := t.fieldWithLock(fkName); fk != nil {
if fk := t.FieldMap[fkName]; fk != nil {
rel.BaseFields = append(rel.BaseFields, fk)
continue
}
if fk := t.fieldWithLock(joinPK.Name); fk != nil {
if fk := t.FieldMap[joinPK.Name]; fk != nil {
rel.BaseFields = append(rel.BaseFields, fk)
continue
}
@ -568,7 +625,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation {
panic(err)
}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
joinTable := t.dialect.Tables().InProgress(field.IndirectType)
rel := &Relation{
Type: HasOneRelation,
Field: field,
@ -582,7 +639,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation {
if join, ok := field.Tag.Options["join"]; ok {
baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns {
if f := t.fieldWithLock(baseColumn); f != nil {
if f := t.FieldMap[baseColumn]; f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
@ -592,7 +649,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation {
}
joinColumn := joinColumns[i]
if f := joinTable.fieldWithLock(joinColumn); f != nil {
if f := joinTable.FieldMap[joinColumn]; f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
@ -608,12 +665,12 @@ func (t *Table) hasOneRelation(field *Field) *Relation {
fkPrefix := internal.Underscore(t.ModelName) + "_"
for _, pk := range t.PKs {
fkName := fkPrefix + pk.Name
if f := joinTable.fieldWithLock(fkName); f != nil {
if f := joinTable.FieldMap[fkName]; f != nil {
rel.JoinFields = append(rel.JoinFields, f)
continue
}
if f := joinTable.fieldWithLock(pk.Name); f != nil {
if f := joinTable.FieldMap[pk.Name]; f != nil {
rel.JoinFields = append(rel.JoinFields, f)
continue
}
@ -638,7 +695,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
))
}
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
joinTable := t.dialect.Tables().InProgress(indirectType(field.IndirectType.Elem()))
polymorphicValue, isPolymorphic := field.Tag.Option("polymorphic")
rel := &Relation{
Type: HasManyRelation,
@ -662,7 +719,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
continue
}
if f := t.fieldWithLock(baseColumn); f != nil {
if f := t.FieldMap[baseColumn]; f != nil {
rel.BaseFields = append(rel.BaseFields, f)
} else {
panic(fmt.Errorf(
@ -671,7 +728,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
))
}
if f := joinTable.fieldWithLock(joinColumn); f != nil {
if f := joinTable.FieldMap[joinColumn]; f != nil {
rel.JoinFields = append(rel.JoinFields, f)
} else {
panic(fmt.Errorf(
@ -689,12 +746,12 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
for _, pk := range t.PKs {
joinColumn := fkPrefix + pk.Name
if fk := joinTable.fieldWithLock(joinColumn); fk != nil {
if fk := joinTable.FieldMap[joinColumn]; fk != nil {
rel.JoinFields = append(rel.JoinFields, fk)
continue
}
if fk := joinTable.fieldWithLock(pk.Name); fk != nil {
if fk := joinTable.FieldMap[pk.Name]; fk != nil {
rel.JoinFields = append(rel.JoinFields, fk)
continue
}
@ -708,7 +765,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
}
if isPolymorphic {
rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn)
rel.PolymorphicField = joinTable.FieldMap[polymorphicColumn]
if rel.PolymorphicField == nil {
panic(fmt.Errorf(
"bun: %s has-many %s: %s must have polymorphic column %s",
@ -732,7 +789,7 @@ func (t *Table) m2mRelation(field *Field) *Relation {
t.TypeName, field.GoName, field.IndirectType.Kind(),
))
}
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
joinTable := t.dialect.Tables().InProgress(indirectType(field.IndirectType.Elem()))
if err := t.CheckPKs(); err != nil {
panic(err)
@ -805,40 +862,6 @@ func (t *Table) m2mRelation(field *Field) *Relation {
return rel
}
func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) {
if seen == nil {
seen = map[reflect.Type]struct{}{t.Type: {}}
}
if _, ok := seen[field.IndirectType]; ok {
return
}
seen[field.IndirectType] = struct{}{}
joinTable := t.dialect.Tables().Ref(field.IndirectType)
for _, f := range joinTable.allFields {
f = f.Clone()
f.GoName = field.GoName + "_" + f.GoName
f.Name = field.Name + "__" + f.Name
f.SQLName = t.quoteIdent(f.Name)
f.Index = withIndex(field.Index, f.Index)
t.fieldsMapMu.Lock()
if _, ok := t.FieldMap[f.Name]; !ok {
t.FieldMap[f.Name] = f
}
t.fieldsMapMu.Unlock()
if f.IndirectType.Kind() != reflect.Struct {
continue
}
if _, ok := seen[f.IndirectType]; !ok {
t.inlineFields(f, seen)
}
}
}
//------------------------------------------------------------------------------
func (t *Table) Dialect() Dialect { return t.dialect }
@ -890,7 +913,7 @@ func isKnownTableOption(name string) bool {
func isKnownFieldOption(name string) bool {
switch name {
case "column",
"alias",
"alt",
"type",
"array",
"hstore",
@ -931,15 +954,6 @@ func isKnownFKRule(name string) bool {
return false
}
func removeField(fields []*Field, field *Field) []*Field {
for i, f := range fields {
if f == field {
return append(fields[:i], fields[i+1:]...)
}
}
return fields
}
func parseRelationJoin(join []string) ([]string, []string) {
var ss []string
if len(join) == 1 {
@ -1026,7 +1040,7 @@ func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time
}
}
func withIndex(a, b []int) []int {
func makeIndex(a, b []int) []int {
dest := make([]int, 0, len(a)+len(b))
dest = append(dest, a...)
dest = append(dest, b...)