update dependencies (#296)

This commit is contained in:
tobi
2021-11-13 12:29:08 +01:00
committed by GitHub
parent 2aaec82732
commit 829a934d23
124 changed files with 2453 additions and 1588 deletions

View File

@ -0,0 +1,4 @@
root = true
[*]
end_of_line = lf

View File

@ -0,0 +1 @@
* text=auto eol=lf

View File

@ -4,4 +4,5 @@
1. Andrew Krasichkov @buglloc https://github.com/buglloc
1. Mike Samuel mikesamuel@gmail.com
1. Dmitri Shuralyov shurcooL@gmail.com
1. https://github.com/opennota
1. opennota https://github.com/opennota https://gitlab.com/opennota
1. Tom Anthony https://www.tomanthony.co.uk/

View File

@ -3,6 +3,7 @@
# all: Builds the code locally after testing
#
# fmt: Formats the source files
# fmt-check: Check if the source files are formated
# build: Builds the code locally
# vet: Vets the code
# lint: Runs lint over the code (you do not need to fix everything)
@ -11,6 +12,8 @@
#
# install: Builds, tests and installs the code locally
GOFILES_NOVENDOR = $(shell find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./.git/*")
.PHONY: all fmt build vet lint test cover install
# The first target is always the default action if `make` is called without
@ -19,7 +22,10 @@
all: fmt vet test install
fmt:
@gofmt -s -w ./$*
@gofmt -s -w ${GOFILES_NOVENDOR}
fmt-check:
@([ -z "$(shell gofmt -d $(GOFILES_NOVENDOR) | head)" ]) || (echo "Source is unformatted"; exit 1)
build:
@go build

View File

@ -180,7 +180,7 @@ p.AllowElementsMatching(regex.MustCompile(`^my-element-`))
Or add elements as a virtue of adding an attribute:
```go
// Not the recommended pattern, see the recommendation on using .Matching() below
// Note the recommended pattern, see the recommendation on using .Matching() below
p.AllowAttrs("nowrap").OnElements("td", "th")
```
@ -222,7 +222,7 @@ p.AllowElements("fieldset", "select", "option")
Although it's possible to handle inline CSS using `AllowAttrs` with a `Matching` rule, writing a single monolithic regular expression to safely process all inline CSS which you wish to allow is not a trivial task. Instead of attempting to do so, you can allow the `style` attribute on whichever element(s) you desire and use style policies to control and sanitize inline styles.
It is suggested that you use `Matching` (with a suitable regular expression)
It is strongly recommended that you use `Matching` (with a suitable regular expression)
`MatchingEnum`, or `MatchingHandler` to ensure each style matches your needs,
but default handlers are supplied for most widely used styles.
@ -379,6 +379,8 @@ Both examples exhibit the same issue, they declare attributes but do not then sp
We are not yet including any tools to help allow and sanitize CSS. Which means that unless you wish to do the heavy lifting in a single regular expression (inadvisable), **you should not allow the "style" attribute anywhere**.
In the same theme, both `<script>` and `<style>` are considered harmful. These elements (and their content) will not be rendered by default, and require you to explicitly set `p.AllowUnsafe(true)`. You should be aware that allowing these elements defeats the purpose of using a HTML sanitizer as you would be explicitly allowing either JavaScript (and any plainly written XSS) and CSS (which can modify a DOM to insert JS), and additionally but limitations in this library mean it is not aware of whether HTML is validly structured and that can allow these elements to bypass some of the safety mechanisms built into the [WhatWG HTML parser standard](https://html.spec.whatwg.org/multipage/parsing.html#parsing-main-inselect).
It is not the job of bluemonday to fix your bad HTML, it is merely the job of bluemonday to prevent malicious HTML getting through. If you have mismatched HTML elements, or non-conforming nesting of elements, those will remain. But if you have well-structured HTML bluemonday will not break it.
## TODO

View File

@ -134,6 +134,19 @@ type Policy struct {
setOfElementsMatchingAllowedWithoutAttrs []*regexp.Regexp
setOfElementsToSkipContent map[string]struct{}
// Permits fundamentally unsafe elements.
//
// If false (default) then elements such as `style` and `script` will not be
// permitted even if declared in a policy. These elements when combined with
// untrusted input cannot be safely handled by bluemonday at this point in
// time.
//
// If true then `style` and `script` would be permitted by bluemonday if a
// policy declares them. However this is not recommended under any circumstance
// and can lead to XSS being rendered thus defeating the purpose of using a
// HTML sanitizer.
allowUnsafe bool
}
type attrPolicy struct {
@ -714,6 +727,23 @@ func (p *Policy) AllowElementsContent(names ...string) *Policy {
return p
}
// AllowUnsafe permits fundamentally unsafe elements.
//
// If false (default) then elements such as `style` and `script` will not be
// permitted even if declared in a policy. These elements when combined with
// untrusted input cannot be safely handled by bluemonday at this point in
// time.
//
// If true then `style` and `script` would be permitted by bluemonday if a
// policy declares them. However this is not recommended under any circumstance
// and can lead to XSS being rendered thus defeating the purpose of using a
// HTML sanitizer.
func (p *Policy) AllowUnsafe(allowUnsafe bool) *Policy {
p.init()
p.allowUnsafe = allowUnsafe
return p
}
// addDefaultElementsWithoutAttrs adds the HTML elements that we know are valid
// without any attributes to an internal map.
// i.e. we know that <table> is valid, but <bdo> isn't valid as the "dir" attr

View File

@ -130,7 +130,7 @@ func escapeUrlComponent(w stringWriterWriter, val string) error {
return err
}
// Query represents a single part of the query string, a query param
// Query represents a single part of the query string, a query param
type Query struct {
Key string
Value string
@ -293,6 +293,17 @@ func (p *Policy) sanitize(r io.Reader, w io.Writer) error {
mostRecentlyStartedToken = normaliseElementName(token.Data)
switch normaliseElementName(token.Data) {
case `script`:
if !p.allowUnsafe {
continue
}
case `style`:
if !p.allowUnsafe {
continue
}
}
aps, ok := p.elsAndAttrs[token.Data]
if !ok {
aa, matched := p.matchRegex(token.Data)
@ -341,6 +352,17 @@ func (p *Policy) sanitize(r io.Reader, w io.Writer) error {
mostRecentlyStartedToken = ""
}
switch normaliseElementName(token.Data) {
case `script`:
if !p.allowUnsafe {
continue
}
case `style`:
if !p.allowUnsafe {
continue
}
}
if skipClosingTag && closingTagToSkipStack[len(closingTagToSkipStack)-1] == token.Data {
closingTagToSkipStack = closingTagToSkipStack[:len(closingTagToSkipStack)-1]
if len(closingTagToSkipStack) == 0 {
@ -386,6 +408,17 @@ func (p *Policy) sanitize(r io.Reader, w io.Writer) error {
case html.SelfClosingTagToken:
switch normaliseElementName(token.Data) {
case `script`:
if !p.allowUnsafe {
continue
}
case `style`:
if !p.allowUnsafe {
continue
}
}
aps, ok := p.elsAndAttrs[token.Data]
if !ok {
aa, matched := p.matchRegex(token.Data)
@ -425,14 +458,22 @@ func (p *Policy) sanitize(r io.Reader, w io.Writer) error {
case `script`:
// not encouraged, but if a policy allows JavaScript we
// should not HTML escape it as that would break the output
if _, err := buff.WriteString(token.Data); err != nil {
return err
//
// requires p.AllowUnsafe()
if p.allowUnsafe {
if _, err := buff.WriteString(token.Data); err != nil {
return err
}
}
case "style":
// not encouraged, but if a policy allows CSS styles we
// should not HTML escape it as that would break the output
if _, err := buff.WriteString(token.Data); err != nil {
return err
//
// requires p.AllowUnsafe()
if p.allowUnsafe {
if _, err := buff.WriteString(token.Data); err != nil {
return err
}
}
default:
// HTML escape the text
@ -524,11 +565,11 @@ attrsLoop:
for _, ap := range apl {
if ap.regexp != nil {
if ap.regexp.MatchString(htmlAttr.Val) {
htmlAttr.Val = escapeAttribute(htmlAttr.Val)
htmlAttr.Val = escapeAttribute(htmlAttr.Val)
cleanAttrs = append(cleanAttrs, htmlAttr)
}
} else {
htmlAttr.Val = escapeAttribute(htmlAttr.Val)
htmlAttr.Val = escapeAttribute(htmlAttr.Val)
cleanAttrs = append(cleanAttrs, htmlAttr)
}
}
@ -1058,4 +1099,4 @@ func escapeAttribute(val string) string {
val = strings.Replace(val, string([]rune{'\u00A0'}), `&nbsp;`, -1)
val = strings.Replace(val, `"`, `&quot;`, -1)
return val
}
}

View File

@ -1,3 +1,4 @@
//go:build go1.12
// +build go1.12
package bluemonday

View File

@ -1,3 +1,4 @@
//go:build go1.1 && !go1.12
// +build go1.1,!go1.12
package bluemonday

View File

@ -1,3 +1,45 @@
## [1.0.17](https://github.com/uptrace/bun/compare/v1.0.16...v1.0.17) (2021-11-11)
### Bug Fixes
* don't call rollback when tx is already done ([8246c2a](https://github.com/uptrace/bun/commit/8246c2a63e2e6eba314201c6ba87f094edf098b9))
* **mysql:** escape backslash char in strings ([fb32029](https://github.com/uptrace/bun/commit/fb32029ea7604d066800b16df21f239b71bf121d))
## [1.0.16](https://github.com/uptrace/bun/compare/v1.0.15...v1.0.16) (2021-11-07)
### Bug Fixes
* call query hook when tx is started, committed, or rolled back ([30e85b5](https://github.com/uptrace/bun/commit/30e85b5366b2e51951ef17a0cf362b58f708dab1))
* **pgdialect:** auto-enable array support if the sql type is an array ([62c1012](https://github.com/uptrace/bun/commit/62c1012b2482e83969e5c6f5faf89e655ce78138))
### Features
* support multiple tag options join:left_col1=right_col1,join:left_col2=right_col2 ([78cd5aa](https://github.com/uptrace/bun/commit/78cd5aa60a5c7d1323bb89081db2b2b811113052))
* **tag:** log with bad tag name ([4e82d75](https://github.com/uptrace/bun/commit/4e82d75be2dabdba1a510df4e1fbb86092f92f4c))
## [1.0.15](https://github.com/uptrace/bun/compare/v1.0.14...v1.0.15) (2021-10-29)
### Bug Fixes
* fixed bug creating table when model has no columns ([042c50b](https://github.com/uptrace/bun/commit/042c50bfe41caaa6e279e02c887c3a84a3acd84f))
* init table with dialect once ([9a1ce1e](https://github.com/uptrace/bun/commit/9a1ce1e492602742bb2f587e9ed24e50d7d07cad))
### Features
* accept columns in WherePK ([b3e7035](https://github.com/uptrace/bun/commit/b3e70356db1aa4891115a10902316090fccbc8bf))
* support ADD COLUMN IF NOT EXISTS ([ca7357c](https://github.com/uptrace/bun/commit/ca7357cdfe283e2f0b94eb638372e18401c486e9))
## [1.0.14](https://github.com/uptrace/bun/compare/v1.0.13...v1.0.14) (2021-10-24)

View File

@ -30,8 +30,7 @@ Main features are:
Resources:
- To ask questions, join [Discord](https://discord.gg/rWtp5Aj) or use
[Discussions](https://github.com/uptrace/bun/discussions).
- [Discussions](https://github.com/uptrace/bun/discussions).
- [Newsletter](https://blog.uptrace.dev/pages/newsletter.html) to get latest updates.
- [Examples](https://github.com/uptrace/bun/tree/master/example)
- [Documentation](https://bun.uptrace.dev/)
@ -41,6 +40,7 @@ Resources:
Projects using Bun:
- [gotosocial](https://github.com/superseriousbusiness/gotosocial) - Golang fediverse server.
- [input-output-hk/cicero](https://github.com/input-output-hk/cicero)
- [RealWorld app](https://github.com/go-bun/bun-realworld-app)
<details>
@ -95,6 +95,55 @@ Projects using Bun:
</details>
## Why another database client?
So you can elegantly write complex queries:
```go
regionalSales := db.NewSelect().
ColumnExpr("region").
ColumnExpr("SUM(amount) AS total_sales").
TableExpr("orders").
GroupExpr("region")
topRegions := db.NewSelect().
ColumnExpr("region").
TableExpr("regional_sales").
Where("total_sales > (SELECT SUM(total_sales) / 10 FROM regional_sales)")
err := db.NewSelect().
With("regional_sales", regionalSales).
With("top_regions", topRegions).
ColumnExpr("region").
ColumnExpr("product").
ColumnExpr("SUM(quantity) AS product_units").
ColumnExpr("SUM(amount) AS product_sales").
TableExpr("orders").
Where("region IN (SELECT region FROM top_regions)").
GroupExpr("region").
GroupExpr("product").
Scan(ctx)
```
```sql
WITH regional_sales AS (
SELECT region, SUM(amount) AS total_sales
FROM orders
GROUP BY region
), top_regions AS (
SELECT region
FROM regional_sales
WHERE total_sales > (SELECT SUM(total_sales)/10 FROM regional_sales)
)
SELECT region,
product,
SUM(quantity) AS product_units,
SUM(amount) AS product_sales
FROM orders
WHERE region IN (SELECT region FROM top_regions)
GROUP BY region, product
```
## Installation
```go
@ -149,30 +198,6 @@ err := db.NewSelect().
Scan(ctx)
```
The code above is equivalent to:
```go
query := "SELECT id, name FROM users AS user WHERE name != '' ORDER BY id ASC LIMIT 1"
rows, err := sqldb.QueryContext(ctx, query)
if err != nil {
panic(err)
}
if !rows.Next() {
panic(sql.ErrNoRows)
}
user := new(User)
if err := db.ScanRow(ctx, rows, user); err != nil {
panic(err)
}
if err := rows.Err(); err != nil {
panic(err)
}
```
## Basic example
To provide initial data for our [example](/example/basic/), we will use Bun

35
vendor/github.com/uptrace/bun/db.go generated vendored
View File

@ -356,7 +356,8 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (Stmt, error) {
//------------------------------------------------------------------------------
type Tx struct {
db *DB
ctx context.Context
db *DB
*sql.Tx
}
@ -369,11 +370,20 @@ func (db *DB) RunInTx(
if err != nil {
return err
}
defer tx.Rollback() //nolint:errcheck
var done bool
defer func() {
if !done {
_ = tx.Rollback()
}
}()
if err := fn(ctx, tx); err != nil {
return err
}
done = true
return tx.Commit()
}
@ -382,16 +392,33 @@ func (db *DB) Begin() (Tx, error) {
}
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
ctx, event := db.beforeQuery(ctx, nil, "BEGIN", nil, nil)
tx, err := db.DB.BeginTx(ctx, opts)
db.afterQuery(ctx, event, nil, err)
if err != nil {
return Tx{}, err
}
return Tx{
db: db,
Tx: tx,
ctx: ctx,
db: db,
Tx: tx,
}, nil
}
func (tx Tx) Commit() error {
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "COMMIT", nil, nil)
err := tx.Tx.Commit()
tx.db.afterQuery(ctx, event, nil, err)
return err
}
func (tx Tx) Rollback() error {
ctx, event := tx.db.beforeQuery(tx.ctx, nil, "ROLLBACK", nil, nil)
err := tx.Tx.Rollback()
tx.db.afterQuery(ctx, event, nil, err)
return err
}
func (tx Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.ExecContext(context.TODO(), query, args...)
}

View File

@ -1,10 +1,8 @@
package dialect
import (
"encoding/hex"
"math"
"strconv"
"unicode/utf8"
"github.com/uptrace/bun/internal"
)
@ -48,50 +46,6 @@ func appendFloat(b []byte, v float64, bitSize int) []byte {
}
}
func AppendString(b []byte, s string) []byte {
b = append(b, '\'')
for _, r := range s {
if r == '\000' {
continue
}
if r == '\'' {
b = append(b, '\'', '\'')
continue
}
if r < utf8.RuneSelf {
b = append(b, byte(r))
continue
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
b = append(b, '\'')
return b
}
func AppendBytes(b, bs []byte) []byte {
if bs == nil {
return AppendNull(b)
}
b = append(b, `'\x`...)
s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)
b = append(b, '\'')
return b
}
//------------------------------------------------------------------------------
func AppendIdent(b []byte, field string, quote byte) []byte {

View File

@ -3,6 +3,7 @@ package pgdialect
import (
"database/sql"
"strconv"
"strings"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
@ -68,7 +69,7 @@ func (d *Dialect) onField(field *schema.Field) {
}
}
if field.Tag.HasOption("array") {
if field.Tag.HasOption("array") || strings.HasSuffix(field.UserSQLType, "[]") {
field.Append = d.arrayAppender(field.StructField.Type)
field.Scan = arrayScanner(field.StructField.Type)
}

View File

@ -4,7 +4,6 @@ import (
"encoding/json"
"net"
"reflect"
"time"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/schema"
@ -41,7 +40,6 @@ const (
)
var (
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
@ -52,11 +50,11 @@ func fieldSQLType(field *schema.Field) string {
return field.UserSQLType
}
if v, ok := field.Tag.Options["composite"]; ok {
if v, ok := field.Tag.Option("composite"); ok {
return v
}
if _, ok := field.Tag.Options["hstore"]; ok {
if _, ok := field.Tag.Option("hstore"); ok {
return "hstore"
}

View File

@ -6,7 +6,11 @@ import (
type Tag struct {
Name string
Options map[string]string
Options map[string][]string
}
func (t Tag) IsZero() bool {
return t.Name == "" && t.Options == nil
}
func (t Tag) HasOption(name string) bool {
@ -14,7 +18,17 @@ func (t Tag) HasOption(name string) bool {
return ok
}
func (t Tag) Option(name string) (string, bool) {
if vs, ok := t.Options[name]; ok {
return vs[len(vs)-1], true
}
return "", false
}
func Parse(s string) Tag {
if s == "" {
return Tag{}
}
p := parser{
s: s,
}
@ -45,9 +59,13 @@ func (p *parser) addOption(key, value string) {
return
}
if p.tag.Options == nil {
p.tag.Options = make(map[string]string)
p.tag.Options = make(map[string][]string)
}
if vs, ok := p.tag.Options[key]; ok {
p.tag.Options[key] = append(vs, value)
} else {
p.tag.Options[key] = []string{value}
}
p.tag.Options[key] = value
}
func (p *parser) parse() {

View File

@ -1,6 +1,6 @@
{
"name": "bun",
"version": "1.0.14",
"version": "1.0.17",
"main": "index.js",
"repository": "git@github.com:uptrace/bun.git",
"author": "Vladimir Mihailenco <vladimir.webdev@gmail.com>",

View File

@ -14,8 +14,7 @@ import (
)
const (
wherePKFlag internal.Flag = 1 << iota
forceDeleteFlag
forceDeleteFlag internal.Flag = 1 << iota
deletedFlag
allWithDeletedFlag
)
@ -580,7 +579,8 @@ func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) s
type whereBaseQuery struct {
baseQuery
where []schema.QueryWithSep
where []schema.QueryWithSep
whereFields []*schema.Field
}
func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) {
@ -601,10 +601,46 @@ func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep)
q.addWhere(schema.SafeQueryWithSep("", nil, ")"))
}
func (q *whereBaseQuery) addWhereCols(cols []string) {
if q.table == nil {
err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
q.setErr(err)
return
}
var fields []*schema.Field
if cols == nil {
if err := q.table.CheckPKs(); err != nil {
q.setErr(err)
return
}
fields = q.table.PKs
} else {
fields = make([]*schema.Field, len(cols))
for i, col := range cols {
field, err := q.table.Field(col)
if err != nil {
q.setErr(err)
return
}
fields[i] = field
}
}
if q.whereFields != nil {
err := errors.New("bun: WherePK can only be called once")
q.setErr(err)
return
}
q.whereFields = fields
}
func (q *whereBaseQuery) mustAppendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) ([]byte, error) {
if len(q.where) == 0 && !q.flags.Has(wherePKFlag) {
if len(q.where) == 0 && q.whereFields == nil {
err := errors.New("bun: Update and Delete queries require at least one Where")
return nil, err
}
@ -614,7 +650,7 @@ func (q *whereBaseQuery) mustAppendWhere(
func (q *whereBaseQuery) appendWhere(
fmter schema.Formatter, b []byte, withAlias bool,
) (_ []byte, err error) {
if len(q.where) == 0 && !q.isSoftDelete() && !q.flags.Has(wherePKFlag) {
if len(q.where) == 0 && q.whereFields == nil && !q.isSoftDelete() {
return b, nil
}
@ -656,11 +692,11 @@ func (q *whereBaseQuery) appendWhere(
}
}
if q.flags.Has(wherePKFlag) {
if q.whereFields != nil {
if len(b) > startLen {
b = append(b, " AND "...)
}
b, err = q.appendWherePK(fmter, b, withAlias)
b, err = q.appendWhereFields(fmter, b, q.whereFields, withAlias)
if err != nil {
return nil, err
}
@ -691,29 +727,30 @@ func appendWhere(
return b, nil
}
func (q *whereBaseQuery) appendWherePK(
fmter schema.Formatter, b []byte, withAlias bool,
func (q *whereBaseQuery) appendWhereFields(
fmter schema.Formatter, b []byte, fields []*schema.Field, withAlias bool,
) (_ []byte, err error) {
if q.table == nil {
err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model)
return nil, err
}
if err := q.table.CheckPKs(); err != nil {
err := fmt.Errorf("bun: got %T, but WherePK requires struct or slice-based model", q.model)
return nil, err
}
switch model := q.tableModel.(type) {
case *structTableModel:
return q.appendWherePKStruct(fmter, b, model, withAlias)
return q.appendWhereStructFields(fmter, b, model, fields, withAlias)
case *sliceTableModel:
return q.appendWherePKSlice(fmter, b, model, withAlias)
return q.appendWhereSliceFields(fmter, b, model, fields, withAlias)
default:
return nil, fmt.Errorf("bun: WhereColumn does not support %T", q.tableModel)
}
return nil, fmt.Errorf("bun: WherePK does not support %T", q.tableModel)
}
func (q *whereBaseQuery) appendWherePKStruct(
fmter schema.Formatter, b []byte, model *structTableModel, withAlias bool,
func (q *whereBaseQuery) appendWhereStructFields(
fmter schema.Formatter,
b []byte,
model *structTableModel,
fields []*schema.Field,
withAlias bool,
) (_ []byte, err error) {
if !model.strct.IsValid() {
return nil, errNilModel
@ -721,7 +758,7 @@ func (q *whereBaseQuery) appendWherePKStruct(
isTemplate := fmter.IsNop()
b = append(b, '(')
for i, f := range q.table.PKs {
for i, f := range fields {
if i > 0 {
b = append(b, " AND "...)
}
@ -741,18 +778,22 @@ func (q *whereBaseQuery) appendWherePKStruct(
return b, nil
}
func (q *whereBaseQuery) appendWherePKSlice(
fmter schema.Formatter, b []byte, model *sliceTableModel, withAlias bool,
func (q *whereBaseQuery) appendWhereSliceFields(
fmter schema.Formatter,
b []byte,
model *sliceTableModel,
fields []*schema.Field,
withAlias bool,
) (_ []byte, err error) {
if len(q.table.PKs) > 1 {
if len(fields) > 1 {
b = append(b, '(')
}
if withAlias {
b = appendColumns(b, q.table.SQLAlias, q.table.PKs)
b = appendColumns(b, q.table.SQLAlias, fields)
} else {
b = appendColumns(b, "", q.table.PKs)
b = appendColumns(b, "", fields)
}
if len(q.table.PKs) > 1 {
if len(fields) > 1 {
b = append(b, ')')
}
@ -771,10 +812,10 @@ func (q *whereBaseQuery) appendWherePKSlice(
el := indirect(slice.Index(i))
if len(q.table.PKs) > 1 {
if len(fields) > 1 {
b = append(b, '(')
}
for i, f := range q.table.PKs {
for i, f := range fields {
if i > 0 {
b = append(b, ", "...)
}
@ -784,7 +825,7 @@ func (q *whereBaseQuery) appendWherePKSlice(
b = f.AppendValue(fmter, b, el)
}
}
if len(q.table.PKs) > 1 {
if len(fields) > 1 {
b = append(b, ')')
}
}

View File

@ -11,6 +11,8 @@ import (
type AddColumnQuery struct {
baseQuery
ifNotExists bool
}
func NewAddColumnQuery(db *DB) *AddColumnQuery {
@ -59,6 +61,11 @@ func (q *AddColumnQuery) ColumnExpr(query string, args ...interface{}) *AddColum
return q
}
func (q *AddColumnQuery) IfNotExists() *AddColumnQuery {
q.ifNotExists = true
return q
}
//------------------------------------------------------------------------------
func (q *AddColumnQuery) Operation() string {
@ -82,6 +89,10 @@ func (q *AddColumnQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte
b = append(b, " ADD "...)
if q.ifNotExists {
b = append(b, "IF NOT EXISTS "...)
}
b, err = q.columns[0].AppendQuery(fmter, b)
if err != nil {
return nil, err
@ -99,11 +110,5 @@ func (q *AddColumnQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Res
}
query := internal.String(queryBytes)
res, err := q.exec(ctx, q, query)
if err != nil {
return nil, err
}
return res, nil
return q.exec(ctx, q, query)
}

View File

@ -66,8 +66,8 @@ func (q *DeleteQuery) ModelTableExpr(query string, args ...interface{}) *DeleteQ
//------------------------------------------------------------------------------
func (q *DeleteQuery) WherePK() *DeleteQuery {
q.flags = q.flags.Set(wherePKFlag)
func (q *DeleteQuery) WherePK(cols ...string) *DeleteQuery {
q.addWhereCols(cols)
return q
}

View File

@ -10,7 +10,6 @@ import (
"strings"
"sync"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
@ -116,8 +115,8 @@ func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery {
//------------------------------------------------------------------------------
func (q *SelectQuery) WherePK() *SelectQuery {
q.flags = q.flags.Set(wherePKFlag)
func (q *SelectQuery) WherePK(cols ...string) *SelectQuery {
q.addWhereCols(cols)
return q
}
@ -542,7 +541,7 @@ func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte,
if len(q.table.Fields) > 10 && fmter.IsNop() {
b = append(b, q.table.SQLAlias...)
b = append(b, '.')
b = dialect.AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields)))
b = fmter.Dialect().AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields)))
} else {
b = appendColumns(b, q.table.SQLAlias, q.table.Fields)
}

View File

@ -44,7 +44,7 @@ func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery {
return q
}
//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------
func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
for _, table := range tables {
@ -68,7 +68,7 @@ func (q *CreateTableQuery) ColumnExpr(query string, args ...interface{}) *Create
return q
}
//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------
func (q *CreateTableQuery) Temp() *CreateTableQuery {
q.temp = true
@ -128,7 +128,7 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
if field.NotNull {
b = append(b, " NOT NULL"...)
}
if q.db.features.Has(feature.AutoIncrement) && field.AutoIncrement {
if fmter.Dialect().Features().Has(feature.AutoIncrement) && field.AutoIncrement {
b = append(b, " AUTO_INCREMENT"...)
}
if field.SQLDefault != "" {
@ -137,8 +137,13 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
}
}
for _, col := range q.columns {
b = append(b, ", "...)
for i, col := range q.columns {
// Only pre-pend the comma if we are on subsequent iterations, or if there were fields/columns appended before
// this. This way if we are only appending custom column expressions we will not produce a syntax error with a
// leading comma.
if i > 0 || len(q.table.Fields) > 0 {
b = append(b, ", "...)
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err
@ -147,7 +152,7 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
b = q.appendPKConstraint(b, q.table.PKs)
b = q.appendUniqueConstraints(fmter, b)
b, err = q.appenFKConstraints(fmter, b)
b, err = q.appendFKConstraints(fmter, b)
if err != nil {
return nil, err
}
@ -226,7 +231,7 @@ func (q *CreateTableQuery) appendUniqueConstraint(
return b
}
func (q *CreateTableQuery) appenFKConstraints(
func (q *CreateTableQuery) appendFKConstraints(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
for _, fk := range q.fks {
@ -250,7 +255,7 @@ func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []b
return b
}
//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------
func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if err := q.beforeCreateTableHook(ctx); err != nil {

View File

@ -107,8 +107,8 @@ func (q *UpdateQuery) OmitZero() *UpdateQuery {
//------------------------------------------------------------------------------
func (q *UpdateQuery) WherePK() *UpdateQuery {
q.flags = q.flags.Set(wherePKFlag)
func (q *UpdateQuery) WherePK(cols ...string) *UpdateQuery {
q.addWhereCols(cols)
return q
}

View File

@ -31,7 +31,7 @@ func Append(fmter Formatter, b []byte, v interface{}) []byte {
case float64:
return dialect.AppendFloat64(b, v)
case string:
return dialect.AppendString(b, v)
return fmter.Dialect().AppendString(b, v)
case time.Time:
return fmter.Dialect().AppendTime(b, v)
case []byte:

View File

@ -194,7 +194,7 @@ func appendArrayBytesValue(fmter Formatter, b []byte, v reflect.Value) []byte {
}
func AppendStringValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return dialect.AppendString(b, v.String())
return fmter.Dialect().AppendString(b, v.String())
}
func AppendJSONValue(fmter Formatter, b []byte, v reflect.Value) []byte {
@ -217,12 +217,12 @@ func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ip := v.Interface().(net.IP)
return dialect.AppendString(b, ip.String())
return fmter.Dialect().AppendString(b, ip.String())
}
func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ipnet := v.Interface().(net.IPNet)
return dialect.AppendString(b, ipnet.String())
return fmter.Dialect().AppendString(b, ipnet.String())
}
func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
@ -230,7 +230,7 @@ func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byt
if bytes == nil {
return dialect.AppendNull(b)
}
return dialect.AppendString(b, internal.String(bytes))
return fmter.Dialect().AppendString(b, internal.String(bytes))
}
func appendQueryAppenderValue(fmter Formatter, b []byte, v reflect.Value) []byte {

View File

@ -2,8 +2,10 @@ package schema
import (
"database/sql"
"encoding/hex"
"strconv"
"time"
"unicode/utf8"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature"
@ -24,6 +26,7 @@ type Dialect interface {
AppendUint32(b []byte, n uint32) []byte
AppendUint64(b []byte, n uint64) []byte
AppendTime(b []byte, tm time.Time) []byte
AppendString(b []byte, s string) []byte
AppendBytes(b []byte, bs []byte) []byte
AppendJSON(b, jsonb []byte) []byte
}
@ -47,8 +50,48 @@ func (BaseDialect) AppendTime(b []byte, tm time.Time) []byte {
return b
}
func (BaseDialect) AppendString(b []byte, s string) []byte {
b = append(b, '\'')
for _, r := range s {
if r == '\000' {
continue
}
if r == '\'' {
b = append(b, '\'', '\'')
continue
}
if r < utf8.RuneSelf {
b = append(b, byte(r))
continue
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
b = append(b, '\'')
return b
}
func (BaseDialect) AppendBytes(b, bs []byte) []byte {
return dialect.AppendBytes(b, bs)
if bs == nil {
return dialect.AppendNull(b)
}
b = append(b, `'\x`...)
s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)
b = append(b, '\'')
return b
}
func (BaseDialect) AppendJSON(b, jsonb []byte) []byte {

View File

@ -300,11 +300,11 @@ func (t *Table) processBaseModelField(f reflect.StructField) {
t.setName(tag.Name)
}
if s, ok := tag.Options["select"]; ok {
if s, ok := tag.Option("select"); ok {
t.SQLNameForSelects = t.quoteTableName(s)
}
if s, ok := tag.Options["alias"]; ok {
if s, ok := tag.Option("alias"); ok {
t.Alias = s
t.SQLAlias = t.quoteIdent(s)
}
@ -315,17 +315,16 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
tag := tagparser.Parse(f.Tag.Get("bun"))
sqlName := internal.Underscore(f.Name)
if tag.Name != "" {
if tag.Name != "" && tag.Name != sqlName {
if isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
"%s.%s tag name %q is also an option name; is it a mistake?",
t.TypeName, f.Name, tag.Name,
)
}
sqlName = tag.Name
}
if tag.Name != sqlName && isKnownFieldOption(tag.Name) {
internal.Warn.Printf(
"%s.%s tag name %q is also an option name; is it a mistake?",
t.TypeName, f.Name, tag.Name,
)
}
for name := range tag.Options {
if !isKnownFieldOption(name) {
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name)
@ -360,20 +359,27 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
}
if v, ok := tag.Options["unique"]; ok {
// Split the value by comma, this will allow multiple names to be specified.
// We can use this to create multiple named unique constraints where a single column
// might be included in multiple constraints.
for _, uniqueName := range strings.Split(v, ",") {
var names []string
if len(v) == 1 {
// Split the value by comma, this will allow multiple names to be specified.
// We can use this to create multiple named unique constraints where a single column
// might be included in multiple constraints.
names = strings.Split(v[0], ",")
} else {
names = v
}
for _, uniqueName := range names {
if t.Unique == nil {
t.Unique = make(map[string][]*Field)
}
t.Unique[uniqueName] = append(t.Unique[uniqueName], field)
}
}
if s, ok := tag.Options["default"]; ok {
if s, ok := tag.Option("default"); ok {
field.SQLDefault = s
}
if s, ok := field.Tag.Options["type"]; ok {
if s, ok := field.Tag.Option("type"); ok {
field.UserSQLType = s
}
field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType)
@ -381,7 +387,7 @@ func (t *Table) newField(f reflect.StructField, index []int) *Field {
field.Scan = FieldScanner(t.dialect, field)
field.IsZero = zeroChecker(field.StructField.Type)
if v, ok := tag.Options["alt"]; ok {
if v, ok := tag.Option("alt"); ok {
t.FieldMap[v] = field
}
@ -433,7 +439,7 @@ func (t *Table) initRelations() {
}
func (t *Table) tryRelation(field *Field) bool {
if rel, ok := field.Tag.Options["rel"]; ok {
if rel, ok := field.Tag.Option("rel"); ok {
t.initRelation(field, rel)
return true
}
@ -444,7 +450,7 @@ func (t *Table) tryRelation(field *Field) bool {
if field.Tag.HasOption("join") {
internal.Warn.Printf(
`%s.%s option "join" requires a relation type`,
`%s.%s "join" option must come together with "rel" option`,
t.TypeName, field.GoName,
)
}
@ -609,7 +615,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
}
joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem()))
polymorphicValue, isPolymorphic := field.Tag.Options["polymorphic"]
polymorphicValue, isPolymorphic := field.Tag.Option("polymorphic")
rel := &Relation{
Type: HasManyRelation,
Field: field,
@ -706,7 +712,7 @@ func (t *Table) m2mRelation(field *Field) *Relation {
panic(err)
}
m2mTableName, ok := field.Tag.Options["m2m"]
m2mTableName, ok := field.Tag.Option("m2m")
if !ok {
panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName))
}
@ -891,8 +897,14 @@ func removeField(fields []*Field, field *Field) []*Field {
return fields
}
func parseRelationJoin(join string) ([]string, []string) {
ss := strings.Split(join, ",")
func parseRelationJoin(join []string) ([]string, []string) {
var ss []string
if len(join) == 1 {
ss = strings.Split(join[0], ",")
} else {
ss = join
}
baseColumns := make([]string, len(ss))
joinColumns := make([]string, len(ss))
for i, s := range ss {

View File

@ -101,13 +101,15 @@ func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
return table
}
if inProgress.init2() {
t.mu.Lock()
delete(t.inProgress, typ)
t.tables.Store(typ, table)
t.mu.Unlock()
if !inProgress.init2() {
return table
}
t.mu.Lock()
delete(t.inProgress, typ)
t.tables.Store(typ, table)
t.mu.Unlock()
t.dialect.OnTable(table)
for _, field := range table.FieldMap {

View File

@ -2,5 +2,5 @@ package bun
// Version is the current release version.
func Version() string {
return "1.0.14"
return "1.0.17"
}