Update deps
This commit is contained in:
parent
35d7aa0603
commit
0059194a9e
11
go.mod
11
go.mod
|
@ -3,7 +3,7 @@ module github.com/dnscrypt/dnscrypt-proxy
|
|||
go 1.22
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.3.2
|
||||
github.com/BurntSushi/toml v1.4.0
|
||||
github.com/VividCortex/ewma v1.2.0
|
||||
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf
|
||||
github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185
|
||||
|
@ -20,7 +20,7 @@ require (
|
|||
github.com/miekg/dns v1.1.59
|
||||
github.com/opencoff/go-sieve v0.2.1
|
||||
github.com/powerman/check v1.7.0
|
||||
github.com/quic-go/quic-go v0.43.1
|
||||
github.com/quic-go/quic-go v0.44.0
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/net v0.25.0
|
||||
golang.org/x/sys v0.20.0
|
||||
|
@ -41,10 +41,11 @@ require (
|
|||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/smartystreets/goconvey v1.7.2 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
|
||||
golang.org/x/mod v0.16.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/tools v0.19.0 // indirect
|
||||
golang.org/x/tools v0.21.0 // indirect
|
||||
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
|
||||
google.golang.org/grpc v1.53.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
|
|
28
go.sum
28
go.sum
|
@ -1,5 +1,5 @@
|
|||
github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8=
|
||||
github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0=
|
||||
github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow=
|
||||
github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4=
|
||||
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||
|
@ -20,8 +20,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
|
|||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
|
||||
|
@ -73,8 +73,8 @@ github.com/powerman/deepequal v0.1.0 h1:sVwtyTsBuYIvdbLR1O2wzRY63YgPqdGZmk/o80l+
|
|||
github.com/powerman/deepequal v0.1.0/go.mod h1:3k7aG/slufBhUANdN67o/UPg8i5YaiJ6FmibWX0cn04=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/quic-go v0.43.1 h1:fLiMNfQVe9q2JvSsiXo4fXOEguXHGGl9+6gLp4RPeZQ=
|
||||
github.com/quic-go/quic-go v0.43.1/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
|
||||
github.com/quic-go/quic-go v0.44.0 h1:So5wOr7jyO4vzL2sd8/pD9Kesciv91zSk8BoFngItQ0=
|
||||
github.com/quic-go/quic-go v0.44.0/go.mod h1:z4cx/9Ny9UtGITIPzmPTXh1ULfOyWh4qGQlpnPcWmek=
|
||||
github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs=
|
||||
github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
|
||||
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
|
||||
|
@ -87,15 +87,15 @@ go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o=
|
||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
|
||||
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
@ -108,8 +108,8 @@ golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
|||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
|
||||
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
|
||||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f h1:BWUVssLB0HVOSY78gIdvk1dTVYtT1y8SBWtPYuTJ/6w=
|
||||
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
|
||||
|
|
|
@ -9,7 +9,7 @@ See the [releases page](https://github.com/BurntSushi/toml/releases) for a
|
|||
changelog; this information is also in the git tag annotations (e.g. `git show
|
||||
v0.4.0`).
|
||||
|
||||
This library requires Go 1.13 or newer; add it to your go.mod with:
|
||||
This library requires Go 1.18 or newer; add it to your go.mod with:
|
||||
|
||||
% go get github.com/BurntSushi/toml@latest
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"io/fs"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -18,13 +18,13 @@ import (
|
|||
// Unmarshaler is the interface implemented by objects that can unmarshal a
|
||||
// TOML description of themselves.
|
||||
type Unmarshaler interface {
|
||||
UnmarshalTOML(interface{}) error
|
||||
UnmarshalTOML(any) error
|
||||
}
|
||||
|
||||
// Unmarshal decodes the contents of data in TOML format into a pointer v.
|
||||
//
|
||||
// See [Decoder] for a description of the decoding process.
|
||||
func Unmarshal(data []byte, v interface{}) error {
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
_, err := NewDecoder(bytes.NewReader(data)).Decode(v)
|
||||
return err
|
||||
}
|
||||
|
@ -32,12 +32,12 @@ func Unmarshal(data []byte, v interface{}) error {
|
|||
// Decode the TOML data in to the pointer v.
|
||||
//
|
||||
// See [Decoder] for a description of the decoding process.
|
||||
func Decode(data string, v interface{}) (MetaData, error) {
|
||||
func Decode(data string, v any) (MetaData, error) {
|
||||
return NewDecoder(strings.NewReader(data)).Decode(v)
|
||||
}
|
||||
|
||||
// DecodeFile reads the contents of a file and decodes it with [Decode].
|
||||
func DecodeFile(path string, v interface{}) (MetaData, error) {
|
||||
func DecodeFile(path string, v any) (MetaData, error) {
|
||||
fp, err := os.Open(path)
|
||||
if err != nil {
|
||||
return MetaData{}, err
|
||||
|
@ -46,6 +46,17 @@ func DecodeFile(path string, v interface{}) (MetaData, error) {
|
|||
return NewDecoder(fp).Decode(v)
|
||||
}
|
||||
|
||||
// DecodeFS reads the contents of a file from [fs.FS] and decodes it with
|
||||
// [Decode].
|
||||
func DecodeFS(fsys fs.FS, path string, v any) (MetaData, error) {
|
||||
fp, err := fsys.Open(path)
|
||||
if err != nil {
|
||||
return MetaData{}, err
|
||||
}
|
||||
defer fp.Close()
|
||||
return NewDecoder(fp).Decode(v)
|
||||
}
|
||||
|
||||
// Primitive is a TOML value that hasn't been decoded into a Go value.
|
||||
//
|
||||
// This type can be used for any value, which will cause decoding to be delayed.
|
||||
|
@ -58,7 +69,7 @@ func DecodeFile(path string, v interface{}) (MetaData, error) {
|
|||
// overhead of reflection. They can be useful when you don't know the exact type
|
||||
// of TOML data until runtime.
|
||||
type Primitive struct {
|
||||
undecoded interface{}
|
||||
undecoded any
|
||||
context Key
|
||||
}
|
||||
|
||||
|
@ -122,7 +133,7 @@ var (
|
|||
)
|
||||
|
||||
// Decode TOML data in to the pointer `v`.
|
||||
func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
|
||||
func (dec *Decoder) Decode(v any) (MetaData, error) {
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
s := "%q"
|
||||
|
@ -136,8 +147,8 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
|
|||
return MetaData{}, fmt.Errorf("toml: cannot decode to nil value of %q", reflect.TypeOf(v))
|
||||
}
|
||||
|
||||
// Check if this is a supported type: struct, map, interface{}, or something
|
||||
// that implements UnmarshalTOML or UnmarshalText.
|
||||
// Check if this is a supported type: struct, map, any, or something that
|
||||
// implements UnmarshalTOML or UnmarshalText.
|
||||
rv = indirect(rv)
|
||||
rt := rv.Type()
|
||||
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
|
||||
|
@ -148,7 +159,7 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
|
|||
|
||||
// TODO: parser should read from io.Reader? Or at the very least, make it
|
||||
// read from []byte rather than string
|
||||
data, err := ioutil.ReadAll(dec.r)
|
||||
data, err := io.ReadAll(dec.r)
|
||||
if err != nil {
|
||||
return MetaData{}, err
|
||||
}
|
||||
|
@ -179,7 +190,7 @@ func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
|
|||
// will only reflect keys that were decoded. Namely, any keys hidden behind a
|
||||
// Primitive will be considered undecoded. Executing this method will update the
|
||||
// undecoded keys in the meta data. (See the example.)
|
||||
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
|
||||
func (md *MetaData) PrimitiveDecode(primValue Primitive, v any) error {
|
||||
md.context = primValue.context
|
||||
defer func() { md.context = nil }()
|
||||
return md.unify(primValue.undecoded, rvalue(v))
|
||||
|
@ -190,7 +201,7 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
|
|||
//
|
||||
// Any type mismatch produces an error. Finding a type that we don't know
|
||||
// how to handle produces an unsupported type error.
|
||||
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unify(data any, rv reflect.Value) error {
|
||||
// Special case. Look for a `Primitive` value.
|
||||
// TODO: #76 would make this superfluous after implemented.
|
||||
if rv.Type() == primitiveType {
|
||||
|
@ -207,7 +218,11 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
|
|||
|
||||
rvi := rv.Interface()
|
||||
if v, ok := rvi.(Unmarshaler); ok {
|
||||
return v.UnmarshalTOML(data)
|
||||
err := v.UnmarshalTOML(data)
|
||||
if err != nil {
|
||||
return md.parseErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if v, ok := rvi.(encoding.TextUnmarshaler); ok {
|
||||
return md.unifyText(data, v)
|
||||
|
@ -227,14 +242,6 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
|
|||
return md.unifyInt(data, rv)
|
||||
}
|
||||
switch k {
|
||||
case reflect.Ptr:
|
||||
elem := reflect.New(rv.Type().Elem())
|
||||
err := md.unify(data, reflect.Indirect(elem))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rv.Set(elem)
|
||||
return nil
|
||||
case reflect.Struct:
|
||||
return md.unifyStruct(data, rv)
|
||||
case reflect.Map:
|
||||
|
@ -258,14 +265,13 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
|
|||
return md.e("unsupported type %s", rv.Kind())
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
|
||||
tmap, ok := mapping.(map[string]interface{})
|
||||
func (md *MetaData) unifyStruct(mapping any, rv reflect.Value) error {
|
||||
tmap, ok := mapping.(map[string]any)
|
||||
if !ok {
|
||||
if mapping == nil {
|
||||
return nil
|
||||
}
|
||||
return md.e("type mismatch for %s: expected table but found %T",
|
||||
rv.Type().String(), mapping)
|
||||
return md.e("type mismatch for %s: expected table but found %s", rv.Type().String(), fmtType(mapping))
|
||||
}
|
||||
|
||||
for key, datum := range tmap {
|
||||
|
@ -304,14 +310,14 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unifyMap(mapping any, rv reflect.Value) error {
|
||||
keyType := rv.Type().Key().Kind()
|
||||
if keyType != reflect.String && keyType != reflect.Interface {
|
||||
return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)",
|
||||
keyType, rv.Type())
|
||||
}
|
||||
|
||||
tmap, ok := mapping.(map[string]interface{})
|
||||
tmap, ok := mapping.(map[string]any)
|
||||
if !ok {
|
||||
if tmap == nil {
|
||||
return nil
|
||||
|
@ -347,7 +353,7 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unifyArray(data any, rv reflect.Value) error {
|
||||
datav := reflect.ValueOf(data)
|
||||
if datav.Kind() != reflect.Slice {
|
||||
if !datav.IsValid() {
|
||||
|
@ -361,7 +367,7 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
|
|||
return md.unifySliceArray(datav, rv)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unifySlice(data any, rv reflect.Value) error {
|
||||
datav := reflect.ValueOf(data)
|
||||
if datav.Kind() != reflect.Slice {
|
||||
if !datav.IsValid() {
|
||||
|
@ -388,7 +394,7 @@ func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unifyString(data any, rv reflect.Value) error {
|
||||
_, ok := rv.Interface().(json.Number)
|
||||
if ok {
|
||||
if i, ok := data.(int64); ok {
|
||||
|
@ -408,7 +414,7 @@ func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
|
|||
return md.badtype("string", data)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unifyFloat64(data any, rv reflect.Value) error {
|
||||
rvk := rv.Kind()
|
||||
|
||||
if num, ok := data.(float64); ok {
|
||||
|
@ -429,7 +435,7 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
|
|||
if num, ok := data.(int64); ok {
|
||||
if (rvk == reflect.Float32 && (num < -maxSafeFloat32Int || num > maxSafeFloat32Int)) ||
|
||||
(rvk == reflect.Float64 && (num < -maxSafeFloat64Int || num > maxSafeFloat64Int)) {
|
||||
return md.parseErr(errParseRange{i: num, size: rvk.String()})
|
||||
return md.parseErr(errUnsafeFloat{i: num, size: rvk.String()})
|
||||
}
|
||||
rv.SetFloat(float64(num))
|
||||
return nil
|
||||
|
@ -438,7 +444,7 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
|
|||
return md.badtype("float", data)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unifyInt(data any, rv reflect.Value) error {
|
||||
_, ok := rv.Interface().(time.Duration)
|
||||
if ok {
|
||||
// Parse as string duration, and fall back to regular integer parsing
|
||||
|
@ -481,7 +487,7 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unifyBool(data any, rv reflect.Value) error {
|
||||
if b, ok := data.(bool); ok {
|
||||
rv.SetBool(b)
|
||||
return nil
|
||||
|
@ -489,12 +495,12 @@ func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
|
|||
return md.badtype("boolean", data)
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
|
||||
func (md *MetaData) unifyAnything(data any, rv reflect.Value) error {
|
||||
rv.Set(reflect.ValueOf(data))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) error {
|
||||
func (md *MetaData) unifyText(data any, v encoding.TextUnmarshaler) error {
|
||||
var s string
|
||||
switch sdata := data.(type) {
|
||||
case Marshaler:
|
||||
|
@ -523,13 +529,13 @@ func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) erro
|
|||
return md.badtype("primitive (string-like)", data)
|
||||
}
|
||||
if err := v.UnmarshalText([]byte(s)); err != nil {
|
||||
return err
|
||||
return md.parseErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *MetaData) badtype(dst string, data interface{}) error {
|
||||
return md.e("incompatible types: TOML value has type %T; destination has type %s", data, dst)
|
||||
func (md *MetaData) badtype(dst string, data any) error {
|
||||
return md.e("incompatible types: TOML value has type %s; destination has type %s", fmtType(data), dst)
|
||||
}
|
||||
|
||||
func (md *MetaData) parseErr(err error) error {
|
||||
|
@ -543,7 +549,7 @@ func (md *MetaData) parseErr(err error) error {
|
|||
}
|
||||
}
|
||||
|
||||
func (md *MetaData) e(format string, args ...interface{}) error {
|
||||
func (md *MetaData) e(format string, args ...any) error {
|
||||
f := "toml: "
|
||||
if len(md.context) > 0 {
|
||||
f = fmt.Sprintf("toml: (last key %q): ", md.context)
|
||||
|
@ -556,7 +562,7 @@ func (md *MetaData) e(format string, args ...interface{}) error {
|
|||
}
|
||||
|
||||
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
|
||||
func rvalue(v interface{}) reflect.Value {
|
||||
func rvalue(v any) reflect.Value {
|
||||
return indirect(reflect.ValueOf(v))
|
||||
}
|
||||
|
||||
|
@ -600,3 +606,8 @@ func isUnifiable(rv reflect.Value) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// fmt %T with "interface {}" replaced with "any", which is far more readable.
|
||||
func fmtType(t any) string {
|
||||
return strings.ReplaceAll(fmt.Sprintf("%T", t), "interface {}", "any")
|
||||
}
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
//go:build go1.16
|
||||
// +build go1.16
|
||||
|
||||
package toml
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
)
|
||||
|
||||
// DecodeFS reads the contents of a file from [fs.FS] and decodes it with
|
||||
// [Decode].
|
||||
func DecodeFS(fsys fs.FS, path string, v interface{}) (MetaData, error) {
|
||||
fp, err := fsys.Open(path)
|
||||
if err != nil {
|
||||
return MetaData{}, err
|
||||
}
|
||||
defer fp.Close()
|
||||
return NewDecoder(fp).Decode(v)
|
||||
}
|
|
@ -15,15 +15,15 @@ type TextMarshaler encoding.TextMarshaler
|
|||
// Deprecated: use encoding.TextUnmarshaler
|
||||
type TextUnmarshaler encoding.TextUnmarshaler
|
||||
|
||||
// PrimitiveDecode is an alias for MetaData.PrimitiveDecode().
|
||||
//
|
||||
// Deprecated: use MetaData.PrimitiveDecode.
|
||||
func PrimitiveDecode(primValue Primitive, v interface{}) error {
|
||||
md := MetaData{decoded: make(map[string]struct{})}
|
||||
return md.unify(primValue.undecoded, rvalue(v))
|
||||
}
|
||||
|
||||
// DecodeReader is an alias for NewDecoder(r).Decode(v).
|
||||
//
|
||||
// Deprecated: use NewDecoder(reader).Decode(&value).
|
||||
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) { return NewDecoder(r).Decode(v) }
|
||||
func DecodeReader(r io.Reader, v any) (MetaData, error) { return NewDecoder(r).Decode(v) }
|
||||
|
||||
// PrimitiveDecode is an alias for MetaData.PrimitiveDecode().
|
||||
//
|
||||
// Deprecated: use MetaData.PrimitiveDecode.
|
||||
func PrimitiveDecode(primValue Primitive, v any) error {
|
||||
md := MetaData{decoded: make(map[string]struct{})}
|
||||
return md.unify(primValue.undecoded, rvalue(v))
|
||||
}
|
||||
|
|
|
@ -2,9 +2,6 @@
|
|||
//
|
||||
// This package supports TOML v1.0.0, as specified at https://toml.io
|
||||
//
|
||||
// There is also support for delaying decoding with the Primitive type, and
|
||||
// querying the set of keys in a TOML document with the MetaData type.
|
||||
//
|
||||
// The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator,
|
||||
// and can be used to verify if TOML document is valid. It can also be used to
|
||||
// print the type of each key.
|
||||
|
|
|
@ -2,6 +2,7 @@ package toml
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -76,6 +77,17 @@ type Marshaler interface {
|
|||
MarshalTOML() ([]byte, error)
|
||||
}
|
||||
|
||||
// Marshal returns a TOML representation of the Go value.
|
||||
//
|
||||
// See [Encoder] for a description of the encoding process.
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
buff := new(bytes.Buffer)
|
||||
if err := NewEncoder(buff).Encode(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buff.Bytes(), nil
|
||||
}
|
||||
|
||||
// Encoder encodes a Go to a TOML document.
|
||||
//
|
||||
// The mapping between Go values and TOML values should be precisely the same as
|
||||
|
@ -115,26 +127,21 @@ type Marshaler interface {
|
|||
// NOTE: only exported keys are encoded due to the use of reflection. Unexported
|
||||
// keys are silently discarded.
|
||||
type Encoder struct {
|
||||
// String to use for a single indentation level; default is two spaces.
|
||||
Indent string
|
||||
|
||||
Indent string // string for a single indentation level; default is two spaces.
|
||||
hasWritten bool // written any output to w yet?
|
||||
w *bufio.Writer
|
||||
hasWritten bool // written any output to w yet?
|
||||
}
|
||||
|
||||
// NewEncoder create a new Encoder.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
return &Encoder{
|
||||
w: bufio.NewWriter(w),
|
||||
Indent: " ",
|
||||
}
|
||||
return &Encoder{w: bufio.NewWriter(w), Indent: " "}
|
||||
}
|
||||
|
||||
// Encode writes a TOML representation of the Go value to the [Encoder]'s writer.
|
||||
//
|
||||
// An error is returned if the value given cannot be encoded to a valid TOML
|
||||
// document.
|
||||
func (enc *Encoder) Encode(v interface{}) error {
|
||||
func (enc *Encoder) Encode(v any) error {
|
||||
rv := eindirect(reflect.ValueOf(v))
|
||||
err := enc.safeEncode(Key([]string{}), rv)
|
||||
if err != nil {
|
||||
|
@ -280,18 +287,30 @@ func (enc *Encoder) eElement(rv reflect.Value) {
|
|||
case reflect.Float32:
|
||||
f := rv.Float()
|
||||
if math.IsNaN(f) {
|
||||
if math.Signbit(f) {
|
||||
enc.wf("-")
|
||||
}
|
||||
enc.wf("nan")
|
||||
} else if math.IsInf(f, 0) {
|
||||
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)])
|
||||
if math.Signbit(f) {
|
||||
enc.wf("-")
|
||||
}
|
||||
enc.wf("inf")
|
||||
} else {
|
||||
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 32)))
|
||||
}
|
||||
case reflect.Float64:
|
||||
f := rv.Float()
|
||||
if math.IsNaN(f) {
|
||||
if math.Signbit(f) {
|
||||
enc.wf("-")
|
||||
}
|
||||
enc.wf("nan")
|
||||
} else if math.IsInf(f, 0) {
|
||||
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)])
|
||||
if math.Signbit(f) {
|
||||
enc.wf("-")
|
||||
}
|
||||
enc.wf("inf")
|
||||
} else {
|
||||
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 64)))
|
||||
}
|
||||
|
@ -304,7 +323,7 @@ func (enc *Encoder) eElement(rv reflect.Value) {
|
|||
case reflect.Interface:
|
||||
enc.eElement(rv.Elem())
|
||||
default:
|
||||
encPanic(fmt.Errorf("unexpected type: %T", rv.Interface()))
|
||||
encPanic(fmt.Errorf("unexpected type: %s", fmtType(rv.Interface())))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -712,7 +731,7 @@ func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
|
|||
}
|
||||
}
|
||||
|
||||
func (enc *Encoder) wf(format string, v ...interface{}) {
|
||||
func (enc *Encoder) wf(format string, v ...any) {
|
||||
_, err := fmt.Fprintf(enc.w, format, v...)
|
||||
if err != nil {
|
||||
encPanic(err)
|
||||
|
|
|
@ -114,13 +114,22 @@ func (pe ParseError) ErrorWithPosition() string {
|
|||
msg, pe.Position.Line, col, col+pe.Position.Len)
|
||||
}
|
||||
if pe.Position.Line > 2 {
|
||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-2, lines[pe.Position.Line-3])
|
||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-2, expandTab(lines[pe.Position.Line-3]))
|
||||
}
|
||||
if pe.Position.Line > 1 {
|
||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-1, lines[pe.Position.Line-2])
|
||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-1, expandTab(lines[pe.Position.Line-2]))
|
||||
}
|
||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line, lines[pe.Position.Line-1])
|
||||
fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col), strings.Repeat("^", pe.Position.Len))
|
||||
|
||||
/// Expand tabs, so that the ^^^s are at the correct position, but leave
|
||||
/// "column 10-13" intact. Adjusting this to the visual column would be
|
||||
/// better, but we don't know the tabsize of the user in their editor, which
|
||||
/// can be 8, 4, 2, or something else. We can't know. So leaving it as the
|
||||
/// character index is probably the "most correct".
|
||||
expanded := expandTab(lines[pe.Position.Line-1])
|
||||
diff := len(expanded) - len(lines[pe.Position.Line-1])
|
||||
|
||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line, expanded)
|
||||
fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col+diff), strings.Repeat("^", pe.Position.Len))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
|
@ -159,17 +168,47 @@ func (pe ParseError) column(lines []string) int {
|
|||
return col
|
||||
}
|
||||
|
||||
func expandTab(s string) string {
|
||||
var (
|
||||
b strings.Builder
|
||||
l int
|
||||
fill = func(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = ' '
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
)
|
||||
b.Grow(len(s))
|
||||
for _, r := range s {
|
||||
switch r {
|
||||
case '\t':
|
||||
tw := 8 - l%8
|
||||
b.WriteString(fill(tw))
|
||||
l += tw
|
||||
default:
|
||||
b.WriteRune(r)
|
||||
l += 1
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type (
|
||||
errLexControl struct{ r rune }
|
||||
errLexEscape struct{ r rune }
|
||||
errLexUTF8 struct{ b byte }
|
||||
errLexInvalidNum struct{ v string }
|
||||
errLexInvalidDate struct{ v string }
|
||||
errParseDate struct{ v string }
|
||||
errLexInlineTableNL struct{}
|
||||
errLexStringNL struct{}
|
||||
errParseRange struct {
|
||||
i interface{} // int or float
|
||||
size string // "int64", "uint16", etc.
|
||||
i any // int or float
|
||||
size string // "int64", "uint16", etc.
|
||||
}
|
||||
errUnsafeFloat struct {
|
||||
i interface{} // float32 or float64
|
||||
size string // "float32" or "float64"
|
||||
}
|
||||
errParseDuration struct{ d string }
|
||||
)
|
||||
|
@ -183,18 +222,20 @@ func (e errLexEscape) Error() string { return fmt.Sprintf(`invalid escape
|
|||
func (e errLexEscape) Usage() string { return usageEscape }
|
||||
func (e errLexUTF8) Error() string { return fmt.Sprintf("invalid UTF-8 byte: 0x%02x", e.b) }
|
||||
func (e errLexUTF8) Usage() string { return "" }
|
||||
func (e errLexInvalidNum) Error() string { return fmt.Sprintf("invalid number: %q", e.v) }
|
||||
func (e errLexInvalidNum) Usage() string { return "" }
|
||||
func (e errLexInvalidDate) Error() string { return fmt.Sprintf("invalid date: %q", e.v) }
|
||||
func (e errLexInvalidDate) Usage() string { return "" }
|
||||
func (e errParseDate) Error() string { return fmt.Sprintf("invalid datetime: %q", e.v) }
|
||||
func (e errParseDate) Usage() string { return usageDate }
|
||||
func (e errLexInlineTableNL) Error() string { return "newlines not allowed within inline tables" }
|
||||
func (e errLexInlineTableNL) Usage() string { return usageInlineNewline }
|
||||
func (e errLexStringNL) Error() string { return "strings cannot contain newlines" }
|
||||
func (e errLexStringNL) Usage() string { return usageStringNewline }
|
||||
func (e errParseRange) Error() string { return fmt.Sprintf("%v is out of range for %s", e.i, e.size) }
|
||||
func (e errParseRange) Usage() string { return usageIntOverflow }
|
||||
func (e errParseDuration) Error() string { return fmt.Sprintf("invalid duration: %q", e.d) }
|
||||
func (e errParseDuration) Usage() string { return usageDuration }
|
||||
func (e errUnsafeFloat) Error() string {
|
||||
return fmt.Sprintf("%v is out of the safe %s range", e.i, e.size)
|
||||
}
|
||||
func (e errUnsafeFloat) Usage() string { return usageUnsafeFloat }
|
||||
func (e errParseDuration) Error() string { return fmt.Sprintf("invalid duration: %q", e.d) }
|
||||
func (e errParseDuration) Usage() string { return usageDuration }
|
||||
|
||||
const usageEscape = `
|
||||
A '\' inside a "-delimited string is interpreted as an escape character.
|
||||
|
@ -251,19 +292,35 @@ bug in the program that uses too small of an integer.
|
|||
The maximum and minimum values are:
|
||||
|
||||
size │ lowest │ highest
|
||||
───────┼────────────────┼──────────
|
||||
───────┼────────────────┼──────────────
|
||||
int8 │ -128 │ 127
|
||||
int16 │ -32,768 │ 32,767
|
||||
int32 │ -2,147,483,648 │ 2,147,483,647
|
||||
int64 │ -9.2 × 10¹⁷ │ 9.2 × 10¹⁷
|
||||
uint8 │ 0 │ 255
|
||||
uint16 │ 0 │ 65535
|
||||
uint32 │ 0 │ 4294967295
|
||||
uint16 │ 0 │ 65,535
|
||||
uint32 │ 0 │ 4,294,967,295
|
||||
uint64 │ 0 │ 1.8 × 10¹⁸
|
||||
|
||||
int refers to int32 on 32-bit systems and int64 on 64-bit systems.
|
||||
`
|
||||
|
||||
const usageUnsafeFloat = `
|
||||
This number is outside of the "safe" range for floating point numbers; whole
|
||||
(non-fractional) numbers outside the below range can not always be represented
|
||||
accurately in a float, leading to some loss of accuracy.
|
||||
|
||||
Explicitly mark a number as a fractional unit by adding ".0", which will incur
|
||||
some loss of accuracy; for example:
|
||||
|
||||
f = 2_000_000_000.0
|
||||
|
||||
Accuracy ranges:
|
||||
|
||||
float32 = 16,777,215
|
||||
float64 = 9,007,199,254,740,991
|
||||
`
|
||||
|
||||
const usageDuration = `
|
||||
A duration must be as "number<unit>", without any spaces. Valid units are:
|
||||
|
||||
|
@ -277,3 +334,23 @@ A duration must be as "number<unit>", without any spaces. Valid units are:
|
|||
You can combine multiple units; for example "5m10s" for 5 minutes and 10
|
||||
seconds.
|
||||
`
|
||||
|
||||
const usageDate = `
|
||||
A TOML datetime must be in one of the following formats:
|
||||
|
||||
2006-01-02T15:04:05Z07:00 Date and time, with timezone.
|
||||
2006-01-02T15:04:05 Date and time, but without timezone.
|
||||
2006-01-02 Date without a time or timezone.
|
||||
15:04:05 Just a time, without any timezone.
|
||||
|
||||
Seconds may optionally have a fraction, up to nanosecond precision:
|
||||
|
||||
15:04:05.123
|
||||
15:04:05.856018510
|
||||
`
|
||||
|
||||
// TOML 1.1:
|
||||
// The seconds part in times is optional, and may be omitted:
|
||||
// 2006-01-02T15:04Z07:00
|
||||
// 2006-01-02T15:04
|
||||
// 15:04
|
||||
|
|
|
@ -17,6 +17,7 @@ const (
|
|||
itemEOF
|
||||
itemText
|
||||
itemString
|
||||
itemStringEsc
|
||||
itemRawString
|
||||
itemMultilineString
|
||||
itemRawMultilineString
|
||||
|
@ -53,6 +54,7 @@ type lexer struct {
|
|||
state stateFn
|
||||
items chan item
|
||||
tomlNext bool
|
||||
esc bool
|
||||
|
||||
// Allow for backing up up to 4 runes. This is necessary because TOML
|
||||
// contains 3-rune tokens (""" and ''').
|
||||
|
@ -164,7 +166,7 @@ func (lx *lexer) next() (r rune) {
|
|||
}
|
||||
|
||||
r, w := utf8.DecodeRuneInString(lx.input[lx.pos:])
|
||||
if r == utf8.RuneError {
|
||||
if r == utf8.RuneError && w == 1 {
|
||||
lx.error(errLexUTF8{lx.input[lx.pos]})
|
||||
return utf8.RuneError
|
||||
}
|
||||
|
@ -270,7 +272,7 @@ func (lx *lexer) errorPos(start, length int, err error) stateFn {
|
|||
}
|
||||
|
||||
// errorf is like error, and creates a new error.
|
||||
func (lx *lexer) errorf(format string, values ...interface{}) stateFn {
|
||||
func (lx *lexer) errorf(format string, values ...any) stateFn {
|
||||
if lx.atEOF {
|
||||
pos := lx.getPos()
|
||||
pos.Line--
|
||||
|
@ -333,9 +335,7 @@ func lexTopEnd(lx *lexer) stateFn {
|
|||
lx.emit(itemEOF)
|
||||
return nil
|
||||
}
|
||||
return lx.errorf(
|
||||
"expected a top-level item to end with a newline, comment, or EOF, but got %q instead",
|
||||
r)
|
||||
return lx.errorf("expected a top-level item to end with a newline, comment, or EOF, but got %q instead", r)
|
||||
}
|
||||
|
||||
// lexTable lexes the beginning of a table. Namely, it makes sure that
|
||||
|
@ -698,7 +698,12 @@ func lexString(lx *lexer) stateFn {
|
|||
return lexStringEscape
|
||||
case r == '"':
|
||||
lx.backup()
|
||||
lx.emit(itemString)
|
||||
if lx.esc {
|
||||
lx.esc = false
|
||||
lx.emit(itemStringEsc)
|
||||
} else {
|
||||
lx.emit(itemString)
|
||||
}
|
||||
lx.next()
|
||||
lx.ignore()
|
||||
return lx.pop()
|
||||
|
@ -748,6 +753,7 @@ func lexMultilineString(lx *lexer) stateFn {
|
|||
lx.backup() /// backup: don't include the """ in the item.
|
||||
lx.backup()
|
||||
lx.backup()
|
||||
lx.esc = false
|
||||
lx.emit(itemMultilineString)
|
||||
lx.next() /// Read over ''' again and discard it.
|
||||
lx.next()
|
||||
|
@ -837,6 +843,7 @@ func lexMultilineStringEscape(lx *lexer) stateFn {
|
|||
}
|
||||
|
||||
func lexStringEscape(lx *lexer) stateFn {
|
||||
lx.esc = true
|
||||
r := lx.next()
|
||||
switch r {
|
||||
case 'e':
|
||||
|
@ -879,10 +886,8 @@ func lexHexEscape(lx *lexer) stateFn {
|
|||
var r rune
|
||||
for i := 0; i < 2; i++ {
|
||||
r = lx.next()
|
||||
if !isHexadecimal(r) {
|
||||
return lx.errorf(
|
||||
`expected two hexadecimal digits after '\x', but got %q instead`,
|
||||
lx.current())
|
||||
if !isHex(r) {
|
||||
return lx.errorf(`expected two hexadecimal digits after '\x', but got %q instead`, lx.current())
|
||||
}
|
||||
}
|
||||
return lx.pop()
|
||||
|
@ -892,10 +897,8 @@ func lexShortUnicodeEscape(lx *lexer) stateFn {
|
|||
var r rune
|
||||
for i := 0; i < 4; i++ {
|
||||
r = lx.next()
|
||||
if !isHexadecimal(r) {
|
||||
return lx.errorf(
|
||||
`expected four hexadecimal digits after '\u', but got %q instead`,
|
||||
lx.current())
|
||||
if !isHex(r) {
|
||||
return lx.errorf(`expected four hexadecimal digits after '\u', but got %q instead`, lx.current())
|
||||
}
|
||||
}
|
||||
return lx.pop()
|
||||
|
@ -905,10 +908,8 @@ func lexLongUnicodeEscape(lx *lexer) stateFn {
|
|||
var r rune
|
||||
for i := 0; i < 8; i++ {
|
||||
r = lx.next()
|
||||
if !isHexadecimal(r) {
|
||||
return lx.errorf(
|
||||
`expected eight hexadecimal digits after '\U', but got %q instead`,
|
||||
lx.current())
|
||||
if !isHex(r) {
|
||||
return lx.errorf(`expected eight hexadecimal digits after '\U', but got %q instead`, lx.current())
|
||||
}
|
||||
}
|
||||
return lx.pop()
|
||||
|
@ -975,7 +976,7 @@ func lexDatetime(lx *lexer) stateFn {
|
|||
// lexHexInteger consumes a hexadecimal integer after seeing the '0x' prefix.
|
||||
func lexHexInteger(lx *lexer) stateFn {
|
||||
r := lx.next()
|
||||
if isHexadecimal(r) {
|
||||
if isHex(r) {
|
||||
return lexHexInteger
|
||||
}
|
||||
switch r {
|
||||
|
@ -1109,7 +1110,7 @@ func lexBaseNumberOrDate(lx *lexer) stateFn {
|
|||
return lexOctalInteger
|
||||
case 'x':
|
||||
r = lx.peek()
|
||||
if !isHexadecimal(r) {
|
||||
if !isHex(r) {
|
||||
lx.errorf("not a hexidecimal number: '%s%c'", lx.current(), r)
|
||||
}
|
||||
return lexHexInteger
|
||||
|
@ -1207,7 +1208,7 @@ func (itype itemType) String() string {
|
|||
return "EOF"
|
||||
case itemText:
|
||||
return "Text"
|
||||
case itemString, itemRawString, itemMultilineString, itemRawMultilineString:
|
||||
case itemString, itemStringEsc, itemRawString, itemMultilineString, itemRawMultilineString:
|
||||
return "String"
|
||||
case itemBool:
|
||||
return "Bool"
|
||||
|
@ -1240,7 +1241,7 @@ func (itype itemType) String() string {
|
|||
}
|
||||
|
||||
func (item item) String() string {
|
||||
return fmt.Sprintf("(%s, %s)", item.typ.String(), item.val)
|
||||
return fmt.Sprintf("(%s, %s)", item.typ, item.val)
|
||||
}
|
||||
|
||||
func isWhitespace(r rune) bool { return r == '\t' || r == ' ' }
|
||||
|
@ -1256,10 +1257,7 @@ func isControl(r rune) bool { // Control characters except \t, \r, \n
|
|||
func isDigit(r rune) bool { return r >= '0' && r <= '9' }
|
||||
func isBinary(r rune) bool { return r == '0' || r == '1' }
|
||||
func isOctal(r rune) bool { return r >= '0' && r <= '7' }
|
||||
func isHexadecimal(r rune) bool {
|
||||
return (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')
|
||||
}
|
||||
|
||||
func isHex(r rune) bool { return (r >= '0' && r <= '9') || (r|0x20 >= 'a' && r|0x20 <= 'f') }
|
||||
func isBareKeyChar(r rune, tomlNext bool) bool {
|
||||
if tomlNext {
|
||||
return (r >= 'A' && r <= 'Z') ||
|
||||
|
|
|
@ -13,7 +13,7 @@ type MetaData struct {
|
|||
context Key // Used only during decoding.
|
||||
|
||||
keyInfo map[string]keyInfo
|
||||
mapping map[string]interface{}
|
||||
mapping map[string]any
|
||||
keys []Key
|
||||
decoded map[string]struct{}
|
||||
data []byte // Input file; for errors.
|
||||
|
@ -31,12 +31,12 @@ func (md *MetaData) IsDefined(key ...string) bool {
|
|||
}
|
||||
|
||||
var (
|
||||
hash map[string]interface{}
|
||||
hash map[string]any
|
||||
ok bool
|
||||
hashOrVal interface{} = md.mapping
|
||||
hashOrVal any = md.mapping
|
||||
)
|
||||
for _, k := range key {
|
||||
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
|
||||
if hash, ok = hashOrVal.(map[string]any); !ok {
|
||||
return false
|
||||
}
|
||||
if hashOrVal, ok = hash[k]; !ok {
|
||||
|
@ -94,28 +94,55 @@ func (md *MetaData) Undecoded() []Key {
|
|||
type Key []string
|
||||
|
||||
func (k Key) String() string {
|
||||
ss := make([]string, len(k))
|
||||
for i := range k {
|
||||
ss[i] = k.maybeQuoted(i)
|
||||
// This is called quite often, so it's a bit funky to make it faster.
|
||||
var b strings.Builder
|
||||
b.Grow(len(k) * 25)
|
||||
outer:
|
||||
for i, kk := range k {
|
||||
if i > 0 {
|
||||
b.WriteByte('.')
|
||||
}
|
||||
if kk == "" {
|
||||
b.WriteString(`""`)
|
||||
} else {
|
||||
for _, r := range kk {
|
||||
// "Inline" isBareKeyChar
|
||||
if !((r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-') {
|
||||
b.WriteByte('"')
|
||||
b.WriteString(dblQuotedReplacer.Replace(kk))
|
||||
b.WriteByte('"')
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
b.WriteString(kk)
|
||||
}
|
||||
}
|
||||
return strings.Join(ss, ".")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (k Key) maybeQuoted(i int) string {
|
||||
if k[i] == "" {
|
||||
return `""`
|
||||
}
|
||||
for _, c := range k[i] {
|
||||
if !isBareKeyChar(c, false) {
|
||||
return `"` + dblQuotedReplacer.Replace(k[i]) + `"`
|
||||
for _, r := range k[i] {
|
||||
if (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||
continue
|
||||
}
|
||||
return `"` + dblQuotedReplacer.Replace(k[i]) + `"`
|
||||
}
|
||||
return k[i]
|
||||
}
|
||||
|
||||
// Like append(), but only increase the cap by 1.
|
||||
func (k Key) add(piece string) Key {
|
||||
if cap(k) > len(k) {
|
||||
return append(k, piece)
|
||||
}
|
||||
newKey := make(Key, len(k)+1)
|
||||
copy(newKey, k)
|
||||
newKey[len(k)] = piece
|
||||
return newKey
|
||||
}
|
||||
|
||||
func (k Key) parent() Key { return k[:len(k)-1] } // all except the last piece.
|
||||
func (k Key) last() string { return k[len(k)-1] } // last piece of this key.
|
||||
|
|
|
@ -2,6 +2,7 @@ package toml
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -20,9 +21,9 @@ type parser struct {
|
|||
|
||||
ordered []Key // List of keys in the order that they appear in the TOML data.
|
||||
|
||||
keyInfo map[string]keyInfo // Map keyname → info about the TOML key.
|
||||
mapping map[string]interface{} // Map keyname → key value.
|
||||
implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names").
|
||||
keyInfo map[string]keyInfo // Map keyname → info about the TOML key.
|
||||
mapping map[string]any // Map keyname → key value.
|
||||
implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names").
|
||||
}
|
||||
|
||||
type keyInfo struct {
|
||||
|
@ -49,6 +50,7 @@ func parse(data string) (p *parser, err error) {
|
|||
// it anyway.
|
||||
if strings.HasPrefix(data, "\xff\xfe") || strings.HasPrefix(data, "\xfe\xff") { // UTF-16
|
||||
data = data[2:]
|
||||
//lint:ignore S1017 https://github.com/dominikh/go-tools/issues/1447
|
||||
} else if strings.HasPrefix(data, "\xef\xbb\xbf") { // UTF-8
|
||||
data = data[3:]
|
||||
}
|
||||
|
@ -71,7 +73,7 @@ func parse(data string) (p *parser, err error) {
|
|||
|
||||
p = &parser{
|
||||
keyInfo: make(map[string]keyInfo),
|
||||
mapping: make(map[string]interface{}),
|
||||
mapping: make(map[string]any),
|
||||
lx: lex(data, tomlNext),
|
||||
ordered: make([]Key, 0),
|
||||
implicits: make(map[string]struct{}),
|
||||
|
@ -97,7 +99,7 @@ func (p *parser) panicErr(it item, err error) {
|
|||
})
|
||||
}
|
||||
|
||||
func (p *parser) panicItemf(it item, format string, v ...interface{}) {
|
||||
func (p *parser) panicItemf(it item, format string, v ...any) {
|
||||
panic(ParseError{
|
||||
Message: fmt.Sprintf(format, v...),
|
||||
Position: it.pos,
|
||||
|
@ -106,7 +108,7 @@ func (p *parser) panicItemf(it item, format string, v ...interface{}) {
|
|||
})
|
||||
}
|
||||
|
||||
func (p *parser) panicf(format string, v ...interface{}) {
|
||||
func (p *parser) panicf(format string, v ...any) {
|
||||
panic(ParseError{
|
||||
Message: fmt.Sprintf(format, v...),
|
||||
Position: p.pos,
|
||||
|
@ -139,7 +141,7 @@ func (p *parser) nextPos() item {
|
|||
return it
|
||||
}
|
||||
|
||||
func (p *parser) bug(format string, v ...interface{}) {
|
||||
func (p *parser) bug(format string, v ...any) {
|
||||
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
|
||||
}
|
||||
|
||||
|
@ -194,11 +196,11 @@ func (p *parser) topLevel(item item) {
|
|||
p.assertEqual(itemKeyEnd, k.typ)
|
||||
|
||||
/// The current key is the last part.
|
||||
p.currentKey = key[len(key)-1]
|
||||
p.currentKey = key.last()
|
||||
|
||||
/// All the other parts (if any) are the context; need to set each part
|
||||
/// as implicit.
|
||||
context := key[:len(key)-1]
|
||||
context := key.parent()
|
||||
for i := range context {
|
||||
p.addImplicitContext(append(p.context, context[i:i+1]...))
|
||||
}
|
||||
|
@ -207,7 +209,8 @@ func (p *parser) topLevel(item item) {
|
|||
/// Set value.
|
||||
vItem := p.next()
|
||||
val, typ := p.value(vItem, false)
|
||||
p.set(p.currentKey, val, typ, vItem.pos)
|
||||
p.setValue(p.currentKey, val)
|
||||
p.setType(p.currentKey, typ, vItem.pos)
|
||||
|
||||
/// Remove the context we added (preserving any context from [tbl] lines).
|
||||
p.context = outerContext
|
||||
|
@ -222,7 +225,7 @@ func (p *parser) keyString(it item) string {
|
|||
switch it.typ {
|
||||
case itemText:
|
||||
return it.val
|
||||
case itemString, itemMultilineString,
|
||||
case itemString, itemStringEsc, itemMultilineString,
|
||||
itemRawString, itemRawMultilineString:
|
||||
s, _ := p.value(it, false)
|
||||
return s.(string)
|
||||
|
@ -239,9 +242,11 @@ var datetimeRepl = strings.NewReplacer(
|
|||
|
||||
// value translates an expected value from the lexer into a Go value wrapped
|
||||
// as an empty interface.
|
||||
func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) {
|
||||
func (p *parser) value(it item, parentIsArray bool) (any, tomlType) {
|
||||
switch it.typ {
|
||||
case itemString:
|
||||
return it.val, p.typeOfPrimitive(it)
|
||||
case itemStringEsc:
|
||||
return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it)
|
||||
case itemMultilineString:
|
||||
return p.replaceEscapes(it, p.stripEscapedNewlines(stripFirstNewline(it.val))), p.typeOfPrimitive(it)
|
||||
|
@ -274,7 +279,7 @@ func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) {
|
|||
panic("unreachable")
|
||||
}
|
||||
|
||||
func (p *parser) valueInteger(it item) (interface{}, tomlType) {
|
||||
func (p *parser) valueInteger(it item) (any, tomlType) {
|
||||
if !numUnderscoresOK(it.val) {
|
||||
p.panicItemf(it, "Invalid integer %q: underscores must be surrounded by digits", it.val)
|
||||
}
|
||||
|
@ -298,7 +303,7 @@ func (p *parser) valueInteger(it item) (interface{}, tomlType) {
|
|||
return num, p.typeOfPrimitive(it)
|
||||
}
|
||||
|
||||
func (p *parser) valueFloat(it item) (interface{}, tomlType) {
|
||||
func (p *parser) valueFloat(it item) (any, tomlType) {
|
||||
parts := strings.FieldsFunc(it.val, func(r rune) bool {
|
||||
switch r {
|
||||
case '.', 'e', 'E':
|
||||
|
@ -322,7 +327,9 @@ func (p *parser) valueFloat(it item) (interface{}, tomlType) {
|
|||
p.panicItemf(it, "Invalid float %q: '.' must be followed by one or more digits", it.val)
|
||||
}
|
||||
val := strings.Replace(it.val, "_", "", -1)
|
||||
if val == "+nan" || val == "-nan" { // Go doesn't support this, but TOML spec does.
|
||||
signbit := false
|
||||
if val == "+nan" || val == "-nan" {
|
||||
signbit = val == "-nan"
|
||||
val = "nan"
|
||||
}
|
||||
num, err := strconv.ParseFloat(val, 64)
|
||||
|
@ -333,6 +340,9 @@ func (p *parser) valueFloat(it item) (interface{}, tomlType) {
|
|||
p.panicItemf(it, "Invalid float value: %q", it.val)
|
||||
}
|
||||
}
|
||||
if signbit {
|
||||
num = math.Copysign(num, -1)
|
||||
}
|
||||
return num, p.typeOfPrimitive(it)
|
||||
}
|
||||
|
||||
|
@ -352,7 +362,7 @@ var dtTypes = []struct {
|
|||
{"15:04", internal.LocalTime, true},
|
||||
}
|
||||
|
||||
func (p *parser) valueDatetime(it item) (interface{}, tomlType) {
|
||||
func (p *parser) valueDatetime(it item) (any, tomlType) {
|
||||
it.val = datetimeRepl.Replace(it.val)
|
||||
var (
|
||||
t time.Time
|
||||
|
@ -365,26 +375,44 @@ func (p *parser) valueDatetime(it item) (interface{}, tomlType) {
|
|||
}
|
||||
t, err = time.ParseInLocation(dt.fmt, it.val, dt.zone)
|
||||
if err == nil {
|
||||
if missingLeadingZero(it.val, dt.fmt) {
|
||||
p.panicErr(it, errParseDate{it.val})
|
||||
}
|
||||
ok = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
p.panicItemf(it, "Invalid TOML Datetime: %q.", it.val)
|
||||
p.panicErr(it, errParseDate{it.val})
|
||||
}
|
||||
return t, p.typeOfPrimitive(it)
|
||||
}
|
||||
|
||||
func (p *parser) valueArray(it item) (interface{}, tomlType) {
|
||||
// Go's time.Parse() will accept numbers without a leading zero; there isn't any
|
||||
// way to require it. https://github.com/golang/go/issues/29911
|
||||
//
|
||||
// Depend on the fact that the separators (- and :) should always be at the same
|
||||
// location.
|
||||
func missingLeadingZero(d, l string) bool {
|
||||
for i, c := range []byte(l) {
|
||||
if c == '.' || c == 'Z' {
|
||||
return false
|
||||
}
|
||||
if (c < '0' || c > '9') && d[i] != c {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *parser) valueArray(it item) (any, tomlType) {
|
||||
p.setType(p.currentKey, tomlArray, it.pos)
|
||||
|
||||
var (
|
||||
types []tomlType
|
||||
|
||||
// Initialize to a non-nil empty slice. This makes it consistent with
|
||||
// how S = [] decodes into a non-nil slice inside something like struct
|
||||
// { S []string }. See #338
|
||||
array = []interface{}{}
|
||||
// Initialize to a non-nil slice to make it consistent with how S = []
|
||||
// decodes into a non-nil slice inside something like struct { S
|
||||
// []string }. See #338
|
||||
array = make([]any, 0, 2)
|
||||
)
|
||||
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
|
||||
if it.typ == itemCommentStart {
|
||||
|
@ -394,21 +422,20 @@ func (p *parser) valueArray(it item) (interface{}, tomlType) {
|
|||
|
||||
val, typ := p.value(it, true)
|
||||
array = append(array, val)
|
||||
types = append(types, typ)
|
||||
|
||||
// XXX: types isn't used here, we need it to record the accurate type
|
||||
// XXX: type isn't used here, we need it to record the accurate type
|
||||
// information.
|
||||
//
|
||||
// Not entirely sure how to best store this; could use "key[0]",
|
||||
// "key[1]" notation, or maybe store it on the Array type?
|
||||
_ = types
|
||||
_ = typ
|
||||
}
|
||||
return array, tomlArray
|
||||
}
|
||||
|
||||
func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tomlType) {
|
||||
func (p *parser) valueInlineTable(it item, parentIsArray bool) (any, tomlType) {
|
||||
var (
|
||||
hash = make(map[string]interface{})
|
||||
topHash = make(map[string]any)
|
||||
outerContext = p.context
|
||||
outerKey = p.currentKey
|
||||
)
|
||||
|
@ -436,11 +463,11 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom
|
|||
p.assertEqual(itemKeyEnd, k.typ)
|
||||
|
||||
/// The current key is the last part.
|
||||
p.currentKey = key[len(key)-1]
|
||||
p.currentKey = key.last()
|
||||
|
||||
/// All the other parts (if any) are the context; need to set each part
|
||||
/// as implicit.
|
||||
context := key[:len(key)-1]
|
||||
context := key.parent()
|
||||
for i := range context {
|
||||
p.addImplicitContext(append(p.context, context[i:i+1]...))
|
||||
}
|
||||
|
@ -448,7 +475,21 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom
|
|||
|
||||
/// Set the value.
|
||||
val, typ := p.value(p.next(), false)
|
||||
p.set(p.currentKey, val, typ, it.pos)
|
||||
p.setValue(p.currentKey, val)
|
||||
p.setType(p.currentKey, typ, it.pos)
|
||||
|
||||
hash := topHash
|
||||
for _, c := range context {
|
||||
h, ok := hash[c]
|
||||
if !ok {
|
||||
h = make(map[string]any)
|
||||
hash[c] = h
|
||||
}
|
||||
hash, ok = h.(map[string]any)
|
||||
if !ok {
|
||||
p.panicf("%q is not a table", p.context)
|
||||
}
|
||||
}
|
||||
hash[p.currentKey] = val
|
||||
|
||||
/// Restore context.
|
||||
|
@ -456,7 +497,7 @@ func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tom
|
|||
}
|
||||
p.context = outerContext
|
||||
p.currentKey = outerKey
|
||||
return hash, tomlHash
|
||||
return topHash, tomlHash
|
||||
}
|
||||
|
||||
// numHasLeadingZero checks if this number has leading zeroes, allowing for '0',
|
||||
|
@ -486,9 +527,9 @@ func numUnderscoresOK(s string) bool {
|
|||
}
|
||||
}
|
||||
|
||||
// isHexadecimal is a superset of all the permissable characters
|
||||
// surrounding an underscore.
|
||||
accept = isHexadecimal(r)
|
||||
// isHexis a superset of all the permissable characters surrounding an
|
||||
// underscore.
|
||||
accept = isHex(r)
|
||||
}
|
||||
return accept
|
||||
}
|
||||
|
@ -511,21 +552,19 @@ func numPeriodsOK(s string) bool {
|
|||
// Establishing the context also makes sure that the key isn't a duplicate, and
|
||||
// will create implicit hashes automatically.
|
||||
func (p *parser) addContext(key Key, array bool) {
|
||||
var ok bool
|
||||
|
||||
// Always start at the top level and drill down for our context.
|
||||
/// Always start at the top level and drill down for our context.
|
||||
hashContext := p.mapping
|
||||
keyContext := make(Key, 0)
|
||||
keyContext := make(Key, 0, len(key)-1)
|
||||
|
||||
// We only need implicit hashes for key[0:-1]
|
||||
for _, k := range key[0 : len(key)-1] {
|
||||
_, ok = hashContext[k]
|
||||
/// We only need implicit hashes for the parents.
|
||||
for _, k := range key.parent() {
|
||||
_, ok := hashContext[k]
|
||||
keyContext = append(keyContext, k)
|
||||
|
||||
// No key? Make an implicit hash and move on.
|
||||
if !ok {
|
||||
p.addImplicit(keyContext)
|
||||
hashContext[k] = make(map[string]interface{})
|
||||
hashContext[k] = make(map[string]any)
|
||||
}
|
||||
|
||||
// If the hash context is actually an array of tables, then set
|
||||
|
@ -534,9 +573,9 @@ func (p *parser) addContext(key Key, array bool) {
|
|||
// Otherwise, it better be a table, since this MUST be a key group (by
|
||||
// virtue of it not being the last element in a key).
|
||||
switch t := hashContext[k].(type) {
|
||||
case []map[string]interface{}:
|
||||
case []map[string]any:
|
||||
hashContext = t[len(t)-1]
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
hashContext = t
|
||||
default:
|
||||
p.panicf("Key '%s' was already created as a hash.", keyContext)
|
||||
|
@ -547,39 +586,33 @@ func (p *parser) addContext(key Key, array bool) {
|
|||
if array {
|
||||
// If this is the first element for this array, then allocate a new
|
||||
// list of tables for it.
|
||||
k := key[len(key)-1]
|
||||
k := key.last()
|
||||
if _, ok := hashContext[k]; !ok {
|
||||
hashContext[k] = make([]map[string]interface{}, 0, 4)
|
||||
hashContext[k] = make([]map[string]any, 0, 4)
|
||||
}
|
||||
|
||||
// Add a new table. But make sure the key hasn't already been used
|
||||
// for something else.
|
||||
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
|
||||
hashContext[k] = append(hash, make(map[string]interface{}))
|
||||
if hash, ok := hashContext[k].([]map[string]any); ok {
|
||||
hashContext[k] = append(hash, make(map[string]any))
|
||||
} else {
|
||||
p.panicf("Key '%s' was already created and cannot be used as an array.", key)
|
||||
}
|
||||
} else {
|
||||
p.setValue(key[len(key)-1], make(map[string]interface{}))
|
||||
p.setValue(key.last(), make(map[string]any))
|
||||
}
|
||||
p.context = append(p.context, key[len(key)-1])
|
||||
}
|
||||
|
||||
// set calls setValue and setType.
|
||||
func (p *parser) set(key string, val interface{}, typ tomlType, pos Position) {
|
||||
p.setValue(key, val)
|
||||
p.setType(key, typ, pos)
|
||||
p.context = append(p.context, key.last())
|
||||
}
|
||||
|
||||
// setValue sets the given key to the given value in the current context.
|
||||
// It will make sure that the key hasn't already been defined, account for
|
||||
// implicit key groups.
|
||||
func (p *parser) setValue(key string, value interface{}) {
|
||||
func (p *parser) setValue(key string, value any) {
|
||||
var (
|
||||
tmpHash interface{}
|
||||
tmpHash any
|
||||
ok bool
|
||||
hash = p.mapping
|
||||
keyContext Key
|
||||
keyContext = make(Key, 0, len(p.context)+1)
|
||||
)
|
||||
for _, k := range p.context {
|
||||
keyContext = append(keyContext, k)
|
||||
|
@ -587,11 +620,11 @@ func (p *parser) setValue(key string, value interface{}) {
|
|||
p.bug("Context for key '%s' has not been established.", keyContext)
|
||||
}
|
||||
switch t := tmpHash.(type) {
|
||||
case []map[string]interface{}:
|
||||
case []map[string]any:
|
||||
// The context is a table of hashes. Pick the most recent table
|
||||
// defined as the current hash.
|
||||
hash = t[len(t)-1]
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
hash = t
|
||||
default:
|
||||
p.panicf("Key '%s' has already been defined.", keyContext)
|
||||
|
@ -618,9 +651,8 @@ func (p *parser) setValue(key string, value interface{}) {
|
|||
p.removeImplicit(keyContext)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, we have a concrete key trying to override a previous
|
||||
// key, which is *always* wrong.
|
||||
// Otherwise, we have a concrete key trying to override a previous key,
|
||||
// which is *always* wrong.
|
||||
p.panicf("Key '%s' has already been defined.", keyContext)
|
||||
}
|
||||
|
||||
|
@ -683,8 +715,11 @@ func stripFirstNewline(s string) string {
|
|||
// the next newline. After a line-ending backslash, all whitespace is removed
|
||||
// until the next non-whitespace character.
|
||||
func (p *parser) stripEscapedNewlines(s string) string {
|
||||
var b strings.Builder
|
||||
var i int
|
||||
var (
|
||||
b strings.Builder
|
||||
i int
|
||||
)
|
||||
b.Grow(len(s))
|
||||
for {
|
||||
ix := strings.Index(s[i:], `\`)
|
||||
if ix < 0 {
|
||||
|
@ -714,9 +749,8 @@ func (p *parser) stripEscapedNewlines(s string) string {
|
|||
continue
|
||||
}
|
||||
if !strings.Contains(s[i:j], "\n") {
|
||||
// This is not a line-ending backslash.
|
||||
// (It's a bad escape sequence, but we can let
|
||||
// replaceEscapes catch it.)
|
||||
// This is not a line-ending backslash. (It's a bad escape sequence,
|
||||
// but we can let replaceEscapes catch it.)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
@ -727,79 +761,78 @@ func (p *parser) stripEscapedNewlines(s string) string {
|
|||
}
|
||||
|
||||
func (p *parser) replaceEscapes(it item, str string) string {
|
||||
replaced := make([]rune, 0, len(str))
|
||||
s := []byte(str)
|
||||
r := 0
|
||||
for r < len(s) {
|
||||
if s[r] != '\\' {
|
||||
c, size := utf8.DecodeRune(s[r:])
|
||||
r += size
|
||||
replaced = append(replaced, c)
|
||||
var (
|
||||
b strings.Builder
|
||||
skip = 0
|
||||
)
|
||||
b.Grow(len(str))
|
||||
for i, c := range str {
|
||||
if skip > 0 {
|
||||
skip--
|
||||
continue
|
||||
}
|
||||
r += 1
|
||||
if r >= len(s) {
|
||||
if c != '\\' {
|
||||
b.WriteRune(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if i >= len(str) {
|
||||
p.bug("Escape sequence at end of string.")
|
||||
return ""
|
||||
}
|
||||
switch s[r] {
|
||||
switch str[i+1] {
|
||||
default:
|
||||
p.bug("Expected valid escape code after \\, but got %q.", s[r])
|
||||
p.bug("Expected valid escape code after \\, but got %q.", str[i+1])
|
||||
case ' ', '\t':
|
||||
p.panicItemf(it, "invalid escape: '\\%c'", s[r])
|
||||
p.panicItemf(it, "invalid escape: '\\%c'", str[i+1])
|
||||
case 'b':
|
||||
replaced = append(replaced, rune(0x0008))
|
||||
r += 1
|
||||
b.WriteByte(0x08)
|
||||
skip = 1
|
||||
case 't':
|
||||
replaced = append(replaced, rune(0x0009))
|
||||
r += 1
|
||||
b.WriteByte(0x09)
|
||||
skip = 1
|
||||
case 'n':
|
||||
replaced = append(replaced, rune(0x000A))
|
||||
r += 1
|
||||
b.WriteByte(0x0a)
|
||||
skip = 1
|
||||
case 'f':
|
||||
replaced = append(replaced, rune(0x000C))
|
||||
r += 1
|
||||
b.WriteByte(0x0c)
|
||||
skip = 1
|
||||
case 'r':
|
||||
replaced = append(replaced, rune(0x000D))
|
||||
r += 1
|
||||
b.WriteByte(0x0d)
|
||||
skip = 1
|
||||
case 'e':
|
||||
if p.tomlNext {
|
||||
replaced = append(replaced, rune(0x001B))
|
||||
r += 1
|
||||
b.WriteByte(0x1b)
|
||||
skip = 1
|
||||
}
|
||||
case '"':
|
||||
replaced = append(replaced, rune(0x0022))
|
||||
r += 1
|
||||
b.WriteByte(0x22)
|
||||
skip = 1
|
||||
case '\\':
|
||||
replaced = append(replaced, rune(0x005C))
|
||||
r += 1
|
||||
b.WriteByte(0x5c)
|
||||
skip = 1
|
||||
// The lexer guarantees the correct number of characters are present;
|
||||
// don't need to check here.
|
||||
case 'x':
|
||||
if p.tomlNext {
|
||||
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+3])
|
||||
replaced = append(replaced, escaped)
|
||||
r += 3
|
||||
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+4])
|
||||
b.WriteRune(escaped)
|
||||
skip = 3
|
||||
}
|
||||
case 'u':
|
||||
// At this point, we know we have a Unicode escape of the form
|
||||
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
|
||||
// for us.)
|
||||
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+5])
|
||||
replaced = append(replaced, escaped)
|
||||
r += 5
|
||||
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+6])
|
||||
b.WriteRune(escaped)
|
||||
skip = 5
|
||||
case 'U':
|
||||
// At this point, we know we have a Unicode escape of the form
|
||||
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
|
||||
// for us.)
|
||||
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+9])
|
||||
replaced = append(replaced, escaped)
|
||||
r += 9
|
||||
escaped := p.asciiEscapeToUnicode(it, str[i+2:i+10])
|
||||
b.WriteRune(escaped)
|
||||
skip = 9
|
||||
}
|
||||
}
|
||||
return string(replaced)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (p *parser) asciiEscapeToUnicode(it item, bs []byte) rune {
|
||||
s := string(bs)
|
||||
func (p *parser) asciiEscapeToUnicode(it item, s string) rune {
|
||||
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
|
||||
if err != nil {
|
||||
p.bug("Could not parse '%s' as a hexadecimal number, but the lexer claims it's OK: %s", s, err)
|
||||
|
|
|
@ -25,10 +25,8 @@ type field struct {
|
|||
// breaking ties with index sequence.
|
||||
type byName []field
|
||||
|
||||
func (x byName) Len() int { return len(x) }
|
||||
|
||||
func (x byName) Len() int { return len(x) }
|
||||
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
|
||||
|
||||
func (x byName) Less(i, j int) bool {
|
||||
if x[i].name != x[j].name {
|
||||
return x[i].name < x[j].name
|
||||
|
@ -45,10 +43,8 @@ func (x byName) Less(i, j int) bool {
|
|||
// byIndex sorts field by index sequence.
|
||||
type byIndex []field
|
||||
|
||||
func (x byIndex) Len() int { return len(x) }
|
||||
|
||||
func (x byIndex) Len() int { return len(x) }
|
||||
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
|
||||
|
||||
func (x byIndex) Less(i, j int) bool {
|
||||
for k, xik := range x[i].index {
|
||||
if k >= len(x[j].index) {
|
||||
|
|
|
@ -22,13 +22,8 @@ func typeIsTable(t tomlType) bool {
|
|||
|
||||
type tomlBaseType string
|
||||
|
||||
func (btype tomlBaseType) typeString() string {
|
||||
return string(btype)
|
||||
}
|
||||
|
||||
func (btype tomlBaseType) String() string {
|
||||
return btype.typeString()
|
||||
}
|
||||
func (btype tomlBaseType) typeString() string { return string(btype) }
|
||||
func (btype tomlBaseType) String() string { return btype.typeString() }
|
||||
|
||||
var (
|
||||
tomlInteger tomlBaseType = "Integer"
|
||||
|
@ -54,7 +49,7 @@ func (p *parser) typeOfPrimitive(lexItem item) tomlType {
|
|||
return tomlFloat
|
||||
case itemDatetime:
|
||||
return tomlDatetime
|
||||
case itemString:
|
||||
case itemString, itemStringEsc:
|
||||
return tomlString
|
||||
case itemMultilineString:
|
||||
return tomlString
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
run:
|
||||
skip-files:
|
||||
- internal/handshake/cipher_suite.go
|
||||
linters-settings:
|
||||
misspell:
|
||||
ignore-words:
|
||||
|
@ -26,6 +23,7 @@ linters:
|
|||
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
|
||||
- gofumpt
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- misspell
|
||||
- prealloc
|
||||
|
@ -34,10 +32,14 @@ linters:
|
|||
- unconvert
|
||||
- unparam
|
||||
- unused
|
||||
- vet
|
||||
|
||||
issues:
|
||||
exclude-files:
|
||||
- internal/handshake/cipher_suite.go
|
||||
exclude-rules:
|
||||
- path: internal/qtls
|
||||
linters:
|
||||
- depguard
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- exhaustive
|
||||
|
|
|
@ -39,6 +39,12 @@ func validateConfig(config *Config) error {
|
|||
if config.MaxConnectionReceiveWindow > quicvarint.Max {
|
||||
config.MaxConnectionReceiveWindow = quicvarint.Max
|
||||
}
|
||||
if config.InitialPacketSize > 0 && config.InitialPacketSize < protocol.MinInitialPacketSize {
|
||||
config.InitialPacketSize = protocol.MinInitialPacketSize
|
||||
}
|
||||
if config.InitialPacketSize > protocol.MaxPacketBufferSize {
|
||||
config.InitialPacketSize = protocol.MaxPacketBufferSize
|
||||
}
|
||||
// check that all QUIC versions are actually supported
|
||||
for _, v := range config.Versions {
|
||||
if !protocol.IsValidVersion(v) {
|
||||
|
@ -94,6 +100,10 @@ func populateConfig(config *Config) *Config {
|
|||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
initialPacketSize := config.InitialPacketSize
|
||||
if initialPacketSize == 0 {
|
||||
initialPacketSize = protocol.InitialPacketSize
|
||||
}
|
||||
|
||||
return &Config{
|
||||
GetConfigForClient: config.GetConfigForClient,
|
||||
|
@ -110,6 +120,7 @@ func populateConfig(config *Config) *Config {
|
|||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
TokenStore: config.TokenStore,
|
||||
EnableDatagrams: config.EnableDatagrams,
|
||||
InitialPacketSize: initialPacketSize,
|
||||
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
|
||||
Allow0RTT: config.Allow0RTT,
|
||||
Tracer: config.Tracer,
|
||||
|
|
|
@ -153,7 +153,9 @@ type connection struct {
|
|||
unpacker unpacker
|
||||
frameParser wire.FrameParser
|
||||
packer packer
|
||||
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
|
||||
mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received
|
||||
|
||||
maxPayloadSizeEstimate atomic.Uint32
|
||||
|
||||
initialStream cryptoStream
|
||||
handshakeStream cryptoStream
|
||||
|
@ -276,7 +278,7 @@ var newConnection = func(
|
|||
s.preSetup()
|
||||
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
||||
0,
|
||||
getMaxPacketSize(s.conn.RemoteAddr()),
|
||||
protocol.ByteCount(s.config.InitialPacketSize),
|
||||
s.rttStats,
|
||||
clientAddressValidated,
|
||||
s.conn.capabilities().ECN,
|
||||
|
@ -284,7 +286,7 @@ var newConnection = func(
|
|||
s.tracer,
|
||||
s.logger,
|
||||
)
|
||||
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
|
||||
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
|
||||
params := &wire.TransportParameters{
|
||||
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||
|
@ -295,6 +297,7 @@ var newConnection = func(
|
|||
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
|
||||
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
|
||||
AckDelayExponent: protocol.AckDelayExponent,
|
||||
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
|
||||
DisableActiveMigration: true,
|
||||
StatelessResetToken: &statelessResetToken,
|
||||
OriginalDestinationConnectionID: origDestConnID,
|
||||
|
@ -385,7 +388,7 @@ var newClientConnection = func(
|
|||
s.preSetup()
|
||||
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
||||
initialPacketNumber,
|
||||
getMaxPacketSize(s.conn.RemoteAddr()),
|
||||
protocol.ByteCount(s.config.InitialPacketSize),
|
||||
s.rttStats,
|
||||
false, // has no effect
|
||||
s.conn.capabilities().ECN,
|
||||
|
@ -393,7 +396,7 @@ var newClientConnection = func(
|
|||
s.tracer,
|
||||
s.logger,
|
||||
)
|
||||
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
|
||||
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
|
||||
oneRTTStream := newCryptoStream()
|
||||
params := &wire.TransportParameters{
|
||||
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
|
||||
|
@ -404,6 +407,7 @@ var newClientConnection = func(
|
|||
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
|
||||
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
|
||||
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
|
||||
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
|
||||
AckDelayExponent: protocol.AckDelayExponent,
|
||||
DisableActiveMigration: true,
|
||||
// For interoperability with quic-go versions before May 2023, this value must be set to a value
|
||||
|
@ -781,11 +785,7 @@ func (s *connection) handleHandshakeConfirmed() error {
|
|||
s.cryptoStreamHandler.SetHandshakeConfirmed()
|
||||
|
||||
if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF {
|
||||
maxPacketSize := s.peerParams.MaxUDPPayloadSize
|
||||
if maxPacketSize == 0 {
|
||||
maxPacketSize = protocol.MaxByteCount
|
||||
}
|
||||
s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize))
|
||||
s.mtuDiscoverer.Start()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -1774,6 +1774,17 @@ func (s *connection) applyTransportParameters() {
|
|||
// Retire the connection ID.
|
||||
s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken)
|
||||
}
|
||||
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
|
||||
if params.MaxUDPPayloadSize > 0 && params.MaxUDPPayloadSize < maxPacketSize {
|
||||
maxPacketSize = params.MaxUDPPayloadSize
|
||||
}
|
||||
s.mtuDiscoverer = newMTUDiscoverer(
|
||||
s.rttStats,
|
||||
protocol.ByteCount(s.config.InitialPacketSize),
|
||||
maxPacketSize,
|
||||
s.onMTUIncreased,
|
||||
s.tracer,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *connection) triggerSending(now time.Time) error {
|
||||
|
@ -1862,7 +1873,7 @@ func (s *connection) sendPackets(now time.Time) error {
|
|||
}
|
||||
|
||||
if !s.handshakeConfirmed {
|
||||
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
packet, err := s.packer.PackCoalescedPacket(false, s.maxPacketSize(), s.version)
|
||||
if err != nil || packet == nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1889,7 +1900,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
|
|||
for {
|
||||
buf := getPacketBuffer()
|
||||
ecn := s.sentPacketHandler.ECNMode(true)
|
||||
if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil {
|
||||
if _, err := s.appendOneShortHeaderPacket(buf, s.maxPacketSize(), ecn, now); err != nil {
|
||||
if err == errNothingToPack {
|
||||
buf.Release()
|
||||
return nil
|
||||
|
@ -1920,7 +1931,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
|
|||
|
||||
func (s *connection) sendPacketsWithGSO(now time.Time) error {
|
||||
buf := getLargePacketBuffer()
|
||||
maxSize := s.mtuDiscoverer.CurrentSize()
|
||||
maxSize := s.maxPacketSize()
|
||||
|
||||
ecn := s.sentPacketHandler.ECNMode(true)
|
||||
for {
|
||||
|
@ -1989,7 +2000,7 @@ func (s *connection) resetPacingDeadline() {
|
|||
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
|
||||
if !s.handshakeConfirmed {
|
||||
ecn := s.sentPacketHandler.ECNMode(false)
|
||||
packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
packet, err := s.packer.PackCoalescedPacket(true, s.maxPacketSize(), s.version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2000,7 +2011,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
|
|||
}
|
||||
|
||||
ecn := s.sentPacketHandler.ECNMode(true)
|
||||
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
p, buf, err := s.packer.PackAckOnlyPacket(s.maxPacketSize(), s.version)
|
||||
if err != nil {
|
||||
if err == errNothingToPack {
|
||||
return nil
|
||||
|
@ -2022,7 +2033,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
|
|||
break
|
||||
}
|
||||
var err error
|
||||
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2033,7 +2044,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
|
|||
if packet == nil {
|
||||
s.retransmissionQueue.AddPing(encLevel)
|
||||
var err error
|
||||
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -2112,14 +2123,14 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
|
|||
var transportErr *qerr.TransportError
|
||||
var applicationErr *qerr.ApplicationError
|
||||
if errors.As(e, &transportErr) {
|
||||
packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
packet, err = s.packer.PackConnectionClose(transportErr, s.maxPacketSize(), s.version)
|
||||
} else if errors.As(e, &applicationErr) {
|
||||
packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
packet, err = s.packer.PackApplicationClose(applicationErr, s.maxPacketSize(), s.version)
|
||||
} else {
|
||||
packet, err = s.packer.PackConnectionClose(&qerr.TransportError{
|
||||
ErrorCode: qerr.InternalError,
|
||||
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
|
||||
}, s.mtuDiscoverer.CurrentSize(), s.version)
|
||||
}, s.maxPacketSize(), s.version)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -2129,6 +2140,24 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
|
|||
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
|
||||
}
|
||||
|
||||
func (s *connection) maxPacketSize() protocol.ByteCount {
|
||||
if s.mtuDiscoverer == nil {
|
||||
// Use the configured packet size on the client side.
|
||||
// If the server sends a max_udp_payload_size that's smaller than this size, we can ignore this:
|
||||
// Apparently the server still processed the (fully padded) Initial packet anyway.
|
||||
if s.perspective == protocol.PerspectiveClient {
|
||||
return protocol.ByteCount(s.config.InitialPacketSize)
|
||||
}
|
||||
// On the server side, there's no downside to using 1200 bytes until we received the client's transport
|
||||
// parameters:
|
||||
// * If the first packet didn't contain the entire ClientHello, all we can do is ACK that packet. We don't
|
||||
// need a lot of bytes for that.
|
||||
// * If it did, we will have processed the transport parameters and initialized the MTU discoverer.
|
||||
return protocol.MinInitialPacketSize
|
||||
}
|
||||
return s.mtuDiscoverer.CurrentSize()
|
||||
}
|
||||
|
||||
func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
|
||||
// quic-go logging
|
||||
if s.logger.Debug() {
|
||||
|
@ -2352,13 +2381,23 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {
|
||||
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu)))
|
||||
s.sentPacketHandler.SetMaxDatagramSize(mtu)
|
||||
}
|
||||
|
||||
func (s *connection) SendDatagram(p []byte) error {
|
||||
if !s.supportsDatagrams() {
|
||||
return errors.New("datagram support disabled")
|
||||
}
|
||||
|
||||
f := &wire.DatagramFrame{DataLenPresent: true}
|
||||
maxDataLen := f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version)
|
||||
// The payload size estimate is conservative.
|
||||
// Under many circumstances we could send a few more bytes.
|
||||
maxDataLen := min(
|
||||
f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version),
|
||||
protocol.ByteCount(s.maxPayloadSizeEstimate.Load()),
|
||||
)
|
||||
if protocol.ByteCount(len(p)) > maxDataLen {
|
||||
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
|
||||
}
|
||||
|
@ -2391,3 +2430,10 @@ func (s *connection) NextConnection() Connection {
|
|||
s.streamsMap.UseResetMaps()
|
||||
return s
|
||||
}
|
||||
|
||||
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
|
||||
// It is not very sophisticated: it just subtracts the size of header (assuming the maximum
|
||||
// connection ID length), and the size of the encryption tag.
|
||||
func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount {
|
||||
return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */
|
||||
}
|
||||
|
|
|
@ -21,13 +21,29 @@ func (r *exactReader) Read(b []byte) (int, error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
type countingByteReader struct {
|
||||
io.ByteReader
|
||||
Read int
|
||||
}
|
||||
|
||||
func (r *countingByteReader) ReadByte() (byte, error) {
|
||||
b, err := r.ByteReader.ReadByte()
|
||||
if err == nil {
|
||||
r.Read++
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
|
||||
// ParseCapsule parses the header of a Capsule.
|
||||
// It returns an io.LimitedReader that can be used to read the Capsule value.
|
||||
// The Capsule value must be read entirely (i.e. until the io.EOF) before using r again.
|
||||
func ParseCapsule(r quicvarint.Reader) (CapsuleType, io.Reader, error) {
|
||||
ct, err := quicvarint.Read(r)
|
||||
cbr := countingByteReader{ByteReader: r}
|
||||
ct, err := quicvarint.Read(&cbr)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// If an io.EOF is returned without consuming any bytes, return it unmodified.
|
||||
// Otherwise, return an io.ErrUnexpectedEOF.
|
||||
if err == io.EOF && cbr.Read > 0 {
|
||||
return 0, nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, nil, err
|
||||
|
|
|
@ -121,12 +121,16 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
|
|||
}
|
||||
return
|
||||
}
|
||||
go func(str quic.Stream) {
|
||||
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
|
||||
fp := &frameParser{
|
||||
r: str,
|
||||
conn: c.hconn,
|
||||
unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) {
|
||||
id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
|
||||
return c.StreamHijacker(ft, id, str, e)
|
||||
})
|
||||
if err == errHijacked {
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
if _, err := fp.ParseNext(); err == errHijacked {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -135,7 +139,7 @@ func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
|
|||
}
|
||||
}
|
||||
c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
|
||||
}(str)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
@ -71,27 +70,11 @@ func newConnection(
|
|||
return c
|
||||
}
|
||||
|
||||
func (c *connection) onStreamStateChange(id quic.StreamID, state streamState, e error) {
|
||||
func (c *connection) clearStream(id quic.StreamID) {
|
||||
c.streamMx.Lock()
|
||||
defer c.streamMx.Unlock()
|
||||
|
||||
d, ok := c.streams[id]
|
||||
if !ok { // should never happen
|
||||
return
|
||||
}
|
||||
var isDone bool
|
||||
//nolint:exhaustive // These are all the cases we care about.
|
||||
switch state {
|
||||
case streamStateReceiveClosed:
|
||||
isDone = d.SetReceiveError(e)
|
||||
case streamStateSendClosed:
|
||||
isDone = d.SetSendError(e)
|
||||
default:
|
||||
return
|
||||
}
|
||||
if isDone {
|
||||
delete(c.streams, id)
|
||||
}
|
||||
delete(c.streams, id)
|
||||
}
|
||||
|
||||
func (c *connection) openRequestStream(
|
||||
|
@ -109,7 +92,7 @@ func (c *connection) openRequestStream(
|
|||
c.streamMx.Lock()
|
||||
c.streams[str.StreamID()] = datagrams
|
||||
c.streamMx.Unlock()
|
||||
qstr := newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
|
||||
qstr := newStateTrackingStream(str, c, datagrams)
|
||||
hstr := newStream(qstr, c, datagrams)
|
||||
return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes), nil
|
||||
}
|
||||
|
@ -121,10 +104,11 @@ func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagramme
|
|||
}
|
||||
datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
|
||||
if c.perspective == protocol.PerspectiveServer {
|
||||
strID := str.StreamID()
|
||||
c.streamMx.Lock()
|
||||
c.streams[str.StreamID()] = datagrams
|
||||
c.streams[strID] = datagrams
|
||||
c.streamMx.Unlock()
|
||||
str = newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
|
||||
str = newStateTrackingStream(str, c, datagrams)
|
||||
}
|
||||
return str, datagrams, nil
|
||||
}
|
||||
|
@ -201,7 +185,8 @@ func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.Co
|
|||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
|
||||
return
|
||||
}
|
||||
f, err := parseNextFrame(str, nil)
|
||||
fp := &frameParser{conn: c.Connection, r: str}
|
||||
f, err := fp.ParseNext()
|
||||
if err != nil {
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
|
||||
return
|
||||
|
@ -252,9 +237,7 @@ func (c *connection) receiveDatagrams() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: this is quite wasteful in terms of allocations
|
||||
r := bytes.NewReader(b)
|
||||
quarterStreamID, err := quicvarint.Read(r)
|
||||
quarterStreamID, n, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
|
||||
return fmt.Errorf("could not read quarter stream id: %w", err)
|
||||
|
@ -271,7 +254,7 @@ func (c *connection) receiveDatagrams() error {
|
|||
return nil
|
||||
}
|
||||
c.streamMx.Unlock()
|
||||
dg.enqueue(b[len(b)-r.Len():])
|
||||
dg.enqueue(b[n:])
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,21 +27,19 @@ func newDatagrammer(sendDatagram func([]byte) error) *datagrammer {
|
|||
}
|
||||
}
|
||||
|
||||
func (d *datagrammer) SetReceiveError(err error) (isDone bool) {
|
||||
func (d *datagrammer) SetReceiveError(err error) {
|
||||
d.mx.Lock()
|
||||
defer d.mx.Unlock()
|
||||
|
||||
d.receiveErr = err
|
||||
d.signalHasData()
|
||||
return d.sendErr != nil
|
||||
}
|
||||
|
||||
func (d *datagrammer) SetSendError(err error) (isDone bool) {
|
||||
func (d *datagrammer) SetSendError(err error) {
|
||||
d.mx.Lock()
|
||||
defer d.mx.Unlock()
|
||||
|
||||
d.sendErr = err
|
||||
return d.receiveErr != nil
|
||||
}
|
||||
|
||||
func (d *datagrammer) Send(b []byte) error {
|
||||
|
@ -85,9 +83,9 @@ start:
|
|||
d.mx.Unlock()
|
||||
return data, nil
|
||||
}
|
||||
if d.receiveErr != nil {
|
||||
if receiveErr := d.receiveErr; receiveErr != nil {
|
||||
d.mx.Unlock()
|
||||
return nil, d.receiveErr
|
||||
return nil, receiveErr
|
||||
}
|
||||
d.mx.Unlock()
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
|
@ -18,13 +19,19 @@ type frame interface{}
|
|||
|
||||
var errHijacked = errors.New("hijacked")
|
||||
|
||||
func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) {
|
||||
qr := quicvarint.NewReader(r)
|
||||
type frameParser struct {
|
||||
r io.Reader
|
||||
conn quic.Connection
|
||||
unknownFrameHandler unknownFrameHandlerFunc
|
||||
}
|
||||
|
||||
func (p *frameParser) ParseNext() (frame, error) {
|
||||
qr := quicvarint.NewReader(p.r)
|
||||
for {
|
||||
t, err := quicvarint.Read(qr)
|
||||
if err != nil {
|
||||
if unknownFrameHandler != nil {
|
||||
hijacked, err := unknownFrameHandler(0, err)
|
||||
if p.unknownFrameHandler != nil {
|
||||
hijacked, err := p.unknownFrameHandler(0, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -35,8 +42,8 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
|
|||
return nil, err
|
||||
}
|
||||
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
|
||||
if t > 0xd && unknownFrameHandler != nil {
|
||||
hijacked, err := unknownFrameHandler(FrameType(t), nil)
|
||||
if t > 0xd && p.unknownFrameHandler != nil {
|
||||
hijacked, err := p.unknownFrameHandler(FrameType(t), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -56,11 +63,14 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
|
|||
case 0x1:
|
||||
return &headersFrame{Length: l}, nil
|
||||
case 0x4:
|
||||
return parseSettingsFrame(r, l)
|
||||
return parseSettingsFrame(p.r, l)
|
||||
case 0x3: // CANCEL_PUSH
|
||||
case 0x5: // PUSH_PROMISE
|
||||
case 0x7: // GOAWAY
|
||||
case 0xd: // MAX_PUSH_ID
|
||||
case 0x2, 0x6, 0x8, 0x9:
|
||||
p.conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "")
|
||||
return nil, fmt.Errorf("http3: reserved frame type: %d", t)
|
||||
}
|
||||
// skip over unknown frames
|
||||
if _, err := io.CopyN(io.Discard, qr, int64(l)); err != nil {
|
||||
|
|
|
@ -63,10 +63,14 @@ func newStream(str quic.Stream, conn *connection, datagrams *datagrammer) *strea
|
|||
}
|
||||
|
||||
func (s *stream) Read(b []byte) (int, error) {
|
||||
fp := &frameParser{
|
||||
r: s.Stream,
|
||||
conn: s.conn,
|
||||
}
|
||||
if s.bytesRemainingInFrame == 0 {
|
||||
parseLoop:
|
||||
for {
|
||||
frame, err := parseNextFrame(s.Stream, nil)
|
||||
frame, err := fp.ParseNext()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -177,7 +181,11 @@ func (s *requestStream) SendRequestHeader(req *http.Request) error {
|
|||
}
|
||||
|
||||
func (s *requestStream) ReadResponse() (*http.Response, error) {
|
||||
frame, err := parseNextFrame(s.Stream, nil)
|
||||
fp := &frameParser{
|
||||
r: s.Stream,
|
||||
conn: s.conn,
|
||||
}
|
||||
frame, err := fp.ParseNext()
|
||||
if err != nil {
|
||||
s.Stream.CancelRead(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
s.Stream.CancelWrite(quic.StreamErrorCode(ErrCodeFrameError))
|
||||
|
@ -250,7 +258,7 @@ func (s *requestStream) ReadResponse() (*http.Response, error) {
|
|||
|
||||
func (s *stream) SendDatagram(b []byte) error {
|
||||
// TODO: reject if datagrams are not negotiated (yet)
|
||||
return s.conn.sendDatagram(s.Stream.StreamID(), b)
|
||||
return s.datagrams.Send(b)
|
||||
}
|
||||
|
||||
func (s *stream) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
||||
|
|
|
@ -477,7 +477,8 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
|
|||
)
|
||||
}
|
||||
}
|
||||
frame, err := parseNextFrame(str, ufh)
|
||||
fp := &frameParser{conn: conn, r: str, unknownFrameHandler: ufh}
|
||||
frame, err := fp.ParseNext()
|
||||
if err != nil {
|
||||
if !errors.Is(err, errHijacked) {
|
||||
str.CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete))
|
||||
|
@ -665,11 +666,16 @@ func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) er
|
|||
return server.ListenAndServeTLS(certFile, keyFile)
|
||||
}
|
||||
|
||||
// ListenAndServe listens on the given network address for both TLS/TCP and QUIC
|
||||
// Deprecated: use ListenAndServeTLS instead.
|
||||
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
return ListenAndServeTLS(addr, certFile, keyFile, handler)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens on the given network address for both TLS/TCP and QUIC
|
||||
// connections in parallel. It returns if one of the two returns an error.
|
||||
// http.DefaultServeMux is used when handler is nil.
|
||||
// The correct Alt-Svc headers for QUIC are set.
|
||||
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
func ListenAndServeTLS(addr, certFile, keyFile string, handler http.Handler) error {
|
||||
// Load certs
|
||||
var err error
|
||||
certs := make([]tls.Certificate, 1)
|
||||
|
|
|
@ -1,62 +1,87 @@
|
|||
package http3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type streamState uint8
|
||||
|
||||
const (
|
||||
streamStateOpen streamState = iota
|
||||
streamStateReceiveClosed
|
||||
streamStateSendClosed
|
||||
streamStateSendAndReceiveClosed
|
||||
)
|
||||
var _ quic.Stream = &stateTrackingStream{}
|
||||
|
||||
// stateTrackingStream is an implementation of quic.Stream that delegates
|
||||
// to an underlying stream
|
||||
// it takes care of proxying send and receive errors onto an implementation of
|
||||
// the errorSetter interface (intended to be occupied by a datagrammer)
|
||||
// it is also responsible for clearing the stream based on its ID from its
|
||||
// parent connection, this is done through the streamClearer interface when
|
||||
// both the send and receive sides are closed
|
||||
type stateTrackingStream struct {
|
||||
quic.Stream
|
||||
|
||||
mx sync.Mutex
|
||||
state streamState
|
||||
mx sync.Mutex
|
||||
sendErr error
|
||||
recvErr error
|
||||
|
||||
onStateChange func(streamState, error)
|
||||
clearer streamClearer
|
||||
setter errorSetter
|
||||
}
|
||||
|
||||
func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream {
|
||||
return &stateTrackingStream{
|
||||
Stream: s,
|
||||
state: streamStateOpen,
|
||||
onStateChange: onStateChange,
|
||||
type streamClearer interface {
|
||||
clearStream(quic.StreamID)
|
||||
}
|
||||
|
||||
type errorSetter interface {
|
||||
SetSendError(error)
|
||||
SetReceiveError(error)
|
||||
}
|
||||
|
||||
func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream {
|
||||
t := &stateTrackingStream{
|
||||
Stream: s,
|
||||
clearer: clearer,
|
||||
setter: setter,
|
||||
}
|
||||
}
|
||||
|
||||
var _ quic.Stream = &stateTrackingStream{}
|
||||
context.AfterFunc(s.Context(), func() {
|
||||
t.closeSend(context.Cause(s.Context()))
|
||||
})
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) closeSend(e error) {
|
||||
s.mx.Lock()
|
||||
defer s.mx.Unlock()
|
||||
|
||||
if s.state == streamStateReceiveClosed || s.state == streamStateSendAndReceiveClosed {
|
||||
s.state = streamStateSendAndReceiveClosed
|
||||
} else {
|
||||
s.state = streamStateSendClosed
|
||||
// clear the stream the first time both the send
|
||||
// and receive are finished
|
||||
if s.sendErr == nil {
|
||||
if s.recvErr != nil {
|
||||
s.clearer.clearStream(s.StreamID())
|
||||
}
|
||||
|
||||
s.setter.SetSendError(e)
|
||||
s.sendErr = e
|
||||
}
|
||||
s.onStateChange(s.state, e)
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) closeReceive(e error) {
|
||||
s.mx.Lock()
|
||||
defer s.mx.Unlock()
|
||||
|
||||
if s.state == streamStateSendClosed || s.state == streamStateSendAndReceiveClosed {
|
||||
s.state = streamStateSendAndReceiveClosed
|
||||
} else {
|
||||
s.state = streamStateReceiveClosed
|
||||
// clear the stream the first time both the send
|
||||
// and receive are finished
|
||||
if s.recvErr == nil {
|
||||
if s.sendErr != nil {
|
||||
s.clearer.clearStream(s.StreamID())
|
||||
}
|
||||
|
||||
s.setter.SetReceiveError(e)
|
||||
s.recvErr = e
|
||||
}
|
||||
s.onStateChange(s.state, e)
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) Close() error {
|
||||
|
@ -71,7 +96,7 @@ func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
|
|||
|
||||
func (s *stateTrackingStream) Write(b []byte) (int, error) {
|
||||
n, err := s.Stream.Write(b)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
s.closeSend(err)
|
||||
}
|
||||
return n, err
|
||||
|
@ -84,7 +109,7 @@ func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
|
|||
|
||||
func (s *stateTrackingStream) Read(b []byte) (int, error) {
|
||||
n, err := s.Stream.Read(b)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
s.closeReceive(err)
|
||||
}
|
||||
return n, err
|
||||
|
|
|
@ -325,10 +325,16 @@ type Config struct {
|
|||
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
|
||||
// every half of MaxIdleTimeout, whichever is smaller).
|
||||
KeepAlivePeriod time.Duration
|
||||
// InitialPacketSize is the initial size of packets sent.
|
||||
// It is usually not necessary to manually set this value,
|
||||
// since Path MTU discovery very quickly finds the path's MTU.
|
||||
// If set too high, the path might not support packets that large, leading to a timeout of the QUIC handshake.
|
||||
// Values below 1200 are invalid.
|
||||
InitialPacketSize uint16
|
||||
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
|
||||
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
|
||||
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
|
||||
// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
|
||||
// If unavailable or disabled, packets will be at most 1280 bytes in size.
|
||||
DisablePathMTUDiscovery bool
|
||||
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
|
||||
// Only valid for the server.
|
||||
|
|
|
@ -17,11 +17,11 @@ import (
|
|||
// 1024*1024^3 (first 1024 is from 0.100^3)
|
||||
// where 0.100 is 100 ms which is the scaling round trip time.
|
||||
const (
|
||||
cubeScale = 40
|
||||
cubeCongestionWindowScale = 410
|
||||
cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
|
||||
cubeScale = 40
|
||||
cubeCongestionWindowScale = 410
|
||||
cubeFactor = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
|
||||
// TODO: when re-enabling cubic, make sure to use the actual packet size here
|
||||
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
|
||||
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
|
||||
)
|
||||
|
||||
const defaultNumConnections = 1
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
const (
|
||||
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
|
||||
// Used in QUIC for congestion window computations in bytes.
|
||||
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
|
||||
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
|
||||
maxBurstPackets = 3
|
||||
renoBeta = 0.7 // Reno backoff factor.
|
||||
minCongestionWindowPackets = 2
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
|
@ -338,25 +337,26 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (a
|
|||
return false
|
||||
}
|
||||
|
||||
func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
|
||||
r := bytes.NewReader(data)
|
||||
ver, err := quicvarint.Read(r)
|
||||
func decodeDataFromSessionState(b []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
|
||||
ver, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
b = b[l:]
|
||||
if ver != clientSessionStateRevision {
|
||||
return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
|
||||
}
|
||||
rttEncoded, err := quicvarint.Read(r)
|
||||
rttEncoded, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
b = b[l:]
|
||||
rtt := time.Duration(rttEncoded) * time.Microsecond
|
||||
if !earlyData {
|
||||
return rtt, nil, nil
|
||||
}
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
||||
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
return rtt, &tp, nil
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
@ -28,25 +27,26 @@ func (t *sessionTicket) Marshal() []byte {
|
|||
}
|
||||
|
||||
func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
|
||||
r := bytes.NewReader(b)
|
||||
rev, err := quicvarint.Read(r)
|
||||
rev, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return errors.New("failed to read session ticket revision")
|
||||
}
|
||||
b = b[l:]
|
||||
if rev != sessionTicketRevision {
|
||||
return fmt.Errorf("unknown session ticket revision: %d", rev)
|
||||
}
|
||||
rtt, err := quicvarint.Read(r)
|
||||
rtt, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return errors.New("failed to read RTT")
|
||||
}
|
||||
b = b[l:]
|
||||
if using0RTT {
|
||||
var tp wire.TransportParameters
|
||||
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
||||
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
|
||||
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
|
||||
}
|
||||
t.Parameters = &tp
|
||||
} else if r.Len() > 0 {
|
||||
} else if len(b) > 0 {
|
||||
return fmt.Errorf("the session ticket has more bytes than expected")
|
||||
}
|
||||
t.RTT = time.Duration(rtt) * time.Microsecond
|
||||
|
|
|
@ -8,11 +8,8 @@ const DesiredReceiveBufferSize = (1 << 20) * 7 // 7 MB
|
|||
// DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use.
|
||||
const DesiredSendBufferSize = (1 << 20) * 7 // 7 MB
|
||||
|
||||
// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
|
||||
const InitialPacketSizeIPv4 = 1252
|
||||
|
||||
// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
|
||||
const InitialPacketSizeIPv6 = 1232
|
||||
// InitialPacketSize is the initial (before Path MTU discovery) maximum packet size used.
|
||||
const InitialPacketSize = 1280
|
||||
|
||||
// MaxCongestionWindowPackets is the maximum congestion window in packet.
|
||||
const MaxCongestionWindowPackets = 10000
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sort"
|
||||
"time"
|
||||
|
@ -22,18 +21,21 @@ type AckFrame struct {
|
|||
}
|
||||
|
||||
// parseAckFrame reads an ACK frame
|
||||
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error {
|
||||
func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8, _ protocol.Version) (int, error) {
|
||||
startLen := len(b)
|
||||
ecn := typ == ackECNFrameType
|
||||
|
||||
la, err := quicvarint.Read(r)
|
||||
la, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
largestAcked := protocol.PacketNumber(la)
|
||||
delay, err := quicvarint.Read(r)
|
||||
delay, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
|
||||
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
|
||||
if delayTime < 0 {
|
||||
|
@ -42,71 +44,78 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
|
|||
}
|
||||
frame.DelayTime = delayTime
|
||||
|
||||
numBlocks, err := quicvarint.Read(r)
|
||||
numBlocks, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
|
||||
// read the first ACK range
|
||||
ab, err := quicvarint.Read(r)
|
||||
ab, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
ackBlock := protocol.PacketNumber(ab)
|
||||
if ackBlock > largestAcked {
|
||||
return errors.New("invalid first ACK range")
|
||||
return 0, errors.New("invalid first ACK range")
|
||||
}
|
||||
smallest := largestAcked - ackBlock
|
||||
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
|
||||
|
||||
// read all the other ACK ranges
|
||||
for i := uint64(0); i < numBlocks; i++ {
|
||||
g, err := quicvarint.Read(r)
|
||||
g, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
gap := protocol.PacketNumber(g)
|
||||
if smallest < gap+2 {
|
||||
return errInvalidAckRanges
|
||||
return 0, errInvalidAckRanges
|
||||
}
|
||||
largest := smallest - gap - 2
|
||||
|
||||
ab, err := quicvarint.Read(r)
|
||||
ab, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
ackBlock := protocol.PacketNumber(ab)
|
||||
|
||||
if ackBlock > largest {
|
||||
return errInvalidAckRanges
|
||||
return 0, errInvalidAckRanges
|
||||
}
|
||||
smallest = largest - ackBlock
|
||||
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
|
||||
}
|
||||
|
||||
if !frame.validateAckRanges() {
|
||||
return errInvalidAckRanges
|
||||
return 0, errInvalidAckRanges
|
||||
}
|
||||
|
||||
if ecn {
|
||||
ect0, err := quicvarint.Read(r)
|
||||
ect0, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
frame.ECT0 = ect0
|
||||
ect1, err := quicvarint.Read(r)
|
||||
ect1, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
frame.ECT1 = ect1
|
||||
ecnce, err := quicvarint.Read(r)
|
||||
ecnce, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
frame.ECNCE = ecnce
|
||||
}
|
||||
|
||||
return nil
|
||||
return startLen - len(b), nil
|
||||
}
|
||||
|
||||
// Append appends an ACK frame.
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -16,40 +15,38 @@ type ConnectionCloseFrame struct {
|
|||
ReasonPhrase string
|
||||
}
|
||||
|
||||
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) {
|
||||
func parseConnectionCloseFrame(b []byte, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, int, error) {
|
||||
startLen := len(b)
|
||||
f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
|
||||
ec, err := quicvarint.Read(r)
|
||||
ec, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
f.ErrorCode = ec
|
||||
// read the Frame Type, if this is not an application error
|
||||
if !f.IsApplicationError {
|
||||
ft, err := quicvarint.Read(r)
|
||||
ft, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
f.FrameType = ft
|
||||
}
|
||||
var reasonPhraseLen uint64
|
||||
reasonPhraseLen, err = quicvarint.Read(r)
|
||||
reasonPhraseLen, l, err = quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
// shortcut to prevent the unnecessary allocation of dataLen bytes
|
||||
// if the dataLen is larger than the remaining length of the packet
|
||||
// reading the whole reason phrase would result in EOF when attempting to READ
|
||||
if int(reasonPhraseLen) > r.Len() {
|
||||
return nil, io.EOF
|
||||
b = b[l:]
|
||||
if int(reasonPhraseLen) > len(b) {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
|
||||
reasonPhrase := make([]byte, reasonPhraseLen)
|
||||
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
|
||||
// this should never happen, since we already checked the reasonPhraseLen earlier
|
||||
return nil, err
|
||||
}
|
||||
copy(reasonPhrase, b)
|
||||
f.ReasonPhrase = string(reasonPhrase)
|
||||
return f, nil
|
||||
return f, startLen - len(b) + int(reasonPhraseLen), nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -14,28 +13,28 @@ type CryptoFrame struct {
|
|||
Data []byte
|
||||
}
|
||||
|
||||
func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) {
|
||||
func parseCryptoFrame(b []byte, _ protocol.Version) (*CryptoFrame, int, error) {
|
||||
startLen := len(b)
|
||||
frame := &CryptoFrame{}
|
||||
offset, err := quicvarint.Read(r)
|
||||
offset, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
frame.Offset = protocol.ByteCount(offset)
|
||||
dataLen, err := quicvarint.Read(r)
|
||||
dataLen, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
if dataLen > uint64(r.Len()) {
|
||||
return nil, io.EOF
|
||||
b = b[l:]
|
||||
if dataLen > uint64(len(b)) {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
if dataLen != 0 {
|
||||
frame.Data = make([]byte, dataLen)
|
||||
if _, err := io.ReadFull(r, frame.Data); err != nil {
|
||||
// this should never happen, since we already checked the dataLen earlier
|
||||
return nil, err
|
||||
}
|
||||
copy(frame.Data, b)
|
||||
}
|
||||
return frame, nil
|
||||
return frame, startLen - len(b) + int(dataLen), nil
|
||||
}
|
||||
|
||||
func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
@ -12,12 +10,12 @@ type DataBlockedFrame struct {
|
|||
MaximumData protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) {
|
||||
offset, err := quicvarint.Read(r)
|
||||
func parseDataBlockedFrame(b []byte, _ protocol.Version) (*DataBlockedFrame, int, error) {
|
||||
offset, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil
|
||||
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, l, nil
|
||||
}
|
||||
|
||||
func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -20,29 +19,29 @@ type DatagramFrame struct {
|
|||
Data []byte
|
||||
}
|
||||
|
||||
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) {
|
||||
func parseDatagramFrame(b []byte, typ uint64, _ protocol.Version) (*DatagramFrame, int, error) {
|
||||
startLen := len(b)
|
||||
f := &DatagramFrame{}
|
||||
f.DataLenPresent = typ&0x1 > 0
|
||||
|
||||
var length uint64
|
||||
if f.DataLenPresent {
|
||||
var err error
|
||||
len, err := quicvarint.Read(r)
|
||||
var l int
|
||||
length, l, err = quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
if len > uint64(r.Len()) {
|
||||
return nil, io.EOF
|
||||
b = b[l:]
|
||||
if length > uint64(len(b)) {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
length = len
|
||||
} else {
|
||||
length = uint64(r.Len())
|
||||
length = uint64(len(b))
|
||||
}
|
||||
f.Data = make([]byte, length)
|
||||
if _, err := io.ReadFull(r, f.Data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
copy(f.Data, b)
|
||||
return f, startLen - len(b) + int(length), nil
|
||||
}
|
||||
|
||||
func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -38,8 +38,6 @@ const (
|
|||
|
||||
// The FrameParser parses QUIC frames, one by one.
|
||||
type FrameParser struct {
|
||||
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
|
||||
|
||||
ackDelayExponent uint8
|
||||
supportsDatagrams bool
|
||||
|
||||
|
@ -51,7 +49,6 @@ type FrameParser struct {
|
|||
// NewFrameParser creates a new frame parser.
|
||||
func NewFrameParser(supportsDatagrams bool) *FrameParser {
|
||||
return &FrameParser{
|
||||
r: *bytes.NewReader(nil),
|
||||
supportsDatagrams: supportsDatagrams,
|
||||
ackFrame: &AckFrame{},
|
||||
}
|
||||
|
@ -60,45 +57,46 @@ func NewFrameParser(supportsDatagrams bool) *FrameParser {
|
|||
// ParseNext parses the next frame.
|
||||
// It skips PADDING frames.
|
||||
func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
|
||||
startLen := len(data)
|
||||
p.r.Reset(data)
|
||||
frame, err := p.parseNext(&p.r, encLevel, v)
|
||||
n := startLen - p.r.Len()
|
||||
p.r.Reset(nil)
|
||||
return n, frame, err
|
||||
frame, l, err := p.parseNext(data, encLevel, v)
|
||||
return l, frame, err
|
||||
}
|
||||
|
||||
func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
|
||||
for r.Len() != 0 {
|
||||
typ, err := quicvarint.Read(r)
|
||||
func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
|
||||
var parsed int
|
||||
for len(b) != 0 {
|
||||
typ, l, err := quicvarint.Parse(b)
|
||||
parsed += l
|
||||
if err != nil {
|
||||
return nil, &qerr.TransportError{
|
||||
return nil, parsed, &qerr.TransportError{
|
||||
ErrorCode: qerr.FrameEncodingError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
b = b[l:]
|
||||
if typ == 0x0 { // skip PADDING frames
|
||||
continue
|
||||
}
|
||||
|
||||
f, err := p.parseFrame(r, typ, encLevel, v)
|
||||
f, l, err := p.parseFrame(b, typ, encLevel, v)
|
||||
parsed += l
|
||||
if err != nil {
|
||||
return nil, &qerr.TransportError{
|
||||
return nil, parsed, &qerr.TransportError{
|
||||
FrameType: typ,
|
||||
ErrorCode: qerr.FrameEncodingError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
return f, nil
|
||||
return f, parsed, nil
|
||||
}
|
||||
return nil, nil
|
||||
return nil, parsed, nil
|
||||
}
|
||||
|
||||
func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
|
||||
func (p *FrameParser) parseFrame(b []byte, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
|
||||
var frame Frame
|
||||
var err error
|
||||
var l int
|
||||
if typ&0xf8 == 0x8 {
|
||||
frame, err = parseStreamFrame(r, typ, v)
|
||||
frame, l, err = parseStreamFrame(b, typ, v)
|
||||
} else {
|
||||
switch typ {
|
||||
case pingFrameType:
|
||||
|
@ -109,43 +107,43 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
|
|||
ackDelayExponent = protocol.DefaultAckDelayExponent
|
||||
}
|
||||
p.ackFrame.Reset()
|
||||
err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v)
|
||||
l, err = parseAckFrame(p.ackFrame, b, typ, ackDelayExponent, v)
|
||||
frame = p.ackFrame
|
||||
case resetStreamFrameType:
|
||||
frame, err = parseResetStreamFrame(r, v)
|
||||
frame, l, err = parseResetStreamFrame(b, v)
|
||||
case stopSendingFrameType:
|
||||
frame, err = parseStopSendingFrame(r, v)
|
||||
frame, l, err = parseStopSendingFrame(b, v)
|
||||
case cryptoFrameType:
|
||||
frame, err = parseCryptoFrame(r, v)
|
||||
frame, l, err = parseCryptoFrame(b, v)
|
||||
case newTokenFrameType:
|
||||
frame, err = parseNewTokenFrame(r, v)
|
||||
frame, l, err = parseNewTokenFrame(b, v)
|
||||
case maxDataFrameType:
|
||||
frame, err = parseMaxDataFrame(r, v)
|
||||
frame, l, err = parseMaxDataFrame(b, v)
|
||||
case maxStreamDataFrameType:
|
||||
frame, err = parseMaxStreamDataFrame(r, v)
|
||||
frame, l, err = parseMaxStreamDataFrame(b, v)
|
||||
case bidiMaxStreamsFrameType, uniMaxStreamsFrameType:
|
||||
frame, err = parseMaxStreamsFrame(r, typ, v)
|
||||
frame, l, err = parseMaxStreamsFrame(b, typ, v)
|
||||
case dataBlockedFrameType:
|
||||
frame, err = parseDataBlockedFrame(r, v)
|
||||
frame, l, err = parseDataBlockedFrame(b, v)
|
||||
case streamDataBlockedFrameType:
|
||||
frame, err = parseStreamDataBlockedFrame(r, v)
|
||||
frame, l, err = parseStreamDataBlockedFrame(b, v)
|
||||
case bidiStreamBlockedFrameType, uniStreamBlockedFrameType:
|
||||
frame, err = parseStreamsBlockedFrame(r, typ, v)
|
||||
frame, l, err = parseStreamsBlockedFrame(b, typ, v)
|
||||
case newConnectionIDFrameType:
|
||||
frame, err = parseNewConnectionIDFrame(r, v)
|
||||
frame, l, err = parseNewConnectionIDFrame(b, v)
|
||||
case retireConnectionIDFrameType:
|
||||
frame, err = parseRetireConnectionIDFrame(r, v)
|
||||
frame, l, err = parseRetireConnectionIDFrame(b, v)
|
||||
case pathChallengeFrameType:
|
||||
frame, err = parsePathChallengeFrame(r, v)
|
||||
frame, l, err = parsePathChallengeFrame(b, v)
|
||||
case pathResponseFrameType:
|
||||
frame, err = parsePathResponseFrame(r, v)
|
||||
frame, l, err = parsePathResponseFrame(b, v)
|
||||
case connectionCloseFrameType, applicationCloseFrameType:
|
||||
frame, err = parseConnectionCloseFrame(r, typ, v)
|
||||
frame, l, err = parseConnectionCloseFrame(b, typ, v)
|
||||
case handshakeDoneFrameType:
|
||||
frame = &HandshakeDoneFrame{}
|
||||
case 0x30, 0x31:
|
||||
if p.supportsDatagrams {
|
||||
frame, err = parseDatagramFrame(r, typ, v)
|
||||
frame, l, err = parseDatagramFrame(b, typ, v)
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
|
@ -154,12 +152,12 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
|
|||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
if !p.isAllowedAtEncLevel(frame, encLevel) {
|
||||
return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
|
||||
return nil, l, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
|
||||
}
|
||||
return frame, nil
|
||||
return frame, l, nil
|
||||
}
|
||||
|
||||
func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
|
||||
|
@ -190,3 +188,10 @@ func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL
|
|||
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
|
||||
p.ackDelayExponent = exp
|
||||
}
|
||||
|
||||
func replaceUnexpectedEOF(e error) error {
|
||||
if e == io.ErrUnexpectedEOF {
|
||||
return io.EOF
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
|
@ -139,18 +138,18 @@ type Header struct {
|
|||
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
|
||||
}
|
||||
|
||||
// ParsePacket parses a packet.
|
||||
// If the packet has a long header, the packet is cut according to the length field.
|
||||
// If we understand the version, the packet is header up unto the packet number.
|
||||
// ParsePacket parses a long header packet.
|
||||
// The packet is cut according to the length field.
|
||||
// If we understand the version, the packet is parsed up unto the packet number.
|
||||
// Otherwise, only the invariant part of the header is parsed.
|
||||
func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
|
||||
if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
|
||||
return nil, nil, nil, errors.New("not a long header packet")
|
||||
}
|
||||
hdr, err := parseHeader(bytes.NewReader(data))
|
||||
hdr, err := parseHeader(data)
|
||||
if err != nil {
|
||||
if err == ErrUnsupportedVersion {
|
||||
return hdr, nil, nil, ErrUnsupportedVersion
|
||||
if errors.Is(err, ErrUnsupportedVersion) {
|
||||
return hdr, nil, nil, err
|
||||
}
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
@ -161,55 +160,55 @@ func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
|
|||
return hdr, data[:packetLen], data[packetLen:], nil
|
||||
}
|
||||
|
||||
// ParseHeader parses the header.
|
||||
// For short header packets: up to the packet number.
|
||||
// For long header packets:
|
||||
// ParseHeader parses the header:
|
||||
// * if we understand the version: up to the packet number
|
||||
// * if not, only the invariant part of the header
|
||||
func parseHeader(b *bytes.Reader) (*Header, error) {
|
||||
startLen := b.Len()
|
||||
typeByte, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func parseHeader(b []byte) (*Header, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, io.EOF
|
||||
}
|
||||
typeByte := b[0]
|
||||
|
||||
h := &Header{typeByte: typeByte}
|
||||
err = h.parseLongHeader(b)
|
||||
h.parsedLen = protocol.ByteCount(startLen - b.Len())
|
||||
l, err := h.parseLongHeader(b[1:])
|
||||
h.parsedLen = protocol.ByteCount(l) + 1
|
||||
return h, err
|
||||
}
|
||||
|
||||
func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
||||
v, err := utils.BigEndian.ReadUint32(b)
|
||||
if err != nil {
|
||||
return err
|
||||
func (h *Header) parseLongHeader(b []byte) (int, error) {
|
||||
startLen := len(b)
|
||||
if len(b) < 5 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
h.Version = protocol.Version(v)
|
||||
h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
|
||||
if h.Version != 0 && h.typeByte&0x40 == 0 {
|
||||
return errors.New("not a QUIC packet")
|
||||
return startLen - len(b), errors.New("not a QUIC packet")
|
||||
}
|
||||
destConnIDLen, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
destConnIDLen := int(b[4])
|
||||
if destConnIDLen > protocol.MaxConnIDLen {
|
||||
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
|
||||
}
|
||||
h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
|
||||
if err != nil {
|
||||
return err
|
||||
b = b[5:]
|
||||
if len(b) < destConnIDLen+1 {
|
||||
return startLen - len(b), io.EOF
|
||||
}
|
||||
srcConnIDLen, err := b.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
|
||||
srcConnIDLen := int(b[destConnIDLen])
|
||||
if srcConnIDLen > protocol.MaxConnIDLen {
|
||||
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
|
||||
}
|
||||
h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
|
||||
if err != nil {
|
||||
return err
|
||||
b = b[destConnIDLen+1:]
|
||||
if len(b) < srcConnIDLen {
|
||||
return startLen - len(b), io.EOF
|
||||
}
|
||||
h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
|
||||
b = b[srcConnIDLen:]
|
||||
if h.Version == 0 { // version negotiation packet
|
||||
return nil
|
||||
return startLen - len(b), nil
|
||||
}
|
||||
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
|
||||
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
|
||||
return ErrUnsupportedVersion
|
||||
return startLen - len(b), ErrUnsupportedVersion
|
||||
}
|
||||
|
||||
if h.Version == protocol.Version2 {
|
||||
|
@ -237,38 +236,35 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
|||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeRetry {
|
||||
tokenLen := b.Len() - 16
|
||||
tokenLen := len(b) - 16
|
||||
if tokenLen <= 0 {
|
||||
return io.EOF
|
||||
return startLen - len(b), io.EOF
|
||||
}
|
||||
h.Token = make([]byte, tokenLen)
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := b.Seek(16, io.SeekCurrent)
|
||||
return err
|
||||
copy(h.Token, b[:tokenLen])
|
||||
return startLen - len(b) + tokenLen + 16, nil
|
||||
}
|
||||
|
||||
if h.Type == protocol.PacketTypeInitial {
|
||||
tokenLen, err := quicvarint.Read(b)
|
||||
tokenLen, n, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return startLen - len(b), err
|
||||
}
|
||||
if tokenLen > uint64(b.Len()) {
|
||||
return io.EOF
|
||||
b = b[n:]
|
||||
if tokenLen > uint64(len(b)) {
|
||||
return startLen - len(b), io.EOF
|
||||
}
|
||||
h.Token = make([]byte, tokenLen)
|
||||
if _, err := io.ReadFull(b, h.Token); err != nil {
|
||||
return err
|
||||
}
|
||||
copy(h.Token, b[:tokenLen])
|
||||
b = b[tokenLen:]
|
||||
}
|
||||
|
||||
pl, err := quicvarint.Read(b)
|
||||
pl, n, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
h.Length = protocol.ByteCount(pl)
|
||||
return nil
|
||||
return startLen - len(b) + n, nil
|
||||
}
|
||||
|
||||
// ParsedLen returns the number of bytes that were consumed when parsing the header
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
@ -13,14 +11,14 @@ type MaxDataFrame struct {
|
|||
}
|
||||
|
||||
// parseMaxDataFrame parses a MAX_DATA frame
|
||||
func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) {
|
||||
func parseMaxDataFrame(b []byte, _ protocol.Version) (*MaxDataFrame, int, error) {
|
||||
frame := &MaxDataFrame{}
|
||||
byteOffset, err := quicvarint.Read(r)
|
||||
byteOffset, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
frame.MaximumData = protocol.ByteCount(byteOffset)
|
||||
return frame, nil
|
||||
return frame, l, nil
|
||||
}
|
||||
|
||||
func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
@ -13,20 +11,23 @@ type MaxStreamDataFrame struct {
|
|||
MaximumStreamData protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) {
|
||||
sid, err := quicvarint.Read(r)
|
||||
func parseMaxStreamDataFrame(b []byte, _ protocol.Version) (*MaxStreamDataFrame, int, error) {
|
||||
startLen := len(b)
|
||||
sid, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
offset, err := quicvarint.Read(r)
|
||||
b = b[l:]
|
||||
offset, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
|
||||
return &MaxStreamDataFrame{
|
||||
StreamID: protocol.StreamID(sid),
|
||||
MaximumStreamData: protocol.ByteCount(offset),
|
||||
}, nil
|
||||
}, startLen - len(b), nil
|
||||
}
|
||||
|
||||
func (f *MaxStreamDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -14,7 +13,7 @@ type MaxStreamsFrame struct {
|
|||
MaxStreamNum protocol.StreamNum
|
||||
}
|
||||
|
||||
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) {
|
||||
func parseMaxStreamsFrame(b []byte, typ uint64, _ protocol.Version) (*MaxStreamsFrame, int, error) {
|
||||
f := &MaxStreamsFrame{}
|
||||
switch typ {
|
||||
case bidiMaxStreamsFrameType:
|
||||
|
@ -22,15 +21,15 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*Max
|
|||
case uniMaxStreamsFrameType:
|
||||
f.Type = protocol.StreamTypeUni
|
||||
}
|
||||
streamID, err := quicvarint.Read(r)
|
||||
streamID, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
f.MaxStreamNum = protocol.StreamNum(streamID)
|
||||
if f.MaxStreamNum > protocol.MaxStreamCount {
|
||||
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
|
||||
return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
|
||||
}
|
||||
return f, nil
|
||||
return f, l, nil
|
||||
}
|
||||
|
||||
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -18,43 +17,47 @@ type NewConnectionIDFrame struct {
|
|||
StatelessResetToken protocol.StatelessResetToken
|
||||
}
|
||||
|
||||
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) {
|
||||
seq, err := quicvarint.Read(r)
|
||||
func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFrame, int, error) {
|
||||
startLen := len(b)
|
||||
seq, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
ret, err := quicvarint.Read(r)
|
||||
b = b[l:]
|
||||
ret, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
if ret > seq {
|
||||
//nolint:stylecheck
|
||||
return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
|
||||
return nil, 0, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
|
||||
}
|
||||
connIDLen, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(b) == 0 {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
connIDLen := int(b[0])
|
||||
b = b[1:]
|
||||
if connIDLen == 0 {
|
||||
return nil, errors.New("invalid zero-length connection ID")
|
||||
return nil, 0, errors.New("invalid zero-length connection ID")
|
||||
}
|
||||
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if connIDLen > protocol.MaxConnIDLen {
|
||||
return nil, 0, protocol.ErrInvalidConnectionIDLen
|
||||
}
|
||||
if len(b) < connIDLen {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
frame := &NewConnectionIDFrame{
|
||||
SequenceNumber: seq,
|
||||
RetirePriorTo: ret,
|
||||
ConnectionID: connID,
|
||||
ConnectionID: protocol.ParseConnectionID(b[:connIDLen]),
|
||||
}
|
||||
if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
b = b[connIDLen:]
|
||||
if len(b) < len(frame.StatelessResetToken) {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
|
||||
return frame, nil
|
||||
copy(frame.StatelessResetToken[:], b)
|
||||
return frame, startLen - len(b) + len(frame.StatelessResetToken), nil
|
||||
}
|
||||
|
||||
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
|
@ -14,22 +13,21 @@ type NewTokenFrame struct {
|
|||
Token []byte
|
||||
}
|
||||
|
||||
func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) {
|
||||
tokenLen, err := quicvarint.Read(r)
|
||||
func parseNewTokenFrame(b []byte, _ protocol.Version) (*NewTokenFrame, int, error) {
|
||||
tokenLen, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if uint64(r.Len()) < tokenLen {
|
||||
return nil, io.EOF
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
if tokenLen == 0 {
|
||||
return nil, errors.New("token must not be empty")
|
||||
return nil, 0, errors.New("token must not be empty")
|
||||
}
|
||||
if uint64(len(b)) < tokenLen {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
token := make([]byte, int(tokenLen))
|
||||
if _, err := io.ReadFull(r, token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &NewTokenFrame{Token: token}, nil
|
||||
copy(token, b)
|
||||
return &NewTokenFrame{Token: token}, l + int(tokenLen), nil
|
||||
}
|
||||
|
||||
func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -12,15 +11,13 @@ type PathChallengeFrame struct {
|
|||
Data [8]byte
|
||||
}
|
||||
|
||||
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) {
|
||||
frame := &PathChallengeFrame{}
|
||||
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
func parsePathChallengeFrame(b []byte, _ protocol.Version) (*PathChallengeFrame, int, error) {
|
||||
f := &PathChallengeFrame{}
|
||||
if len(b) < 8 {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
return frame, nil
|
||||
copy(f.Data[:], b)
|
||||
return f, 8, nil
|
||||
}
|
||||
|
||||
func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -12,15 +11,13 @@ type PathResponseFrame struct {
|
|||
Data [8]byte
|
||||
}
|
||||
|
||||
func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) {
|
||||
frame := &PathResponseFrame{}
|
||||
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return nil, err
|
||||
func parsePathResponseFrame(b []byte, _ protocol.Version) (*PathResponseFrame, int, error) {
|
||||
f := &PathResponseFrame{}
|
||||
if len(b) < 8 {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
return frame, nil
|
||||
copy(f.Data[:], b)
|
||||
return f, 8, nil
|
||||
}
|
||||
|
||||
func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
@ -15,21 +13,24 @@ type ResetStreamFrame struct {
|
|||
FinalSize protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) {
|
||||
func parseResetStreamFrame(b []byte, _ protocol.Version) (*ResetStreamFrame, int, error) {
|
||||
startLen := len(b)
|
||||
var streamID protocol.StreamID
|
||||
var byteOffset protocol.ByteCount
|
||||
sid, err := quicvarint.Read(r)
|
||||
sid, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
streamID = protocol.StreamID(sid)
|
||||
errorCode, err := quicvarint.Read(r)
|
||||
errorCode, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
bo, err := quicvarint.Read(r)
|
||||
b = b[l:]
|
||||
bo, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
byteOffset = protocol.ByteCount(bo)
|
||||
|
||||
|
@ -37,7 +38,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFra
|
|||
StreamID: streamID,
|
||||
ErrorCode: qerr.StreamErrorCode(errorCode),
|
||||
FinalSize: byteOffset,
|
||||
}, nil
|
||||
}, startLen - len(b) + l, nil
|
||||
}
|
||||
|
||||
func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
@ -12,12 +10,12 @@ type RetireConnectionIDFrame struct {
|
|||
SequenceNumber uint64
|
||||
}
|
||||
|
||||
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) {
|
||||
seq, err := quicvarint.Read(r)
|
||||
func parseRetireConnectionIDFrame(b []byte, _ protocol.Version) (*RetireConnectionIDFrame, int, error) {
|
||||
seq, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
|
||||
return &RetireConnectionIDFrame{SequenceNumber: seq}, l, nil
|
||||
}
|
||||
|
||||
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
|
@ -15,20 +13,23 @@ type StopSendingFrame struct {
|
|||
}
|
||||
|
||||
// parseStopSendingFrame parses a STOP_SENDING frame
|
||||
func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) {
|
||||
streamID, err := quicvarint.Read(r)
|
||||
func parseStopSendingFrame(b []byte, _ protocol.Version) (*StopSendingFrame, int, error) {
|
||||
startLen := len(b)
|
||||
streamID, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
errorCode, err := quicvarint.Read(r)
|
||||
b = b[l:]
|
||||
errorCode, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
|
||||
return &StopSendingFrame{
|
||||
StreamID: protocol.StreamID(streamID),
|
||||
ErrorCode: qerr.StreamErrorCode(errorCode),
|
||||
}, nil
|
||||
}, startLen - len(b), nil
|
||||
}
|
||||
|
||||
// Length of a written frame
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
@ -13,20 +11,22 @@ type StreamDataBlockedFrame struct {
|
|||
MaximumStreamData protocol.ByteCount
|
||||
}
|
||||
|
||||
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) {
|
||||
sid, err := quicvarint.Read(r)
|
||||
func parseStreamDataBlockedFrame(b []byte, _ protocol.Version) (*StreamDataBlockedFrame, int, error) {
|
||||
startLen := len(b)
|
||||
sid, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
offset, err := quicvarint.Read(r)
|
||||
b = b[l:]
|
||||
offset, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
return &StreamDataBlockedFrame{
|
||||
StreamID: protocol.StreamID(sid),
|
||||
MaximumStreamData: protocol.ByteCount(offset),
|
||||
}, nil
|
||||
}, startLen - len(b) + l, nil
|
||||
}
|
||||
|
||||
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
|
@ -20,33 +19,41 @@ type StreamFrame struct {
|
|||
fromPool bool
|
||||
}
|
||||
|
||||
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) {
|
||||
func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, int, error) {
|
||||
startLen := len(b)
|
||||
hasOffset := typ&0b100 > 0
|
||||
fin := typ&0b1 > 0
|
||||
hasDataLen := typ&0b10 > 0
|
||||
|
||||
streamID, err := quicvarint.Read(r)
|
||||
streamID, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
var offset uint64
|
||||
if hasOffset {
|
||||
offset, err = quicvarint.Read(r)
|
||||
offset, l, err = quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
}
|
||||
|
||||
var dataLen uint64
|
||||
if hasDataLen {
|
||||
var err error
|
||||
dataLen, err = quicvarint.Read(r)
|
||||
var l int
|
||||
dataLen, l, err = quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
b = b[l:]
|
||||
if dataLen > uint64(len(b)) {
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
} else {
|
||||
// The rest of the packet is data
|
||||
dataLen = uint64(r.Len())
|
||||
dataLen = uint64(len(b))
|
||||
}
|
||||
|
||||
var frame *StreamFrame
|
||||
|
@ -57,7 +64,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamF
|
|||
// The STREAM frame can't be larger than the StreamFrame we obtained from the buffer,
|
||||
// since those StreamFrames have a buffer length of the maximum packet size.
|
||||
if dataLen > uint64(cap(frame.Data)) {
|
||||
return nil, io.EOF
|
||||
return nil, 0, io.EOF
|
||||
}
|
||||
frame.Data = frame.Data[:dataLen]
|
||||
}
|
||||
|
@ -68,17 +75,14 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamF
|
|||
frame.DataLenPresent = hasDataLen
|
||||
|
||||
if dataLen != 0 {
|
||||
if _, err := io.ReadFull(r, frame.Data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(frame.Data, b)
|
||||
}
|
||||
if frame.Offset+frame.DataLen() > protocol.MaxByteCount {
|
||||
return nil, errors.New("stream data overflows maximum offset")
|
||||
return nil, 0, errors.New("stream data overflows maximum offset")
|
||||
}
|
||||
return frame, nil
|
||||
return frame, startLen - len(b) + int(dataLen), nil
|
||||
}
|
||||
|
||||
// Write writes a STREAM frame
|
||||
func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
if len(f.Data) == 0 && !f.Fin {
|
||||
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -14,7 +13,7 @@ type StreamsBlockedFrame struct {
|
|||
StreamLimit protocol.StreamNum
|
||||
}
|
||||
|
||||
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) {
|
||||
func parseStreamsBlockedFrame(b []byte, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, int, error) {
|
||||
f := &StreamsBlockedFrame{}
|
||||
switch typ {
|
||||
case bidiStreamBlockedFrameType:
|
||||
|
@ -22,15 +21,15 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (
|
|||
case uniStreamBlockedFrameType:
|
||||
f.Type = protocol.StreamTypeUni
|
||||
}
|
||||
streamLimit, err := quicvarint.Read(r)
|
||||
streamLimit, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, replaceUnexpectedEOF(err)
|
||||
}
|
||||
f.StreamLimit = protocol.StreamNum(streamLimit)
|
||||
if f.StreamLimit > protocol.MaxStreamCount {
|
||||
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit)
|
||||
return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit)
|
||||
}
|
||||
return f, nil
|
||||
return f, l, nil
|
||||
}
|
||||
|
||||
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
|
@ -89,7 +87,7 @@ type TransportParameters struct {
|
|||
|
||||
// Unmarshal the transport parameters
|
||||
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
|
||||
if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil {
|
||||
if err := p.unmarshal(data, sentBy, false); err != nil {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.TransportParameterError,
|
||||
ErrorMessage: err.Error(),
|
||||
|
@ -98,9 +96,9 @@ func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error {
|
||||
func (p *TransportParameters) unmarshal(b []byte, sentBy protocol.Perspective, fromSessionTicket bool) error {
|
||||
// needed to check that every parameter is only sent at most once
|
||||
var parameterIDs []transportParameterID
|
||||
parameterIDs := make([]transportParameterID, 0, 32)
|
||||
|
||||
var (
|
||||
readOriginalDestinationConnectionID bool
|
||||
|
@ -112,18 +110,20 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
|
|||
p.MaxAckDelay = protocol.DefaultMaxAckDelay
|
||||
p.MaxDatagramFrameSize = protocol.InvalidByteCount
|
||||
|
||||
for r.Len() > 0 {
|
||||
paramIDInt, err := quicvarint.Read(r)
|
||||
for len(b) > 0 {
|
||||
paramIDInt, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paramID := transportParameterID(paramIDInt)
|
||||
paramLen, err := quicvarint.Read(r)
|
||||
b = b[l:]
|
||||
paramLen, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if uint64(r.Len()) < paramLen {
|
||||
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
|
||||
b = b[l:]
|
||||
if uint64(len(b)) < paramLen {
|
||||
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(b), paramLen)
|
||||
}
|
||||
parameterIDs = append(parameterIDs, paramID)
|
||||
switch paramID {
|
||||
|
@ -141,16 +141,18 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
|
|||
maxAckDelayParameterID,
|
||||
maxDatagramFrameSizeParameterID,
|
||||
ackDelayExponentParameterID:
|
||||
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
|
||||
if err := p.readNumericTransportParameter(b, paramID, int(paramLen)); err != nil {
|
||||
return err
|
||||
}
|
||||
b = b[paramLen:]
|
||||
case preferredAddressParameterID:
|
||||
if sentBy == protocol.PerspectiveClient {
|
||||
return errors.New("client sent a preferred_address")
|
||||
}
|
||||
if err := p.readPreferredAddress(r, int(paramLen)); err != nil {
|
||||
if err := p.readPreferredAddress(b, int(paramLen)); err != nil {
|
||||
return err
|
||||
}
|
||||
b = b[paramLen:]
|
||||
case disableActiveMigrationParameterID:
|
||||
if paramLen != 0 {
|
||||
return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen)
|
||||
|
@ -164,25 +166,41 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
|
|||
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
|
||||
}
|
||||
var token protocol.StatelessResetToken
|
||||
r.Read(token[:])
|
||||
if len(b) < len(token) {
|
||||
return io.EOF
|
||||
}
|
||||
copy(token[:], b)
|
||||
b = b[len(token):]
|
||||
p.StatelessResetToken = &token
|
||||
case originalDestinationConnectionIDParameterID:
|
||||
if sentBy == protocol.PerspectiveClient {
|
||||
return errors.New("client sent an original_destination_connection_id")
|
||||
}
|
||||
p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
|
||||
if paramLen > protocol.MaxConnIDLen {
|
||||
return protocol.ErrInvalidConnectionIDLen
|
||||
}
|
||||
p.OriginalDestinationConnectionID = protocol.ParseConnectionID(b[:paramLen])
|
||||
b = b[paramLen:]
|
||||
readOriginalDestinationConnectionID = true
|
||||
case initialSourceConnectionIDParameterID:
|
||||
p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
|
||||
if paramLen > protocol.MaxConnIDLen {
|
||||
return protocol.ErrInvalidConnectionIDLen
|
||||
}
|
||||
p.InitialSourceConnectionID = protocol.ParseConnectionID(b[:paramLen])
|
||||
b = b[paramLen:]
|
||||
readInitialSourceConnectionID = true
|
||||
case retrySourceConnectionIDParameterID:
|
||||
if sentBy == protocol.PerspectiveClient {
|
||||
return errors.New("client sent a retry_source_connection_id")
|
||||
}
|
||||
connID, _ := protocol.ReadConnectionID(r, int(paramLen))
|
||||
if paramLen > protocol.MaxConnIDLen {
|
||||
return protocol.ErrInvalidConnectionIDLen
|
||||
}
|
||||
connID := protocol.ParseConnectionID(b[:paramLen])
|
||||
b = b[paramLen:]
|
||||
p.RetrySourceConnectionID = &connID
|
||||
default:
|
||||
r.Seek(int64(paramLen), io.SeekCurrent)
|
||||
b = b[paramLen:]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -202,7 +220,12 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
|
|||
}
|
||||
|
||||
// check that every transport parameter was sent at most once
|
||||
sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] })
|
||||
slices.SortFunc(parameterIDs, func(a, b transportParameterID) int {
|
||||
if a < b {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
})
|
||||
for i := 0; i < len(parameterIDs)-1; i++ {
|
||||
if parameterIDs[i] == parameterIDs[i+1] {
|
||||
return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i])
|
||||
|
@ -212,60 +235,47 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
|
||||
remainingLen := r.Len()
|
||||
func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) error {
|
||||
remainingLen := len(b)
|
||||
pa := &PreferredAddress{}
|
||||
if len(b) < 4+2+16+2+1 {
|
||||
return io.EOF
|
||||
}
|
||||
var ipv4 [4]byte
|
||||
if _, err := io.ReadFull(r, ipv4[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
port, err := utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
|
||||
copy(ipv4[:], b[:4])
|
||||
port4 := binary.BigEndian.Uint16(b[4:])
|
||||
b = b[4+2:]
|
||||
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4)
|
||||
var ipv6 [16]byte
|
||||
if _, err := io.ReadFull(r, ipv6[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
port, err = utils.BigEndian.ReadUint16(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
|
||||
connIDLen, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
copy(ipv6[:], b[:16])
|
||||
port6 := binary.BigEndian.Uint16(b[16:])
|
||||
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6)
|
||||
b = b[16+2:]
|
||||
connIDLen := int(b[0])
|
||||
b = b[1:]
|
||||
if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen {
|
||||
return fmt.Errorf("invalid connection ID length: %d", connIDLen)
|
||||
}
|
||||
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
|
||||
if err != nil {
|
||||
return err
|
||||
if len(b) < connIDLen+len(pa.StatelessResetToken) {
|
||||
return io.EOF
|
||||
}
|
||||
pa.ConnectionID = connID
|
||||
if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen {
|
||||
pa.ConnectionID = protocol.ParseConnectionID(b[:connIDLen])
|
||||
b = b[connIDLen:]
|
||||
copy(pa.StatelessResetToken[:], b)
|
||||
b = b[len(pa.StatelessResetToken):]
|
||||
if bytesRead := remainingLen - len(b); bytesRead != expectedLen {
|
||||
return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead)
|
||||
}
|
||||
p.PreferredAddress = pa
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TransportParameters) readNumericTransportParameter(
|
||||
r *bytes.Reader,
|
||||
paramID transportParameterID,
|
||||
expectedLen int,
|
||||
) error {
|
||||
remainingLen := r.Len()
|
||||
val, err := quicvarint.Read(r)
|
||||
func (p *TransportParameters) readNumericTransportParameter(b []byte, paramID transportParameterID, expectedLen int) error {
|
||||
val, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
|
||||
}
|
||||
if remainingLen-r.Len() != expectedLen {
|
||||
if l != expectedLen {
|
||||
return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID)
|
||||
}
|
||||
//nolint:exhaustive // This only covers the numeric transport parameters.
|
||||
|
@ -292,7 +302,7 @@ func (p *TransportParameters) readNumericTransportParameter(
|
|||
p.MaxIdleTimeout = max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond)
|
||||
case maxUDPPayloadSizeParameterID:
|
||||
if val < 1200 {
|
||||
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
|
||||
return fmt.Errorf("invalid value for max_udp_payload_size: %d (minimum 1200)", val)
|
||||
}
|
||||
p.MaxUDPPayloadSize = protocol.ByteCount(val)
|
||||
case ackDelayExponentParameterID:
|
||||
|
@ -347,8 +357,10 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
|
|||
b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
|
||||
// idle_timeout
|
||||
b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond))
|
||||
// max_packet_size
|
||||
b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize))
|
||||
// max_udp_payload_size
|
||||
if p.MaxUDPPayloadSize > 0 {
|
||||
b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(p.MaxUDPPayloadSize))
|
||||
}
|
||||
// max_ack_delay
|
||||
// Only send it if is different from the default value.
|
||||
if p.MaxAckDelay != protocol.DefaultMaxAckDelay {
|
||||
|
@ -457,15 +469,15 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte {
|
|||
}
|
||||
|
||||
// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
|
||||
func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error {
|
||||
version, err := quicvarint.Read(r)
|
||||
func (p *TransportParameters) UnmarshalFromSessionTicket(b []byte) error {
|
||||
version, l, err := quicvarint.Parse(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if version != transportParameterMarshalingVersion {
|
||||
return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
|
||||
}
|
||||
return p.unmarshal(r, protocol.PerspectiveServer, true)
|
||||
return p.unmarshal(b[l:], protocol.PerspectiveServer, true)
|
||||
}
|
||||
|
||||
// ValidFor0RTT checks if the transport parameters match those saved in the session ticket.
|
||||
|
|
|
@ -24,6 +24,7 @@ type ConnectionTracer struct {
|
|||
UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int)
|
||||
AcknowledgedPacket func(EncryptionLevel, PacketNumber)
|
||||
LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason)
|
||||
UpdatedMTU func(mtu ByteCount, done bool)
|
||||
UpdatedCongestionState func(CongestionState)
|
||||
UpdatedPTOCount func(value uint32)
|
||||
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
|
||||
|
@ -168,6 +169,13 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra
|
|||
}
|
||||
}
|
||||
},
|
||||
UpdatedMTU: func(mtu ByteCount, done bool) {
|
||||
for _, t := range tracers {
|
||||
if t.UpdatedMTU != nil {
|
||||
t.UpdatedMTU(mtu, done)
|
||||
}
|
||||
}
|
||||
},
|
||||
UpdatedCongestionState: func(state CongestionState) {
|
||||
for _, t := range tracers {
|
||||
if t.UpdatedCongestionState != nil {
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/ackhandler"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
)
|
||||
|
||||
type mtuDiscoverer interface {
|
||||
// Start starts the MTU discovery process.
|
||||
// It's unnecessary to call ShouldSendProbe before that.
|
||||
Start(maxPacketSize protocol.ByteCount)
|
||||
Start()
|
||||
ShouldSendProbe(now time.Time) bool
|
||||
CurrentSize() protocol.ByteCount
|
||||
GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount)
|
||||
|
@ -27,20 +27,6 @@ const (
|
|||
mtuProbeDelay = 5
|
||||
)
|
||||
|
||||
func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
|
||||
maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
|
||||
// If this is not a UDP address, we don't know anything about the MTU.
|
||||
// Use the minimum size of an Initial packet as the max packet size.
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
||||
if utils.IsIPv4(udpAddr.IP) {
|
||||
maxSize = protocol.InitialPacketSizeIPv4
|
||||
} else {
|
||||
maxSize = protocol.InitialPacketSizeIPv6
|
||||
}
|
||||
}
|
||||
return maxSize
|
||||
}
|
||||
|
||||
type mtuFinder struct {
|
||||
lastProbeTime time.Time
|
||||
mtuIncreased func(protocol.ByteCount)
|
||||
|
@ -49,16 +35,25 @@ type mtuFinder struct {
|
|||
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
|
||||
current protocol.ByteCount
|
||||
max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer)
|
||||
|
||||
tracer *logging.ConnectionTracer
|
||||
}
|
||||
|
||||
var _ mtuDiscoverer = &mtuFinder{}
|
||||
|
||||
func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
|
||||
func newMTUDiscoverer(
|
||||
rttStats *utils.RTTStats,
|
||||
start, max protocol.ByteCount,
|
||||
mtuIncreased func(protocol.ByteCount),
|
||||
tracer *logging.ConnectionTracer,
|
||||
) *mtuFinder {
|
||||
return &mtuFinder{
|
||||
inFlight: protocol.InvalidByteCount,
|
||||
current: start,
|
||||
max: max,
|
||||
rttStats: rttStats,
|
||||
mtuIncreased: mtuIncreased,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,9 +61,15 @@ func (f *mtuFinder) done() bool {
|
|||
return f.max-f.current <= maxMTUDiff+1
|
||||
}
|
||||
|
||||
func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) {
|
||||
func (f *mtuFinder) SetMax(max protocol.ByteCount) {
|
||||
f.max = max
|
||||
}
|
||||
|
||||
func (f *mtuFinder) Start() {
|
||||
if f.max == protocol.InvalidByteCount {
|
||||
panic("invalid")
|
||||
}
|
||||
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
|
||||
f.max = maxPacketSize
|
||||
}
|
||||
|
||||
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
|
||||
|
@ -87,7 +88,7 @@ func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
|
|||
f.inFlight = size
|
||||
return ackhandler.Frame{
|
||||
Frame: &wire.PingFrame{},
|
||||
Handler: (*mtuFinderAckHandler)(f),
|
||||
Handler: &mtuFinderAckHandler{f},
|
||||
}, size
|
||||
}
|
||||
|
||||
|
@ -95,7 +96,9 @@ func (f *mtuFinder) CurrentSize() protocol.ByteCount {
|
|||
return f.current
|
||||
}
|
||||
|
||||
type mtuFinderAckHandler mtuFinder
|
||||
type mtuFinderAckHandler struct {
|
||||
*mtuFinder
|
||||
}
|
||||
|
||||
var _ ackhandler.FrameHandler = &mtuFinderAckHandler{}
|
||||
|
||||
|
@ -106,6 +109,9 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
|
|||
}
|
||||
h.inFlight = protocol.InvalidByteCount
|
||||
h.current = size
|
||||
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
|
||||
h.tracer.UpdatedMTU(size, h.done())
|
||||
}
|
||||
h.mtuIncreased(size)
|
||||
}
|
||||
|
||||
|
|
|
@ -26,16 +26,16 @@ func Read(r io.ByteReader) (uint64, error) {
|
|||
return 0, err
|
||||
}
|
||||
// the first two bits of the first byte encode the length
|
||||
len := 1 << ((firstByte & 0xc0) >> 6)
|
||||
l := 1 << ((firstByte & 0xc0) >> 6)
|
||||
b1 := firstByte & (0xff - 0xc0)
|
||||
if len == 1 {
|
||||
if l == 1 {
|
||||
return uint64(b1), nil
|
||||
}
|
||||
b2, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len == 2 {
|
||||
if l == 2 {
|
||||
return uint64(b2) + uint64(b1)<<8, nil
|
||||
}
|
||||
b3, err := r.ReadByte()
|
||||
|
@ -46,7 +46,7 @@ func Read(r io.ByteReader) (uint64, error) {
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len == 4 {
|
||||
if l == 4 {
|
||||
return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil
|
||||
}
|
||||
b5, err := r.ReadByte()
|
||||
|
@ -68,6 +68,31 @@ func Read(r io.ByteReader) (uint64, error) {
|
|||
return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil
|
||||
}
|
||||
|
||||
// Parse reads a number in the QUIC varint format.
|
||||
// It returns the number of bytes consumed.
|
||||
func Parse(b []byte) (uint64 /* value */, int /* bytes consumed */, error) {
|
||||
if len(b) == 0 {
|
||||
return 0, 0, io.EOF
|
||||
}
|
||||
firstByte := b[0]
|
||||
// the first two bits of the first byte encode the length
|
||||
l := 1 << ((firstByte & 0xc0) >> 6)
|
||||
if len(b) < l {
|
||||
return 0, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
b0 := firstByte & (0xff - 0xc0)
|
||||
if l == 1 {
|
||||
return uint64(b0), 1, nil
|
||||
}
|
||||
if l == 2 {
|
||||
return uint64(b[1]) + uint64(b0)<<8, 2, nil
|
||||
}
|
||||
if l == 4 {
|
||||
return uint64(b[3]) + uint64(b[2])<<8 + uint64(b[1])<<16 + uint64(b0)<<24, 4, nil
|
||||
}
|
||||
return uint64(b[7]) + uint64(b[6])<<8 + uint64(b[5])<<16 + uint64(b[4])<<24 + uint64(b[3])<<32 + uint64(b[2])<<40 + uint64(b[1])<<48 + uint64(b0)<<56, 8, nil
|
||||
}
|
||||
|
||||
// Append appends i in the QUIC varint format.
|
||||
func Append(b []byte, i uint64) []byte {
|
||||
if i <= maxVarInt1 {
|
||||
|
|
|
@ -225,7 +225,7 @@ func (x *FileSyntax) Cleanup() {
|
|||
if ww == 0 {
|
||||
continue
|
||||
}
|
||||
if ww == 1 {
|
||||
if ww == 1 && len(stmt.RParen.Comments.Before) == 0 {
|
||||
// Collapse block into single line.
|
||||
line := &Line{
|
||||
Comments: Comments{
|
||||
|
|
|
@ -975,6 +975,8 @@ func (f *File) AddGoStmt(version string) error {
|
|||
var hint Expr
|
||||
if f.Module != nil && f.Module.Syntax != nil {
|
||||
hint = f.Module.Syntax
|
||||
} else if f.Syntax == nil {
|
||||
f.Syntax = new(FileSyntax)
|
||||
}
|
||||
f.Go = &Go{
|
||||
Version: version,
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,22 @@
|
|||
Additional IP Rights Grant (Patents)
|
||||
|
||||
"This implementation" means the copyrightable works distributed by
|
||||
Google as part of the Go project.
|
||||
|
||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||
patent license to make, have made, use, offer to sell, sell, import,
|
||||
transfer and otherwise run, modify and propagate the contents of this
|
||||
implementation of Go, where such license applies only to those patent
|
||||
claims, both currently owned or controlled by Google and acquired in
|
||||
the future, licensable by Google that are necessarily infringed by this
|
||||
implementation of Go. This grant does not include claims that would be
|
||||
infringed only as a consequence of further modification of this
|
||||
implementation. If you or your agent or exclusive licensee institute or
|
||||
order or agree to the institution of patent litigation against any
|
||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||
that this implementation of Go or any code incorporated within this
|
||||
implementation of Go constitutes direct or contributory patent
|
||||
infringement, or inducement of patent infringement, then any patent
|
||||
rights granted to you under this License for this implementation of Go
|
||||
shall terminate as of the date such litigation is filed.
|
|
@ -0,0 +1,135 @@
|
|||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package errgroup provides synchronization, error propagation, and Context
|
||||
// cancelation for groups of goroutines working on subtasks of a common task.
|
||||
//
|
||||
// [errgroup.Group] is related to [sync.WaitGroup] but adds handling of tasks
|
||||
// returning errors.
|
||||
package errgroup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type token struct{}
|
||||
|
||||
// A Group is a collection of goroutines working on subtasks that are part of
|
||||
// the same overall task.
|
||||
//
|
||||
// A zero Group is valid, has no limit on the number of active goroutines,
|
||||
// and does not cancel on error.
|
||||
type Group struct {
|
||||
cancel func(error)
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
sem chan token
|
||||
|
||||
errOnce sync.Once
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *Group) done() {
|
||||
if g.sem != nil {
|
||||
<-g.sem
|
||||
}
|
||||
g.wg.Done()
|
||||
}
|
||||
|
||||
// WithContext returns a new Group and an associated Context derived from ctx.
|
||||
//
|
||||
// The derived Context is canceled the first time a function passed to Go
|
||||
// returns a non-nil error or the first time Wait returns, whichever occurs
|
||||
// first.
|
||||
func WithContext(ctx context.Context) (*Group, context.Context) {
|
||||
ctx, cancel := withCancelCause(ctx)
|
||||
return &Group{cancel: cancel}, ctx
|
||||
}
|
||||
|
||||
// Wait blocks until all function calls from the Go method have returned, then
|
||||
// returns the first non-nil error (if any) from them.
|
||||
func (g *Group) Wait() error {
|
||||
g.wg.Wait()
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
return g.err
|
||||
}
|
||||
|
||||
// Go calls the given function in a new goroutine.
|
||||
// It blocks until the new goroutine can be added without the number of
|
||||
// active goroutines in the group exceeding the configured limit.
|
||||
//
|
||||
// The first call to return a non-nil error cancels the group's context, if the
|
||||
// group was created by calling WithContext. The error will be returned by Wait.
|
||||
func (g *Group) Go(f func() error) {
|
||||
if g.sem != nil {
|
||||
g.sem <- token{}
|
||||
}
|
||||
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
defer g.done()
|
||||
|
||||
if err := f(); err != nil {
|
||||
g.errOnce.Do(func() {
|
||||
g.err = err
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// TryGo calls the given function in a new goroutine only if the number of
|
||||
// active goroutines in the group is currently below the configured limit.
|
||||
//
|
||||
// The return value reports whether the goroutine was started.
|
||||
func (g *Group) TryGo(f func() error) bool {
|
||||
if g.sem != nil {
|
||||
select {
|
||||
case g.sem <- token{}:
|
||||
// Note: this allows barging iff channels in general allow barging.
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
g.wg.Add(1)
|
||||
go func() {
|
||||
defer g.done()
|
||||
|
||||
if err := f(); err != nil {
|
||||
g.errOnce.Do(func() {
|
||||
g.err = err
|
||||
if g.cancel != nil {
|
||||
g.cancel(g.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// SetLimit limits the number of active goroutines in this group to at most n.
|
||||
// A negative value indicates no limit.
|
||||
//
|
||||
// Any subsequent call to the Go method will block until it can add an active
|
||||
// goroutine without exceeding the configured limit.
|
||||
//
|
||||
// The limit must not be modified while any goroutines in the group are active.
|
||||
func (g *Group) SetLimit(n int) {
|
||||
if n < 0 {
|
||||
g.sem = nil
|
||||
return
|
||||
}
|
||||
if len(g.sem) != 0 {
|
||||
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
|
||||
}
|
||||
g.sem = make(chan token, n)
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build go1.20
|
||||
|
||||
package errgroup
|
||||
|
||||
import "context"
|
||||
|
||||
func withCancelCause(parent context.Context) (context.Context, func(error)) {
|
||||
return context.WithCancelCause(parent)
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !go1.20
|
||||
|
||||
package errgroup
|
||||
|
||||
import "context"
|
||||
|
||||
func withCancelCause(parent context.Context) (context.Context, func(error)) {
|
||||
ctx, cancel := context.WithCancel(parent)
|
||||
return ctx, func(error) { cancel() }
|
||||
}
|
|
@ -9,6 +9,7 @@ package packages
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
|
@ -24,6 +25,8 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"golang.org/x/tools/go/gcexportdata"
|
||||
"golang.org/x/tools/internal/gocommand"
|
||||
"golang.org/x/tools/internal/packagesinternal"
|
||||
|
@ -126,9 +129,8 @@ type Config struct {
|
|||
Mode LoadMode
|
||||
|
||||
// Context specifies the context for the load operation.
|
||||
// If the context is cancelled, the loader may stop early
|
||||
// and return an ErrCancelled error.
|
||||
// If Context is nil, the load cannot be cancelled.
|
||||
// Cancelling the context may cause [Load] to abort and
|
||||
// return an error.
|
||||
Context context.Context
|
||||
|
||||
// Logf is the logger for the config.
|
||||
|
@ -211,8 +213,8 @@ type Config struct {
|
|||
// Config specifies loading options;
|
||||
// nil behaves the same as an empty Config.
|
||||
//
|
||||
// Load returns an error if any of the patterns was invalid
|
||||
// as defined by the underlying build system.
|
||||
// If any of the patterns was invalid as defined by the
|
||||
// underlying build system, Load returns an error.
|
||||
// It may return an empty list of packages without an error,
|
||||
// for instance for an empty expansion of a valid wildcard.
|
||||
// Errors associated with a particular package are recorded in the
|
||||
|
@ -255,8 +257,27 @@ func Load(cfg *Config, patterns ...string) ([]*Package, error) {
|
|||
// defaultDriver will fall back to the go list driver.
|
||||
// The boolean result indicates that an external driver handled the request.
|
||||
func defaultDriver(cfg *Config, patterns ...string) (*DriverResponse, bool, error) {
|
||||
const (
|
||||
// windowsArgMax specifies the maximum command line length for
|
||||
// the Windows' CreateProcess function.
|
||||
windowsArgMax = 32767
|
||||
// maxEnvSize is a very rough estimation of the maximum environment
|
||||
// size of a user.
|
||||
maxEnvSize = 16384
|
||||
// safeArgMax specifies the maximum safe command line length to use
|
||||
// by the underlying driver excl. the environment. We choose the Windows'
|
||||
// ARG_MAX as the starting point because it's one of the lowest ARG_MAX
|
||||
// constants out of the different supported platforms,
|
||||
// e.g., https://www.in-ulm.de/~mascheck/various/argmax/#results.
|
||||
safeArgMax = windowsArgMax - maxEnvSize
|
||||
)
|
||||
chunks, err := splitIntoChunks(patterns, safeArgMax)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if driver := findExternalDriver(cfg); driver != nil {
|
||||
response, err := driver(cfg, patterns...)
|
||||
response, err := callDriverOnChunks(driver, cfg, chunks)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
} else if !response.NotHandled {
|
||||
|
@ -265,11 +286,82 @@ func defaultDriver(cfg *Config, patterns ...string) (*DriverResponse, bool, erro
|
|||
// (fall through)
|
||||
}
|
||||
|
||||
response, err := goListDriver(cfg, patterns...)
|
||||
response, err := callDriverOnChunks(goListDriver, cfg, chunks)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return response, false, nil
|
||||
return response, false, err
|
||||
}
|
||||
|
||||
// splitIntoChunks chunks the slice so that the total number of characters
|
||||
// in a chunk is no longer than argMax.
|
||||
func splitIntoChunks(patterns []string, argMax int) ([][]string, error) {
|
||||
if argMax <= 0 {
|
||||
return nil, errors.New("failed to split patterns into chunks, negative safe argMax value")
|
||||
}
|
||||
var chunks [][]string
|
||||
charsInChunk := 0
|
||||
nextChunkStart := 0
|
||||
for i, v := range patterns {
|
||||
vChars := len(v)
|
||||
if vChars > argMax {
|
||||
// a single pattern is longer than the maximum safe ARG_MAX, hardly should happen
|
||||
return nil, errors.New("failed to split patterns into chunks, a pattern is too long")
|
||||
}
|
||||
charsInChunk += vChars + 1 // +1 is for a whitespace between patterns that has to be counted too
|
||||
if charsInChunk > argMax {
|
||||
chunks = append(chunks, patterns[nextChunkStart:i])
|
||||
nextChunkStart = i
|
||||
charsInChunk = vChars
|
||||
}
|
||||
}
|
||||
// add the last chunk
|
||||
if nextChunkStart < len(patterns) {
|
||||
chunks = append(chunks, patterns[nextChunkStart:])
|
||||
}
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
func callDriverOnChunks(driver driver, cfg *Config, chunks [][]string) (*DriverResponse, error) {
|
||||
if len(chunks) == 0 {
|
||||
return driver(cfg)
|
||||
}
|
||||
responses := make([]*DriverResponse, len(chunks))
|
||||
errNotHandled := errors.New("driver returned NotHandled")
|
||||
var g errgroup.Group
|
||||
for i, chunk := range chunks {
|
||||
i := i
|
||||
chunk := chunk
|
||||
g.Go(func() (err error) {
|
||||
responses[i], err = driver(cfg, chunk...)
|
||||
if responses[i] != nil && responses[i].NotHandled {
|
||||
err = errNotHandled
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
if err := g.Wait(); err != nil {
|
||||
if errors.Is(err, errNotHandled) {
|
||||
return &DriverResponse{NotHandled: true}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return mergeResponses(responses...), nil
|
||||
}
|
||||
|
||||
func mergeResponses(responses ...*DriverResponse) *DriverResponse {
|
||||
if len(responses) == 0 {
|
||||
return nil
|
||||
}
|
||||
response := newDeduper()
|
||||
response.dr.NotHandled = false
|
||||
response.dr.Compiler = responses[0].Compiler
|
||||
response.dr.Arch = responses[0].Arch
|
||||
response.dr.GoVersion = responses[0].GoVersion
|
||||
for _, v := range responses {
|
||||
response.addAll(v)
|
||||
}
|
||||
return response.dr
|
||||
}
|
||||
|
||||
// A Package describes a loaded Go package.
|
||||
|
@ -335,6 +427,10 @@ type Package struct {
|
|||
// The NeedTypes LoadMode bit sets this field for packages matching the
|
||||
// patterns; type information for dependencies may be missing or incomplete,
|
||||
// unless NeedDeps and NeedImports are also set.
|
||||
//
|
||||
// Each call to [Load] returns a consistent set of type
|
||||
// symbols, as defined by the comment at [types.Identical].
|
||||
// Avoid mixing type information from two or more calls to [Load].
|
||||
Types *types.Package
|
||||
|
||||
// Fset provides position information for Types, TypesInfo, and Syntax.
|
||||
|
@ -761,6 +857,12 @@ func (ld *loader) refine(response *DriverResponse) ([]*Package, error) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
// If the context is done, return its error and
|
||||
// throw out [likely] incomplete packages.
|
||||
if err := ld.Context.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*Package, len(initial))
|
||||
for i, lpkg := range initial {
|
||||
result[i] = lpkg.Package
|
||||
|
@ -856,6 +958,14 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
|
|||
lpkg.Types = types.NewPackage(lpkg.PkgPath, lpkg.Name)
|
||||
lpkg.Fset = ld.Fset
|
||||
|
||||
// Start shutting down if the context is done and do not load
|
||||
// source or export data files.
|
||||
// Packages that import this one will have ld.Context.Err() != nil.
|
||||
// ld.Context.Err() will be returned later by refine.
|
||||
if ld.Context.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Subtle: we populate all Types fields with an empty Package
|
||||
// before loading export data so that export data processing
|
||||
// never has to create a types.Package for an indirect dependency,
|
||||
|
@ -975,6 +1085,13 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
|
|||
return
|
||||
}
|
||||
|
||||
// Start shutting down if the context is done and do not type check.
|
||||
// Packages that import this one will have ld.Context.Err() != nil.
|
||||
// ld.Context.Err() will be returned later by refine.
|
||||
if ld.Context.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
lpkg.TypesInfo = &types.Info{
|
||||
Types: make(map[ast.Expr]types.TypeAndValue),
|
||||
Defs: make(map[*ast.Ident]types.Object),
|
||||
|
@ -1025,7 +1142,7 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
|
|||
Sizes: ld.sizes, // may be nil
|
||||
}
|
||||
if lpkg.Module != nil && lpkg.Module.GoVersion != "" {
|
||||
typesinternal.SetGoVersion(tc, "go"+lpkg.Module.GoVersion)
|
||||
tc.GoVersion = "go" + lpkg.Module.GoVersion
|
||||
}
|
||||
if (ld.Mode & typecheckCgo) != 0 {
|
||||
if !typesinternal.SetUsesCgo(tc) {
|
||||
|
@ -1036,10 +1153,24 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
|
|||
return
|
||||
}
|
||||
}
|
||||
types.NewChecker(tc, ld.Fset, lpkg.Types, lpkg.TypesInfo).Files(lpkg.Syntax)
|
||||
|
||||
typErr := types.NewChecker(tc, ld.Fset, lpkg.Types, lpkg.TypesInfo).Files(lpkg.Syntax)
|
||||
lpkg.importErrors = nil // no longer needed
|
||||
|
||||
// In go/types go1.21 and go1.22, Checker.Files failed fast with a
|
||||
// a "too new" error, without calling tc.Error and without
|
||||
// proceeding to type-check the package (#66525).
|
||||
// We rely on the runtimeVersion error to give the suggested remedy.
|
||||
if typErr != nil && len(lpkg.Errors) == 0 && len(lpkg.Syntax) > 0 {
|
||||
if msg := typErr.Error(); strings.HasPrefix(msg, "package requires newer Go version") {
|
||||
appendError(types.Error{
|
||||
Fset: ld.Fset,
|
||||
Pos: lpkg.Syntax[0].Package,
|
||||
Msg: msg,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// If !Cgo, the type-checker uses FakeImportC mode, so
|
||||
// it doesn't invoke the importer for import "C",
|
||||
// nor report an error for the import,
|
||||
|
@ -1061,6 +1192,12 @@ func (ld *loader) loadPackage(lpkg *loaderPackage) {
|
|||
}
|
||||
}
|
||||
|
||||
// If types.Checker.Files had an error that was unreported,
|
||||
// make sure to report the unknown error so the package is illTyped.
|
||||
if typErr != nil && len(lpkg.Errors) == 0 {
|
||||
appendError(typErr)
|
||||
}
|
||||
|
||||
// Record accumulated errors.
|
||||
illTyped := len(lpkg.Errors) > 0
|
||||
if !illTyped {
|
||||
|
@ -1132,11 +1269,6 @@ func (ld *loader) parseFiles(filenames []string) ([]*ast.File, []error) {
|
|||
parsed := make([]*ast.File, n)
|
||||
errors := make([]error, n)
|
||||
for i, file := range filenames {
|
||||
if ld.Config.Context.Err() != nil {
|
||||
parsed[i] = nil
|
||||
errors[i] = ld.Config.Context.Err()
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(i int, filename string) {
|
||||
parsed[i], errors[i] = ld.parseFile(filename)
|
||||
|
|
|
@ -30,7 +30,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"golang.org/x/tools/internal/aliases"
|
||||
"golang.org/x/tools/internal/typeparams"
|
||||
"golang.org/x/tools/internal/typesinternal"
|
||||
)
|
||||
|
||||
|
@ -395,7 +394,7 @@ func (enc *Encoder) concreteMethod(meth *types.Func) (Path, bool) {
|
|||
// of objectpath will only be giving us origin methods, anyway, as referring
|
||||
// to instantiated methods is usually not useful.
|
||||
|
||||
if typeparams.OriginMethod(meth) != meth {
|
||||
if meth.Origin() != meth {
|
||||
return "", false
|
||||
}
|
||||
|
||||
|
|
|
@ -16,10 +16,14 @@ import (
|
|||
// NewAlias creates a new TypeName in Package pkg that
|
||||
// is an alias for the type rhs.
|
||||
//
|
||||
// When GoVersion>=1.22 and GODEBUG=gotypesalias=1,
|
||||
// the Type() of the return value is a *types.Alias.
|
||||
func NewAlias(pos token.Pos, pkg *types.Package, name string, rhs types.Type) *types.TypeName {
|
||||
if enabled() {
|
||||
// The enabled parameter determines whether the resulting [TypeName]'s
|
||||
// type is an [types.Alias]. Its value must be the result of a call to
|
||||
// [Enabled], which computes the effective value of
|
||||
// GODEBUG=gotypesalias=... by invoking the type checker. The Enabled
|
||||
// function is expensive and should be called once per task (e.g.
|
||||
// package import), not once per call to NewAlias.
|
||||
func NewAlias(enabled bool, pos token.Pos, pkg *types.Package, name string, rhs types.Type) *types.TypeName {
|
||||
if enabled {
|
||||
tname := types.NewTypeName(pos, pkg, name, nil)
|
||||
newAlias(tname, rhs)
|
||||
return tname
|
||||
|
|
|
@ -15,16 +15,17 @@ import (
|
|||
// It will never be created by go/types.
|
||||
type Alias struct{}
|
||||
|
||||
func (*Alias) String() string { panic("unreachable") }
|
||||
|
||||
func (*Alias) String() string { panic("unreachable") }
|
||||
func (*Alias) Underlying() types.Type { panic("unreachable") }
|
||||
|
||||
func (*Alias) Obj() *types.TypeName { panic("unreachable") }
|
||||
func (*Alias) Obj() *types.TypeName { panic("unreachable") }
|
||||
func Rhs(alias *Alias) types.Type { panic("unreachable") }
|
||||
|
||||
// Unalias returns the type t for go <=1.21.
|
||||
func Unalias(t types.Type) types.Type { return t }
|
||||
|
||||
// Always false for go <=1.21. Ignores GODEBUG.
|
||||
func enabled() bool { return false }
|
||||
|
||||
func newAlias(name *types.TypeName, rhs types.Type) *Alias { panic("unreachable") }
|
||||
|
||||
// Enabled reports whether [NewAlias] should create [types.Alias] types.
|
||||
//
|
||||
// Before go1.22, this function always returns false.
|
||||
func Enabled() bool { return false }
|
||||
|
|
|
@ -12,14 +12,22 @@ import (
|
|||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Alias is an alias of types.Alias.
|
||||
type Alias = types.Alias
|
||||
|
||||
// Rhs returns the type on the right-hand side of the alias declaration.
|
||||
func Rhs(alias *Alias) types.Type {
|
||||
if alias, ok := any(alias).(interface{ Rhs() types.Type }); ok {
|
||||
return alias.Rhs() // go1.23+
|
||||
}
|
||||
|
||||
// go1.22's Alias didn't have the Rhs method,
|
||||
// so Unalias is the best we can do.
|
||||
return Unalias(alias)
|
||||
}
|
||||
|
||||
// Unalias is a wrapper of types.Unalias.
|
||||
func Unalias(t types.Type) types.Type { return types.Unalias(t) }
|
||||
|
||||
|
@ -33,40 +41,23 @@ func newAlias(tname *types.TypeName, rhs types.Type) *Alias {
|
|||
return a
|
||||
}
|
||||
|
||||
// enabled returns true when types.Aliases are enabled.
|
||||
func enabled() bool {
|
||||
// Use the gotypesalias value in GODEBUG if set.
|
||||
godebug := os.Getenv("GODEBUG")
|
||||
value := -1 // last set value.
|
||||
for _, f := range strings.Split(godebug, ",") {
|
||||
switch f {
|
||||
case "gotypesalias=1":
|
||||
value = 1
|
||||
case "gotypesalias=0":
|
||||
value = 0
|
||||
}
|
||||
}
|
||||
switch value {
|
||||
case 0:
|
||||
return false
|
||||
case 1:
|
||||
return true
|
||||
default:
|
||||
return aliasesDefault()
|
||||
}
|
||||
// Enabled reports whether [NewAlias] should create [types.Alias] types.
|
||||
//
|
||||
// This function is expensive! Call it sparingly.
|
||||
func Enabled() bool {
|
||||
// The only reliable way to compute the answer is to invoke go/types.
|
||||
// We don't parse the GODEBUG environment variable, because
|
||||
// (a) it's tricky to do so in a manner that is consistent
|
||||
// with the godebug package; in particular, a simple
|
||||
// substring check is not good enough. The value is a
|
||||
// rightmost-wins list of options. But more importantly:
|
||||
// (b) it is impossible to detect changes to the effective
|
||||
// setting caused by os.Setenv("GODEBUG"), as happens in
|
||||
// many tests. Therefore any attempt to cache the result
|
||||
// is just incorrect.
|
||||
fset := token.NewFileSet()
|
||||
f, _ := parser.ParseFile(fset, "a.go", "package p; type A = int", 0)
|
||||
pkg, _ := new(types.Config).Check("p", fset, []*ast.File{f}, nil)
|
||||
_, enabled := pkg.Scope().Lookup("A").Type().(*types.Alias)
|
||||
return enabled
|
||||
}
|
||||
|
||||
// aliasesDefault reports if aliases are enabled by default.
|
||||
func aliasesDefault() bool {
|
||||
// Dynamically check if Aliases will be produced from go/types.
|
||||
aliasesDefaultOnce.Do(func() {
|
||||
fset := token.NewFileSet()
|
||||
f, _ := parser.ParseFile(fset, "a.go", "package p; type A = int", 0)
|
||||
pkg, _ := new(types.Config).Check("p", fset, []*ast.File{f}, nil)
|
||||
_, gotypesaliasDefault = pkg.Scope().Lookup("A").Type().(*types.Alias)
|
||||
})
|
||||
return gotypesaliasDefault
|
||||
}
|
||||
|
||||
var gotypesaliasDefault bool
|
||||
var aliasesDefaultOnce sync.Once
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
// Copyright 2019 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tag provides the labels used for telemetry throughout gopls.
|
||||
package tag
|
||||
|
||||
import (
|
||||
"golang.org/x/tools/internal/event/keys"
|
||||
)
|
||||
|
||||
var (
|
||||
// create the label keys we use
|
||||
Method = keys.NewString("method", "")
|
||||
StatusCode = keys.NewString("status.code", "")
|
||||
StatusMessage = keys.NewString("status.message", "")
|
||||
RPCID = keys.NewString("id", "")
|
||||
RPCDirection = keys.NewString("direction", "")
|
||||
File = keys.NewString("file", "")
|
||||
Directory = keys.New("directory", "")
|
||||
URI = keys.New("URI", "")
|
||||
Package = keys.NewString("package", "") // sorted comma-separated list of Package IDs
|
||||
PackagePath = keys.NewString("package_path", "")
|
||||
Query = keys.New("query", "")
|
||||
Snapshot = keys.NewUInt64("snapshot", "")
|
||||
Operation = keys.NewString("operation", "")
|
||||
|
||||
Position = keys.New("position", "")
|
||||
Category = keys.NewString("category", "")
|
||||
PackageCount = keys.NewInt("packages", "")
|
||||
Files = keys.New("files", "")
|
||||
Port = keys.NewInt("port", "")
|
||||
Type = keys.New("type", "")
|
||||
HoverKind = keys.NewString("hoverkind", "")
|
||||
|
||||
NewServer = keys.NewString("new_server", "A new server was added")
|
||||
EndServer = keys.NewString("end_server", "A server was shut down")
|
||||
|
||||
ServerID = keys.NewString("server", "The server ID an event is related to")
|
||||
Logfile = keys.NewString("logfile", "")
|
||||
DebugAddress = keys.NewString("debug_address", "")
|
||||
GoplsPath = keys.NewString("gopls_path", "")
|
||||
ClientID = keys.NewString("client_id", "")
|
||||
|
||||
Level = keys.NewInt("level", "The logging level")
|
||||
)
|
||||
|
||||
var (
|
||||
// create the stats we measure
|
||||
Started = keys.NewInt64("started", "Count of started RPCs.")
|
||||
ReceivedBytes = keys.NewInt64("received_bytes", "Bytes received.") //, unit.Bytes)
|
||||
SentBytes = keys.NewInt64("sent_bytes", "Bytes sent.") //, unit.Bytes)
|
||||
Latency = keys.NewFloat64("latency_ms", "Elapsed time in milliseconds") //, unit.Milliseconds)
|
||||
)
|
||||
|
||||
const (
|
||||
Inbound = "in"
|
||||
Outbound = "out"
|
||||
)
|
|
@ -464,7 +464,7 @@ func (p *iexporter) doDecl(obj types.Object) {
|
|||
|
||||
switch obj := obj.(type) {
|
||||
case *types.Var:
|
||||
w.tag('V')
|
||||
w.tag(varTag)
|
||||
w.pos(obj.Pos())
|
||||
w.typ(obj.Type(), obj.Pkg())
|
||||
|
||||
|
@ -482,9 +482,9 @@ func (p *iexporter) doDecl(obj types.Object) {
|
|||
|
||||
// Function.
|
||||
if sig.TypeParams().Len() == 0 {
|
||||
w.tag('F')
|
||||
w.tag(funcTag)
|
||||
} else {
|
||||
w.tag('G')
|
||||
w.tag(genericFuncTag)
|
||||
}
|
||||
w.pos(obj.Pos())
|
||||
// The tparam list of the function type is the declaration of the type
|
||||
|
@ -500,7 +500,7 @@ func (p *iexporter) doDecl(obj types.Object) {
|
|||
w.signature(sig)
|
||||
|
||||
case *types.Const:
|
||||
w.tag('C')
|
||||
w.tag(constTag)
|
||||
w.pos(obj.Pos())
|
||||
w.value(obj.Type(), obj.Val())
|
||||
|
||||
|
@ -508,7 +508,7 @@ func (p *iexporter) doDecl(obj types.Object) {
|
|||
t := obj.Type()
|
||||
|
||||
if tparam, ok := aliases.Unalias(t).(*types.TypeParam); ok {
|
||||
w.tag('P')
|
||||
w.tag(typeParamTag)
|
||||
w.pos(obj.Pos())
|
||||
constraint := tparam.Constraint()
|
||||
if p.version >= iexportVersionGo1_18 {
|
||||
|
@ -523,8 +523,13 @@ func (p *iexporter) doDecl(obj types.Object) {
|
|||
}
|
||||
|
||||
if obj.IsAlias() {
|
||||
w.tag('A')
|
||||
w.tag(aliasTag)
|
||||
w.pos(obj.Pos())
|
||||
if alias, ok := t.(*aliases.Alias); ok {
|
||||
// Preserve materialized aliases,
|
||||
// even of non-exported types.
|
||||
t = aliases.Rhs(alias)
|
||||
}
|
||||
w.typ(t, obj.Pkg())
|
||||
break
|
||||
}
|
||||
|
@ -536,9 +541,9 @@ func (p *iexporter) doDecl(obj types.Object) {
|
|||
}
|
||||
|
||||
if named.TypeParams().Len() == 0 {
|
||||
w.tag('T')
|
||||
w.tag(typeTag)
|
||||
} else {
|
||||
w.tag('U')
|
||||
w.tag(genericTypeTag)
|
||||
}
|
||||
w.pos(obj.Pos())
|
||||
|
||||
|
@ -548,7 +553,7 @@ func (p *iexporter) doDecl(obj types.Object) {
|
|||
w.tparamList(obj.Name(), named.TypeParams(), obj.Pkg())
|
||||
}
|
||||
|
||||
underlying := obj.Type().Underlying()
|
||||
underlying := named.Underlying()
|
||||
w.typ(underlying, obj.Pkg())
|
||||
|
||||
if types.IsInterface(t) {
|
||||
|
@ -739,7 +744,10 @@ func (w *exportWriter) doTyp(t types.Type, pkg *types.Package) {
|
|||
}()
|
||||
}
|
||||
switch t := t.(type) {
|
||||
// TODO(adonovan): support types.Alias.
|
||||
case *aliases.Alias:
|
||||
// TODO(adonovan): support parameterized aliases, following *types.Named.
|
||||
w.startType(aliasType)
|
||||
w.qualifiedType(t.Obj())
|
||||
|
||||
case *types.Named:
|
||||
if targs := t.TypeArgs(); targs.Len() > 0 {
|
||||
|
|
|
@ -80,6 +80,20 @@ const (
|
|||
typeParamType
|
||||
instanceType
|
||||
unionType
|
||||
aliasType
|
||||
)
|
||||
|
||||
// Object tags
|
||||
const (
|
||||
varTag = 'V'
|
||||
funcTag = 'F'
|
||||
genericFuncTag = 'G'
|
||||
constTag = 'C'
|
||||
aliasTag = 'A'
|
||||
genericAliasTag = 'B'
|
||||
typeParamTag = 'P'
|
||||
typeTag = 'T'
|
||||
genericTypeTag = 'U'
|
||||
)
|
||||
|
||||
// IImportData imports a package from the serialized package data
|
||||
|
@ -196,6 +210,7 @@ func iimportCommon(fset *token.FileSet, getPackages GetPackagesFunc, data []byte
|
|||
p := iimporter{
|
||||
version: int(version),
|
||||
ipath: path,
|
||||
aliases: aliases.Enabled(),
|
||||
shallow: shallow,
|
||||
reportf: reportf,
|
||||
|
||||
|
@ -324,7 +339,7 @@ func iimportCommon(fset *token.FileSet, getPackages GetPackagesFunc, data []byte
|
|||
}
|
||||
|
||||
// SetConstraint can't be called if the constraint type is not yet complete.
|
||||
// When type params are created in the 'P' case of (*importReader).obj(),
|
||||
// When type params are created in the typeParamTag case of (*importReader).obj(),
|
||||
// the associated constraint type may not be complete due to recursion.
|
||||
// Therefore, we defer calling SetConstraint there, and call it here instead
|
||||
// after all types are complete.
|
||||
|
@ -355,6 +370,7 @@ type iimporter struct {
|
|||
version int
|
||||
ipath string
|
||||
|
||||
aliases bool
|
||||
shallow bool
|
||||
reportf ReportFunc // if non-nil, used to report bugs
|
||||
|
||||
|
@ -546,25 +562,29 @@ func (r *importReader) obj(name string) {
|
|||
pos := r.pos()
|
||||
|
||||
switch tag {
|
||||
case 'A':
|
||||
case aliasTag:
|
||||
typ := r.typ()
|
||||
// TODO(adonovan): support generic aliases:
|
||||
// if tag == genericAliasTag {
|
||||
// tparams := r.tparamList()
|
||||
// alias.SetTypeParams(tparams)
|
||||
// }
|
||||
r.declare(aliases.NewAlias(r.p.aliases, pos, r.currPkg, name, typ))
|
||||
|
||||
r.declare(types.NewTypeName(pos, r.currPkg, name, typ))
|
||||
|
||||
case 'C':
|
||||
case constTag:
|
||||
typ, val := r.value()
|
||||
|
||||
r.declare(types.NewConst(pos, r.currPkg, name, typ, val))
|
||||
|
||||
case 'F', 'G':
|
||||
case funcTag, genericFuncTag:
|
||||
var tparams []*types.TypeParam
|
||||
if tag == 'G' {
|
||||
if tag == genericFuncTag {
|
||||
tparams = r.tparamList()
|
||||
}
|
||||
sig := r.signature(nil, nil, tparams)
|
||||
r.declare(types.NewFunc(pos, r.currPkg, name, sig))
|
||||
|
||||
case 'T', 'U':
|
||||
case typeTag, genericTypeTag:
|
||||
// Types can be recursive. We need to setup a stub
|
||||
// declaration before recursing.
|
||||
obj := types.NewTypeName(pos, r.currPkg, name, nil)
|
||||
|
@ -572,7 +592,7 @@ func (r *importReader) obj(name string) {
|
|||
// Declare obj before calling r.tparamList, so the new type name is recognized
|
||||
// if used in the constraint of one of its own typeparams (see #48280).
|
||||
r.declare(obj)
|
||||
if tag == 'U' {
|
||||
if tag == genericTypeTag {
|
||||
tparams := r.tparamList()
|
||||
named.SetTypeParams(tparams)
|
||||
}
|
||||
|
@ -604,7 +624,7 @@ func (r *importReader) obj(name string) {
|
|||
}
|
||||
}
|
||||
|
||||
case 'P':
|
||||
case typeParamTag:
|
||||
// We need to "declare" a typeparam in order to have a name that
|
||||
// can be referenced recursively (if needed) in the type param's
|
||||
// bound.
|
||||
|
@ -637,7 +657,7 @@ func (r *importReader) obj(name string) {
|
|||
// completely set up all types in ImportData.
|
||||
r.p.later = append(r.p.later, setConstraintArgs{t: t, constraint: constraint})
|
||||
|
||||
case 'V':
|
||||
case varTag:
|
||||
typ := r.typ()
|
||||
|
||||
r.declare(types.NewVar(pos, r.currPkg, name, typ))
|
||||
|
@ -854,7 +874,7 @@ func (r *importReader) doType(base *types.Named) (res types.Type) {
|
|||
errorf("unexpected kind tag in %q: %v", r.p.ipath, k)
|
||||
return nil
|
||||
|
||||
case definedType:
|
||||
case aliasType, definedType:
|
||||
pkg, name := r.qualifiedIdent()
|
||||
r.p.doDecl(pkg, name)
|
||||
return pkg.Scope().Lookup(name).(*types.TypeName).Type()
|
||||
|
|
|
@ -26,6 +26,7 @@ type pkgReader struct {
|
|||
|
||||
ctxt *types.Context
|
||||
imports map[string]*types.Package // previously imported packages, indexed by path
|
||||
aliases bool // create types.Alias nodes
|
||||
|
||||
// lazily initialized arrays corresponding to the unified IR
|
||||
// PosBase, Pkg, and Type sections, respectively.
|
||||
|
@ -99,6 +100,7 @@ func readUnifiedPackage(fset *token.FileSet, ctxt *types.Context, imports map[st
|
|||
|
||||
ctxt: ctxt,
|
||||
imports: imports,
|
||||
aliases: aliases.Enabled(),
|
||||
|
||||
posBases: make([]string, input.NumElems(pkgbits.RelocPosBase)),
|
||||
pkgs: make([]*types.Package, input.NumElems(pkgbits.RelocPkg)),
|
||||
|
@ -524,7 +526,7 @@ func (pr *pkgReader) objIdx(idx pkgbits.Index) (*types.Package, string) {
|
|||
case pkgbits.ObjAlias:
|
||||
pos := r.pos()
|
||||
typ := r.typ()
|
||||
declare(types.NewTypeName(pos, objPkg, objName, typ))
|
||||
declare(aliases.NewAlias(r.p.aliases, pos, objPkg, objName, typ))
|
||||
|
||||
case pkgbits.ObjConst:
|
||||
pos := r.pos()
|
||||
|
|
|
@ -25,7 +25,6 @@ import (
|
|||
"golang.org/x/tools/internal/event"
|
||||
"golang.org/x/tools/internal/event/keys"
|
||||
"golang.org/x/tools/internal/event/label"
|
||||
"golang.org/x/tools/internal/event/tag"
|
||||
)
|
||||
|
||||
// An Runner will run go command invocations and serialize
|
||||
|
@ -55,11 +54,14 @@ func (runner *Runner) initialize() {
|
|||
// 1.14: go: updating go.mod: existing contents have changed since last read
|
||||
var modConcurrencyError = regexp.MustCompile(`go:.*go.mod.*contents have changed`)
|
||||
|
||||
// verb is an event label for the go command verb.
|
||||
var verb = keys.NewString("verb", "go command verb")
|
||||
// event keys for go command invocations
|
||||
var (
|
||||
verb = keys.NewString("verb", "go command verb")
|
||||
directory = keys.NewString("directory", "")
|
||||
)
|
||||
|
||||
func invLabels(inv Invocation) []label.Label {
|
||||
return []label.Label{verb.Of(inv.Verb), tag.Directory.Of(inv.WorkingDir)}
|
||||
return []label.Label{verb.Of(inv.Verb), directory.Of(inv.WorkingDir)}
|
||||
}
|
||||
|
||||
// Run is a convenience wrapper around RunRaw.
|
||||
|
@ -158,12 +160,15 @@ type Invocation struct {
|
|||
BuildFlags []string
|
||||
|
||||
// If ModFlag is set, the go command is invoked with -mod=ModFlag.
|
||||
// TODO(rfindley): remove, in favor of Args.
|
||||
ModFlag string
|
||||
|
||||
// If ModFile is set, the go command is invoked with -modfile=ModFile.
|
||||
// TODO(rfindley): remove, in favor of Args.
|
||||
ModFile string
|
||||
|
||||
// If Overlay is set, the go command is invoked with -overlay=Overlay.
|
||||
// TODO(rfindley): remove, in favor of Args.
|
||||
Overlay string
|
||||
|
||||
// If CleanEnv is set, the invocation will run only with the environment
|
||||
|
|
|
@ -107,3 +107,57 @@ func getMainModuleAnd114(ctx context.Context, inv Invocation, r *Runner) (*Modul
|
|||
}
|
||||
return mod, lines[4] == "go1.14", nil
|
||||
}
|
||||
|
||||
// WorkspaceVendorEnabled reports whether workspace vendoring is enabled. It takes a *Runner to execute Go commands
|
||||
// with the supplied context.Context and Invocation. The Invocation can contain pre-defined fields,
|
||||
// of which only Verb and Args are modified to run the appropriate Go command.
|
||||
// Inspired by setDefaultBuildMod in modload/init.go
|
||||
func WorkspaceVendorEnabled(ctx context.Context, inv Invocation, r *Runner) (bool, []*ModuleJSON, error) {
|
||||
inv.Verb = "env"
|
||||
inv.Args = []string{"GOWORK"}
|
||||
stdout, err := r.Run(ctx, inv)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
goWork := string(bytes.TrimSpace(stdout.Bytes()))
|
||||
if fi, err := os.Stat(filepath.Join(filepath.Dir(goWork), "vendor")); err == nil && fi.IsDir() {
|
||||
mainMods, err := getWorkspaceMainModules(ctx, inv, r)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
return true, mainMods, nil
|
||||
}
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// getWorkspaceMainModules gets the main modules' information.
|
||||
// This is the information needed to figure out if vendoring should be enabled.
|
||||
func getWorkspaceMainModules(ctx context.Context, inv Invocation, r *Runner) ([]*ModuleJSON, error) {
|
||||
const format = `{{.Path}}
|
||||
{{.Dir}}
|
||||
{{.GoMod}}
|
||||
{{.GoVersion}}
|
||||
`
|
||||
inv.Verb = "list"
|
||||
inv.Args = []string{"-m", "-f", format}
|
||||
stdout, err := r.Run(ctx, inv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(stdout.String(), "\n"), "\n")
|
||||
if len(lines) < 4 {
|
||||
return nil, fmt.Errorf("unexpected stdout: %q", stdout.String())
|
||||
}
|
||||
mods := make([]*ModuleJSON, 0, len(lines)/4)
|
||||
for i := 0; i < len(lines); i += 4 {
|
||||
mods = append(mods, &ModuleJSON{
|
||||
Path: lines[i],
|
||||
Dir: lines[i+1],
|
||||
GoMod: lines[i+2],
|
||||
GoVersion: lines[i+3],
|
||||
Main: true,
|
||||
})
|
||||
}
|
||||
return mods, nil
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"golang.org/x/tools/internal/event"
|
||||
"golang.org/x/tools/internal/gocommand"
|
||||
"golang.org/x/tools/internal/gopathwalk"
|
||||
"golang.org/x/tools/internal/stdlib"
|
||||
)
|
||||
|
||||
// importToGroup is a list of functions which map from an import path to
|
||||
|
@ -300,6 +301,20 @@ func (p *pass) loadPackageNames(imports []*ImportInfo) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// if there is a trailing major version, remove it
|
||||
func withoutVersion(nm string) string {
|
||||
if v := path.Base(nm); len(v) > 0 && v[0] == 'v' {
|
||||
if _, err := strconv.Atoi(v[1:]); err == nil {
|
||||
// this is, for instance, called with rand/v2 and returns rand
|
||||
if len(v) < len(nm) {
|
||||
xnm := nm[:len(nm)-len(v)-1]
|
||||
return path.Base(xnm)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nm
|
||||
}
|
||||
|
||||
// importIdentifier returns the identifier that imp will introduce. It will
|
||||
// guess if the package name has not been loaded, e.g. because the source
|
||||
// is not available.
|
||||
|
@ -309,7 +324,7 @@ func (p *pass) importIdentifier(imp *ImportInfo) string {
|
|||
}
|
||||
known := p.knownPackages[imp.ImportPath]
|
||||
if known != nil && known.name != "" {
|
||||
return known.name
|
||||
return withoutVersion(known.name)
|
||||
}
|
||||
return ImportPathToAssumedName(imp.ImportPath)
|
||||
}
|
||||
|
@ -511,9 +526,9 @@ func (p *pass) assumeSiblingImportsValid() {
|
|||
}
|
||||
for left, rights := range refs {
|
||||
if imp, ok := importsByName[left]; ok {
|
||||
if m, ok := stdlib[imp.ImportPath]; ok {
|
||||
if m, ok := stdlib.PackageSymbols[imp.ImportPath]; ok {
|
||||
// We have the stdlib in memory; no need to guess.
|
||||
rights = copyExports(m)
|
||||
rights = symbolNameSet(m)
|
||||
}
|
||||
p.addCandidate(imp, &packageInfo{
|
||||
// no name; we already know it.
|
||||
|
@ -641,7 +656,7 @@ func getCandidatePkgs(ctx context.Context, wrappedCallback *scanCallback, filena
|
|||
dupCheck := map[string]struct{}{}
|
||||
|
||||
// Start off with the standard library.
|
||||
for importPath, exports := range stdlib {
|
||||
for importPath, symbols := range stdlib.PackageSymbols {
|
||||
p := &pkg{
|
||||
dir: filepath.Join(goenv["GOROOT"], "src", importPath),
|
||||
importPathShort: importPath,
|
||||
|
@ -650,6 +665,13 @@ func getCandidatePkgs(ctx context.Context, wrappedCallback *scanCallback, filena
|
|||
}
|
||||
dupCheck[importPath] = struct{}{}
|
||||
if notSelf(p) && wrappedCallback.dirFound(p) && wrappedCallback.packageNameLoaded(p) {
|
||||
var exports []stdlib.Symbol
|
||||
for _, sym := range symbols {
|
||||
switch sym.Kind {
|
||||
case stdlib.Func, stdlib.Type, stdlib.Var, stdlib.Const:
|
||||
exports = append(exports, sym)
|
||||
}
|
||||
}
|
||||
wrappedCallback.exportsLoaded(p, exports)
|
||||
}
|
||||
}
|
||||
|
@ -670,7 +692,7 @@ func getCandidatePkgs(ctx context.Context, wrappedCallback *scanCallback, filena
|
|||
dupCheck[pkg.importPathShort] = struct{}{}
|
||||
return notSelf(pkg) && wrappedCallback.packageNameLoaded(pkg)
|
||||
},
|
||||
exportsLoaded: func(pkg *pkg, exports []string) {
|
||||
exportsLoaded: func(pkg *pkg, exports []stdlib.Symbol) {
|
||||
// If we're an x_test, load the package under test's test variant.
|
||||
if strings.HasSuffix(filePkg, "_test") && pkg.dir == filepath.Dir(filename) {
|
||||
var err error
|
||||
|
@ -795,7 +817,7 @@ func GetImportPaths(ctx context.Context, wrapped func(ImportFix), searchPrefix,
|
|||
// A PackageExport is a package and its exports.
|
||||
type PackageExport struct {
|
||||
Fix *ImportFix
|
||||
Exports []string
|
||||
Exports []stdlib.Symbol
|
||||
}
|
||||
|
||||
// GetPackageExports returns all known packages with name pkg and their exports.
|
||||
|
@ -810,8 +832,8 @@ func GetPackageExports(ctx context.Context, wrapped func(PackageExport), searchP
|
|||
packageNameLoaded: func(pkg *pkg) bool {
|
||||
return pkg.packageName == searchPkg
|
||||
},
|
||||
exportsLoaded: func(pkg *pkg, exports []string) {
|
||||
sort.Strings(exports)
|
||||
exportsLoaded: func(pkg *pkg, exports []stdlib.Symbol) {
|
||||
sortSymbols(exports)
|
||||
wrapped(PackageExport{
|
||||
Fix: &ImportFix{
|
||||
StmtInfo: ImportInfo{
|
||||
|
@ -988,8 +1010,10 @@ func (e *ProcessEnv) GetResolver() (Resolver, error) {
|
|||
// already know the view type.
|
||||
if len(e.Env["GOMOD"]) == 0 && len(e.Env["GOWORK"]) == 0 {
|
||||
e.resolver = newGopathResolver(e)
|
||||
} else if r, err := newModuleResolver(e, e.ModCache); err != nil {
|
||||
e.resolverErr = err
|
||||
} else {
|
||||
e.resolver, e.resolverErr = newModuleResolver(e, e.ModCache)
|
||||
e.resolver = Resolver(r)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1049,24 +1073,40 @@ func addStdlibCandidates(pass *pass, refs references) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
localbase := func(nm string) string {
|
||||
ans := path.Base(nm)
|
||||
if ans[0] == 'v' {
|
||||
// this is called, for instance, with math/rand/v2 and returns rand/v2
|
||||
if _, err := strconv.Atoi(ans[1:]); err == nil {
|
||||
ix := strings.LastIndex(nm, ans)
|
||||
more := path.Base(nm[:ix])
|
||||
ans = path.Join(more, ans)
|
||||
}
|
||||
}
|
||||
return ans
|
||||
}
|
||||
add := func(pkg string) {
|
||||
// Prevent self-imports.
|
||||
if path.Base(pkg) == pass.f.Name.Name && filepath.Join(goenv["GOROOT"], "src", pkg) == pass.srcDir {
|
||||
return
|
||||
}
|
||||
exports := copyExports(stdlib[pkg])
|
||||
exports := symbolNameSet(stdlib.PackageSymbols[pkg])
|
||||
pass.addCandidate(
|
||||
&ImportInfo{ImportPath: pkg},
|
||||
&packageInfo{name: path.Base(pkg), exports: exports})
|
||||
&packageInfo{name: localbase(pkg), exports: exports})
|
||||
}
|
||||
for left := range refs {
|
||||
if left == "rand" {
|
||||
// Make sure we try crypto/rand before math/rand.
|
||||
// Make sure we try crypto/rand before any version of math/rand as both have Int()
|
||||
// and our policy is to recommend crypto
|
||||
add("crypto/rand")
|
||||
add("math/rand")
|
||||
// if the user's no later than go1.21, this should be "math/rand"
|
||||
// but we have no way of figuring out what the user is using
|
||||
// TODO: investigate using the toolchain version to disambiguate in the stdlib
|
||||
add("math/rand/v2")
|
||||
continue
|
||||
}
|
||||
for importPath := range stdlib {
|
||||
for importPath := range stdlib.PackageSymbols {
|
||||
if path.Base(importPath) == left {
|
||||
add(importPath)
|
||||
}
|
||||
|
@ -1085,7 +1125,7 @@ type Resolver interface {
|
|||
|
||||
// loadExports returns the set of exported symbols in the package at dir.
|
||||
// loadExports may be called concurrently.
|
||||
loadExports(ctx context.Context, pkg *pkg, includeTest bool) (string, []string, error)
|
||||
loadExports(ctx context.Context, pkg *pkg, includeTest bool) (string, []stdlib.Symbol, error)
|
||||
|
||||
// scoreImportPath returns the relevance for an import path.
|
||||
scoreImportPath(ctx context.Context, path string) float64
|
||||
|
@ -1114,7 +1154,7 @@ type scanCallback struct {
|
|||
// If it returns true, the package's exports will be loaded.
|
||||
packageNameLoaded func(pkg *pkg) bool
|
||||
// exportsLoaded is called when a package's exports have been loaded.
|
||||
exportsLoaded func(pkg *pkg, exports []string)
|
||||
exportsLoaded func(pkg *pkg, exports []stdlib.Symbol)
|
||||
}
|
||||
|
||||
func addExternalCandidates(ctx context.Context, pass *pass, refs references, filename string) error {
|
||||
|
@ -1295,7 +1335,7 @@ func (r *gopathResolver) loadPackageNames(importPaths []string, srcDir string) (
|
|||
// importPathToName finds out the actual package name, as declared in its .go files.
|
||||
func importPathToName(bctx *build.Context, importPath, srcDir string) string {
|
||||
// Fast path for standard library without going to disk.
|
||||
if _, ok := stdlib[importPath]; ok {
|
||||
if stdlib.HasPackage(importPath) {
|
||||
return path.Base(importPath) // stdlib packages always match their paths.
|
||||
}
|
||||
|
||||
|
@ -1493,7 +1533,7 @@ func (r *gopathResolver) scan(ctx context.Context, callback *scanCallback) error
|
|||
}
|
||||
|
||||
func (r *gopathResolver) scoreImportPath(ctx context.Context, path string) float64 {
|
||||
if _, ok := stdlib[path]; ok {
|
||||
if stdlib.HasPackage(path) {
|
||||
return MaxRelevance
|
||||
}
|
||||
return MaxRelevance - 1
|
||||
|
@ -1510,7 +1550,7 @@ func filterRoots(roots []gopathwalk.Root, include func(gopathwalk.Root) bool) []
|
|||
return result
|
||||
}
|
||||
|
||||
func (r *gopathResolver) loadExports(ctx context.Context, pkg *pkg, includeTest bool) (string, []string, error) {
|
||||
func (r *gopathResolver) loadExports(ctx context.Context, pkg *pkg, includeTest bool) (string, []stdlib.Symbol, error) {
|
||||
if info, ok := r.cache.Load(pkg.dir); ok && !includeTest {
|
||||
return r.cache.CacheExports(ctx, r.env, info)
|
||||
}
|
||||
|
@ -1530,7 +1570,7 @@ func VendorlessPath(ipath string) string {
|
|||
return ipath
|
||||
}
|
||||
|
||||
func loadExportsFromFiles(ctx context.Context, env *ProcessEnv, dir string, includeTest bool) (string, []string, error) {
|
||||
func loadExportsFromFiles(ctx context.Context, env *ProcessEnv, dir string, includeTest bool) (string, []stdlib.Symbol, error) {
|
||||
// Look for non-test, buildable .go files which could provide exports.
|
||||
all, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
|
@ -1554,7 +1594,7 @@ func loadExportsFromFiles(ctx context.Context, env *ProcessEnv, dir string, incl
|
|||
}
|
||||
|
||||
var pkgName string
|
||||
var exports []string
|
||||
var exports []stdlib.Symbol
|
||||
fset := token.NewFileSet()
|
||||
for _, fi := range files {
|
||||
select {
|
||||
|
@ -1581,21 +1621,41 @@ func loadExportsFromFiles(ctx context.Context, env *ProcessEnv, dir string, incl
|
|||
continue
|
||||
}
|
||||
pkgName = f.Name.Name
|
||||
for name := range f.Scope.Objects {
|
||||
for name, obj := range f.Scope.Objects {
|
||||
if ast.IsExported(name) {
|
||||
exports = append(exports, name)
|
||||
var kind stdlib.Kind
|
||||
switch obj.Kind {
|
||||
case ast.Con:
|
||||
kind = stdlib.Const
|
||||
case ast.Typ:
|
||||
kind = stdlib.Type
|
||||
case ast.Var:
|
||||
kind = stdlib.Var
|
||||
case ast.Fun:
|
||||
kind = stdlib.Func
|
||||
}
|
||||
exports = append(exports, stdlib.Symbol{
|
||||
Name: name,
|
||||
Kind: kind,
|
||||
Version: 0, // unknown; be permissive
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
sortSymbols(exports)
|
||||
|
||||
if env.Logf != nil {
|
||||
sortedExports := append([]string(nil), exports...)
|
||||
sort.Strings(sortedExports)
|
||||
env.Logf("loaded exports in dir %v (package %v): %v", dir, pkgName, strings.Join(sortedExports, ", "))
|
||||
env.Logf("loaded exports in dir %v (package %v): %v", dir, pkgName, exports)
|
||||
}
|
||||
return pkgName, exports, nil
|
||||
}
|
||||
|
||||
func sortSymbols(syms []stdlib.Symbol) {
|
||||
sort.Slice(syms, func(i, j int) bool {
|
||||
return syms[i].Name < syms[j].Name
|
||||
})
|
||||
}
|
||||
|
||||
// findImport searches for a package with the given symbols.
|
||||
// If no package is found, findImport returns ("", false, nil)
|
||||
func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgName string, symbols map[string]bool) (*pkg, error) {
|
||||
|
@ -1662,7 +1722,7 @@ func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgNa
|
|||
|
||||
exportsMap := make(map[string]bool, len(exports))
|
||||
for _, sym := range exports {
|
||||
exportsMap[sym] = true
|
||||
exportsMap[sym.Name] = true
|
||||
}
|
||||
|
||||
// If it doesn't have the right
|
||||
|
@ -1820,10 +1880,13 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor {
|
|||
return fn(node)
|
||||
}
|
||||
|
||||
func copyExports(pkg []string) map[string]bool {
|
||||
m := make(map[string]bool, len(pkg))
|
||||
for _, v := range pkg {
|
||||
m[v] = true
|
||||
func symbolNameSet(symbols []stdlib.Symbol) map[string]bool {
|
||||
names := make(map[string]bool)
|
||||
for _, sym := range symbols {
|
||||
switch sym.Kind {
|
||||
case stdlib.Const, stdlib.Var, stdlib.Type, stdlib.Func:
|
||||
names[sym.Name] = true
|
||||
}
|
||||
}
|
||||
return m
|
||||
return names
|
||||
}
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:generate go run mkstdlib.go
|
||||
|
||||
// Package imports implements a Go pretty-printer (like package "go/format")
|
||||
// that also adds or removes import statements as necessary.
|
||||
package imports
|
||||
|
@ -109,7 +107,7 @@ func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options, e
|
|||
}
|
||||
|
||||
// formatFile formats the file syntax tree.
|
||||
// It may mutate the token.FileSet.
|
||||
// It may mutate the token.FileSet and the ast.File.
|
||||
//
|
||||
// If an adjust function is provided, it is called after formatting
|
||||
// with the original source (formatFile's src parameter) and the
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"golang.org/x/tools/internal/event"
|
||||
"golang.org/x/tools/internal/gocommand"
|
||||
"golang.org/x/tools/internal/gopathwalk"
|
||||
"golang.org/x/tools/internal/stdlib"
|
||||
)
|
||||
|
||||
// Notes(rfindley): ModuleResolver appears to be heavily optimized for scanning
|
||||
|
@ -111,11 +112,11 @@ func newModuleResolver(e *ProcessEnv, moduleCacheCache *DirInfoCache) (*ModuleRe
|
|||
}
|
||||
|
||||
vendorEnabled := false
|
||||
var mainModVendor *gocommand.ModuleJSON
|
||||
var mainModVendor *gocommand.ModuleJSON // for module vendoring
|
||||
var mainModsVendor []*gocommand.ModuleJSON // for workspace vendoring
|
||||
|
||||
// Module vendor directories are ignored in workspace mode:
|
||||
// https://go.googlesource.com/proposal/+/master/design/45713-workspace.md
|
||||
if len(r.env.Env["GOWORK"]) == 0 {
|
||||
goWork := r.env.Env["GOWORK"]
|
||||
if len(goWork) == 0 {
|
||||
// TODO(rfindley): VendorEnabled runs the go command to get GOFLAGS, but
|
||||
// they should be available from the ProcessEnv. Can we avoid the redundant
|
||||
// invocation?
|
||||
|
@ -123,18 +124,35 @@ func newModuleResolver(e *ProcessEnv, moduleCacheCache *DirInfoCache) (*ModuleRe
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
vendorEnabled, mainModsVendor, err = gocommand.WorkspaceVendorEnabled(context.Background(), inv, r.env.GocmdRunner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if mainModVendor != nil && vendorEnabled {
|
||||
// Vendor mode is on, so all the non-Main modules are irrelevant,
|
||||
// and we need to search /vendor for everything.
|
||||
r.mains = []*gocommand.ModuleJSON{mainModVendor}
|
||||
r.dummyVendorMod = &gocommand.ModuleJSON{
|
||||
Path: "",
|
||||
Dir: filepath.Join(mainModVendor.Dir, "vendor"),
|
||||
if vendorEnabled {
|
||||
if mainModVendor != nil {
|
||||
// Module vendor mode is on, so all the non-Main modules are irrelevant,
|
||||
// and we need to search /vendor for everything.
|
||||
r.mains = []*gocommand.ModuleJSON{mainModVendor}
|
||||
r.dummyVendorMod = &gocommand.ModuleJSON{
|
||||
Path: "",
|
||||
Dir: filepath.Join(mainModVendor.Dir, "vendor"),
|
||||
}
|
||||
r.modsByModPath = []*gocommand.ModuleJSON{mainModVendor, r.dummyVendorMod}
|
||||
r.modsByDir = []*gocommand.ModuleJSON{mainModVendor, r.dummyVendorMod}
|
||||
} else {
|
||||
// Workspace vendor mode is on, so all the non-Main modules are irrelevant,
|
||||
// and we need to search /vendor for everything.
|
||||
r.mains = mainModsVendor
|
||||
r.dummyVendorMod = &gocommand.ModuleJSON{
|
||||
Path: "",
|
||||
Dir: filepath.Join(filepath.Dir(goWork), "vendor"),
|
||||
}
|
||||
r.modsByModPath = append(append([]*gocommand.ModuleJSON{}, mainModsVendor...), r.dummyVendorMod)
|
||||
r.modsByDir = append(append([]*gocommand.ModuleJSON{}, mainModsVendor...), r.dummyVendorMod)
|
||||
}
|
||||
r.modsByModPath = []*gocommand.ModuleJSON{mainModVendor, r.dummyVendorMod}
|
||||
r.modsByDir = []*gocommand.ModuleJSON{mainModVendor, r.dummyVendorMod}
|
||||
} else {
|
||||
// Vendor mode is off, so run go list -m ... to find everything.
|
||||
err := r.initAllMods()
|
||||
|
@ -165,8 +183,9 @@ func newModuleResolver(e *ProcessEnv, moduleCacheCache *DirInfoCache) (*ModuleRe
|
|||
return count(j) < count(i) // descending order
|
||||
})
|
||||
|
||||
r.roots = []gopathwalk.Root{
|
||||
{Path: filepath.Join(goenv["GOROOT"], "/src"), Type: gopathwalk.RootGOROOT},
|
||||
r.roots = []gopathwalk.Root{}
|
||||
if goenv["GOROOT"] != "" { // "" happens in tests
|
||||
r.roots = append(r.roots, gopathwalk.Root{Path: filepath.Join(goenv["GOROOT"], "/src"), Type: gopathwalk.RootGOROOT})
|
||||
}
|
||||
r.mainByDir = make(map[string]*gocommand.ModuleJSON)
|
||||
for _, main := range r.mains {
|
||||
|
@ -313,15 +332,19 @@ func (r *ModuleResolver) ClearForNewScan() Resolver {
|
|||
// TODO(rfindley): move this to a new env.go, consolidating ProcessEnv methods.
|
||||
func (e *ProcessEnv) ClearModuleInfo() {
|
||||
if r, ok := e.resolver.(*ModuleResolver); ok {
|
||||
resolver, resolverErr := newModuleResolver(e, e.ModCache)
|
||||
if resolverErr == nil {
|
||||
<-r.scanSema // acquire (guards caches)
|
||||
resolver.moduleCacheCache = r.moduleCacheCache
|
||||
resolver.otherCache = r.otherCache
|
||||
r.scanSema <- struct{}{} // release
|
||||
resolver, err := newModuleResolver(e, e.ModCache)
|
||||
if err != nil {
|
||||
e.resolver = nil
|
||||
e.resolverErr = err
|
||||
return
|
||||
}
|
||||
e.resolver = resolver
|
||||
e.resolverErr = resolverErr
|
||||
|
||||
<-r.scanSema // acquire (guards caches)
|
||||
resolver.moduleCacheCache = r.moduleCacheCache
|
||||
resolver.otherCache = r.otherCache
|
||||
r.scanSema <- struct{}{} // release
|
||||
|
||||
e.UpdateResolver(resolver)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -412,7 +435,7 @@ func (r *ModuleResolver) cachePackageName(info directoryPackageInfo) (string, er
|
|||
return r.otherCache.CachePackageName(info)
|
||||
}
|
||||
|
||||
func (r *ModuleResolver) cacheExports(ctx context.Context, env *ProcessEnv, info directoryPackageInfo) (string, []string, error) {
|
||||
func (r *ModuleResolver) cacheExports(ctx context.Context, env *ProcessEnv, info directoryPackageInfo) (string, []stdlib.Symbol, error) {
|
||||
if info.rootType == gopathwalk.RootModuleCache {
|
||||
return r.moduleCacheCache.CacheExports(ctx, env, info)
|
||||
}
|
||||
|
@ -632,7 +655,7 @@ func (r *ModuleResolver) scan(ctx context.Context, callback *scanCallback) error
|
|||
}
|
||||
|
||||
func (r *ModuleResolver) scoreImportPath(ctx context.Context, path string) float64 {
|
||||
if _, ok := stdlib[path]; ok {
|
||||
if stdlib.HasPackage(path) {
|
||||
return MaxRelevance
|
||||
}
|
||||
mod, _ := r.findPackage(path)
|
||||
|
@ -710,7 +733,7 @@ func (r *ModuleResolver) canonicalize(info directoryPackageInfo) (*pkg, error) {
|
|||
return res, nil
|
||||
}
|
||||
|
||||
func (r *ModuleResolver) loadExports(ctx context.Context, pkg *pkg, includeTest bool) (string, []string, error) {
|
||||
func (r *ModuleResolver) loadExports(ctx context.Context, pkg *pkg, includeTest bool) (string, []stdlib.Symbol, error) {
|
||||
if info, ok := r.cacheLoad(pkg.dir); ok && !includeTest {
|
||||
return r.cacheExports(ctx, r.env, info)
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"golang.org/x/mod/module"
|
||||
"golang.org/x/tools/internal/gopathwalk"
|
||||
"golang.org/x/tools/internal/stdlib"
|
||||
)
|
||||
|
||||
// To find packages to import, the resolver needs to know about all of
|
||||
|
@ -73,7 +74,7 @@ type directoryPackageInfo struct {
|
|||
// the default build context GOOS and GOARCH.
|
||||
//
|
||||
// We can make this explicit, and key exports by GOOS, GOARCH.
|
||||
exports []string
|
||||
exports []stdlib.Symbol
|
||||
}
|
||||
|
||||
// reachedStatus returns true when info has a status at least target and any error associated with
|
||||
|
@ -229,7 +230,7 @@ func (d *DirInfoCache) CachePackageName(info directoryPackageInfo) (string, erro
|
|||
return info.packageName, info.err
|
||||
}
|
||||
|
||||
func (d *DirInfoCache) CacheExports(ctx context.Context, env *ProcessEnv, info directoryPackageInfo) (string, []string, error) {
|
||||
func (d *DirInfoCache) CacheExports(ctx context.Context, env *ProcessEnv, info directoryPackageInfo) (string, []stdlib.Symbol, error) {
|
||||
if reached, _ := info.reachedStatus(exportsLoaded); reached {
|
||||
return info.packageName, info.exports, info.err
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
// sortImports sorts runs of consecutive import lines in import blocks in f.
|
||||
// It also removes duplicate imports when it is possible to do so without data loss.
|
||||
//
|
||||
// It may mutate the token.File.
|
||||
// It may mutate the token.File and the ast.File.
|
||||
func sortImports(localPrefix string, tokFile *token.File, f *ast.File) {
|
||||
for i, d := range f.Decls {
|
||||
d, ok := d.(*ast.GenDecl)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -23,6 +23,9 @@ type PkgDecoder struct {
|
|||
// version is the file format version.
|
||||
version uint32
|
||||
|
||||
// aliases determines whether types.Aliases should be created
|
||||
aliases bool
|
||||
|
||||
// sync indicates whether the file uses sync markers.
|
||||
sync bool
|
||||
|
||||
|
@ -73,6 +76,7 @@ func (pr *PkgDecoder) SyncMarkers() bool { return pr.sync }
|
|||
func NewPkgDecoder(pkgPath, input string) PkgDecoder {
|
||||
pr := PkgDecoder{
|
||||
pkgPath: pkgPath,
|
||||
//aliases: aliases.Enabled(),
|
||||
}
|
||||
|
||||
// TODO(mdempsky): Implement direct indexing of input string to
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,97 @@
|
|||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:generate go run generate.go
|
||||
|
||||
// Package stdlib provides a table of all exported symbols in the
|
||||
// standard library, along with the version at which they first
|
||||
// appeared.
|
||||
package stdlib
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Symbol struct {
|
||||
Name string
|
||||
Kind Kind
|
||||
Version Version // Go version that first included the symbol
|
||||
}
|
||||
|
||||
// A Kind indicates the kind of a symbol:
|
||||
// function, variable, constant, type, and so on.
|
||||
type Kind int8
|
||||
|
||||
const (
|
||||
Invalid Kind = iota // Example name:
|
||||
Type // "Buffer"
|
||||
Func // "Println"
|
||||
Var // "EOF"
|
||||
Const // "Pi"
|
||||
Field // "Point.X"
|
||||
Method // "(*Buffer).Grow"
|
||||
)
|
||||
|
||||
func (kind Kind) String() string {
|
||||
return [...]string{
|
||||
Invalid: "invalid",
|
||||
Type: "type",
|
||||
Func: "func",
|
||||
Var: "var",
|
||||
Const: "const",
|
||||
Field: "field",
|
||||
Method: "method",
|
||||
}[kind]
|
||||
}
|
||||
|
||||
// A Version represents a version of Go of the form "go1.%d".
|
||||
type Version int8
|
||||
|
||||
// String returns a version string of the form "go1.23", without allocating.
|
||||
func (v Version) String() string { return versions[v] }
|
||||
|
||||
var versions [30]string // (increase constant as needed)
|
||||
|
||||
func init() {
|
||||
for i := range versions {
|
||||
versions[i] = fmt.Sprintf("go1.%d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// HasPackage reports whether the specified package path is part of
|
||||
// the standard library's public API.
|
||||
func HasPackage(path string) bool {
|
||||
_, ok := PackageSymbols[path]
|
||||
return ok
|
||||
}
|
||||
|
||||
// SplitField splits the field symbol name into type and field
|
||||
// components. It must be called only on Field symbols.
|
||||
//
|
||||
// Example: "File.Package" -> ("File", "Package")
|
||||
func (sym *Symbol) SplitField() (typename, name string) {
|
||||
if sym.Kind != Field {
|
||||
panic("not a field")
|
||||
}
|
||||
typename, name, _ = strings.Cut(sym.Name, ".")
|
||||
return
|
||||
}
|
||||
|
||||
// SplitMethod splits the method symbol name into pointer, receiver,
|
||||
// and method components. It must be called only on Method symbols.
|
||||
//
|
||||
// Example: "(*Buffer).Grow" -> (true, "Buffer", "Grow")
|
||||
func (sym *Symbol) SplitMethod() (ptr bool, recv, name string) {
|
||||
if sym.Kind != Method {
|
||||
panic("not a method")
|
||||
}
|
||||
recv, name, _ = strings.Cut(sym.Name, ".")
|
||||
recv = recv[len("(") : len(recv)-len(")")]
|
||||
ptr = recv[0] == '*'
|
||||
if ptr {
|
||||
recv = recv[len("*"):]
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,195 +0,0 @@
|
|||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package typeparams contains common utilities for writing tools that
|
||||
// interact with generic Go code, as introduced with Go 1.18. It
|
||||
// supplements the standard library APIs. Notably, the StructuralTerms
|
||||
// API computes a minimal representation of the structural
|
||||
// restrictions on a type parameter.
|
||||
//
|
||||
// An external version of these APIs is available in the
|
||||
// golang.org/x/exp/typeparams module.
|
||||
package typeparams
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"go/types"
|
||||
|
||||
"golang.org/x/tools/internal/aliases"
|
||||
"golang.org/x/tools/internal/typesinternal"
|
||||
)
|
||||
|
||||
// UnpackIndexExpr extracts data from AST nodes that represent index
|
||||
// expressions.
|
||||
//
|
||||
// For an ast.IndexExpr, the resulting indices slice will contain exactly one
|
||||
// index expression. For an ast.IndexListExpr (go1.18+), it may have a variable
|
||||
// number of index expressions.
|
||||
//
|
||||
// For nodes that don't represent index expressions, the first return value of
|
||||
// UnpackIndexExpr will be nil.
|
||||
func UnpackIndexExpr(n ast.Node) (x ast.Expr, lbrack token.Pos, indices []ast.Expr, rbrack token.Pos) {
|
||||
switch e := n.(type) {
|
||||
case *ast.IndexExpr:
|
||||
return e.X, e.Lbrack, []ast.Expr{e.Index}, e.Rbrack
|
||||
case *ast.IndexListExpr:
|
||||
return e.X, e.Lbrack, e.Indices, e.Rbrack
|
||||
}
|
||||
return nil, token.NoPos, nil, token.NoPos
|
||||
}
|
||||
|
||||
// PackIndexExpr returns an *ast.IndexExpr or *ast.IndexListExpr, depending on
|
||||
// the cardinality of indices. Calling PackIndexExpr with len(indices) == 0
|
||||
// will panic.
|
||||
func PackIndexExpr(x ast.Expr, lbrack token.Pos, indices []ast.Expr, rbrack token.Pos) ast.Expr {
|
||||
switch len(indices) {
|
||||
case 0:
|
||||
panic("empty indices")
|
||||
case 1:
|
||||
return &ast.IndexExpr{
|
||||
X: x,
|
||||
Lbrack: lbrack,
|
||||
Index: indices[0],
|
||||
Rbrack: rbrack,
|
||||
}
|
||||
default:
|
||||
return &ast.IndexListExpr{
|
||||
X: x,
|
||||
Lbrack: lbrack,
|
||||
Indices: indices,
|
||||
Rbrack: rbrack,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsTypeParam reports whether t is a type parameter (or an alias of one).
|
||||
func IsTypeParam(t types.Type) bool {
|
||||
_, ok := aliases.Unalias(t).(*types.TypeParam)
|
||||
return ok
|
||||
}
|
||||
|
||||
// OriginMethod returns the origin method associated with the method fn.
|
||||
// For methods on a non-generic receiver base type, this is just
|
||||
// fn. However, for methods with a generic receiver, OriginMethod returns the
|
||||
// corresponding method in the method set of the origin type.
|
||||
//
|
||||
// As a special case, if fn is not a method (has no receiver), OriginMethod
|
||||
// returns fn.
|
||||
func OriginMethod(fn *types.Func) *types.Func {
|
||||
recv := fn.Type().(*types.Signature).Recv()
|
||||
if recv == nil {
|
||||
return fn
|
||||
}
|
||||
_, named := typesinternal.ReceiverNamed(recv)
|
||||
if named == nil {
|
||||
// Receiver is a *types.Interface.
|
||||
return fn
|
||||
}
|
||||
if named.TypeParams().Len() == 0 {
|
||||
// Receiver base has no type parameters, so we can avoid the lookup below.
|
||||
return fn
|
||||
}
|
||||
orig := named.Origin()
|
||||
gfn, _, _ := types.LookupFieldOrMethod(orig, true, fn.Pkg(), fn.Name())
|
||||
|
||||
// This is a fix for a gopls crash (#60628) due to a go/types bug (#60634). In:
|
||||
// package p
|
||||
// type T *int
|
||||
// func (*T) f() {}
|
||||
// LookupFieldOrMethod(T, true, p, f)=nil, but NewMethodSet(*T)={(*T).f}.
|
||||
// Here we make them consistent by force.
|
||||
// (The go/types bug is general, but this workaround is reached only
|
||||
// for generic T thanks to the early return above.)
|
||||
if gfn == nil {
|
||||
mset := types.NewMethodSet(types.NewPointer(orig))
|
||||
for i := 0; i < mset.Len(); i++ {
|
||||
m := mset.At(i)
|
||||
if m.Obj().Id() == fn.Id() {
|
||||
gfn = m.Obj()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In golang/go#61196, we observe another crash, this time inexplicable.
|
||||
if gfn == nil {
|
||||
panic(fmt.Sprintf("missing origin method for %s.%s; named == origin: %t, named.NumMethods(): %d, origin.NumMethods(): %d", named, fn, named == orig, named.NumMethods(), orig.NumMethods()))
|
||||
}
|
||||
|
||||
return gfn.(*types.Func)
|
||||
}
|
||||
|
||||
// GenericAssignableTo is a generalization of types.AssignableTo that
|
||||
// implements the following rule for uninstantiated generic types:
|
||||
//
|
||||
// If V and T are generic named types, then V is considered assignable to T if,
|
||||
// for every possible instantation of V[A_1, ..., A_N], the instantiation
|
||||
// T[A_1, ..., A_N] is valid and V[A_1, ..., A_N] implements T[A_1, ..., A_N].
|
||||
//
|
||||
// If T has structural constraints, they must be satisfied by V.
|
||||
//
|
||||
// For example, consider the following type declarations:
|
||||
//
|
||||
// type Interface[T any] interface {
|
||||
// Accept(T)
|
||||
// }
|
||||
//
|
||||
// type Container[T any] struct {
|
||||
// Element T
|
||||
// }
|
||||
//
|
||||
// func (c Container[T]) Accept(t T) { c.Element = t }
|
||||
//
|
||||
// In this case, GenericAssignableTo reports that instantiations of Container
|
||||
// are assignable to the corresponding instantiation of Interface.
|
||||
func GenericAssignableTo(ctxt *types.Context, V, T types.Type) bool {
|
||||
V = aliases.Unalias(V)
|
||||
T = aliases.Unalias(T)
|
||||
|
||||
// If V and T are not both named, or do not have matching non-empty type
|
||||
// parameter lists, fall back on types.AssignableTo.
|
||||
|
||||
VN, Vnamed := V.(*types.Named)
|
||||
TN, Tnamed := T.(*types.Named)
|
||||
if !Vnamed || !Tnamed {
|
||||
return types.AssignableTo(V, T)
|
||||
}
|
||||
|
||||
vtparams := VN.TypeParams()
|
||||
ttparams := TN.TypeParams()
|
||||
if vtparams.Len() == 0 || vtparams.Len() != ttparams.Len() || VN.TypeArgs().Len() != 0 || TN.TypeArgs().Len() != 0 {
|
||||
return types.AssignableTo(V, T)
|
||||
}
|
||||
|
||||
// V and T have the same (non-zero) number of type params. Instantiate both
|
||||
// with the type parameters of V. This must always succeed for V, and will
|
||||
// succeed for T if and only if the type set of each type parameter of V is a
|
||||
// subset of the type set of the corresponding type parameter of T, meaning
|
||||
// that every instantiation of V corresponds to a valid instantiation of T.
|
||||
|
||||
// Minor optimization: ensure we share a context across the two
|
||||
// instantiations below.
|
||||
if ctxt == nil {
|
||||
ctxt = types.NewContext()
|
||||
}
|
||||
|
||||
var targs []types.Type
|
||||
for i := 0; i < vtparams.Len(); i++ {
|
||||
targs = append(targs, vtparams.At(i))
|
||||
}
|
||||
|
||||
vinst, err := types.Instantiate(ctxt, V, targs, true)
|
||||
if err != nil {
|
||||
panic("type parameters should satisfy their own constraints")
|
||||
}
|
||||
|
||||
tinst, err := types.Instantiate(ctxt, T, targs, true)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return types.AssignableTo(vinst, tinst)
|
||||
}
|
|
@ -1,137 +0,0 @@
|
|||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package typeparams
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/types"
|
||||
|
||||
"golang.org/x/tools/internal/aliases"
|
||||
)
|
||||
|
||||
// CoreType returns the core type of T or nil if T does not have a core type.
|
||||
//
|
||||
// See https://go.dev/ref/spec#Core_types for the definition of a core type.
|
||||
func CoreType(T types.Type) types.Type {
|
||||
U := T.Underlying()
|
||||
if _, ok := U.(*types.Interface); !ok {
|
||||
return U // for non-interface types,
|
||||
}
|
||||
|
||||
terms, err := _NormalTerms(U)
|
||||
if len(terms) == 0 || err != nil {
|
||||
// len(terms) -> empty type set of interface.
|
||||
// err != nil => U is invalid, exceeds complexity bounds, or has an empty type set.
|
||||
return nil // no core type.
|
||||
}
|
||||
|
||||
U = terms[0].Type().Underlying()
|
||||
var identical int // i in [0,identical) => Identical(U, terms[i].Type().Underlying())
|
||||
for identical = 1; identical < len(terms); identical++ {
|
||||
if !types.Identical(U, terms[identical].Type().Underlying()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if identical == len(terms) {
|
||||
// https://go.dev/ref/spec#Core_types
|
||||
// "There is a single type U which is the underlying type of all types in the type set of T"
|
||||
return U
|
||||
}
|
||||
ch, ok := U.(*types.Chan)
|
||||
if !ok {
|
||||
return nil // no core type as identical < len(terms) and U is not a channel.
|
||||
}
|
||||
// https://go.dev/ref/spec#Core_types
|
||||
// "the type chan E if T contains only bidirectional channels, or the type chan<- E or
|
||||
// <-chan E depending on the direction of the directional channels present."
|
||||
for chans := identical; chans < len(terms); chans++ {
|
||||
curr, ok := terms[chans].Type().Underlying().(*types.Chan)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if !types.Identical(ch.Elem(), curr.Elem()) {
|
||||
return nil // channel elements are not identical.
|
||||
}
|
||||
if ch.Dir() == types.SendRecv {
|
||||
// ch is bidirectional. We can safely always use curr's direction.
|
||||
ch = curr
|
||||
} else if curr.Dir() != types.SendRecv && ch.Dir() != curr.Dir() {
|
||||
// ch and curr are not bidirectional and not the same direction.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
// _NormalTerms returns a slice of terms representing the normalized structural
|
||||
// type restrictions of a type, if any.
|
||||
//
|
||||
// For all types other than *types.TypeParam, *types.Interface, and
|
||||
// *types.Union, this is just a single term with Tilde() == false and
|
||||
// Type() == typ. For *types.TypeParam, *types.Interface, and *types.Union, see
|
||||
// below.
|
||||
//
|
||||
// Structural type restrictions of a type parameter are created via
|
||||
// non-interface types embedded in its constraint interface (directly, or via a
|
||||
// chain of interface embeddings). For example, in the declaration type
|
||||
// T[P interface{~int; m()}] int the structural restriction of the type
|
||||
// parameter P is ~int.
|
||||
//
|
||||
// With interface embedding and unions, the specification of structural type
|
||||
// restrictions may be arbitrarily complex. For example, consider the
|
||||
// following:
|
||||
//
|
||||
// type A interface{ ~string|~[]byte }
|
||||
//
|
||||
// type B interface{ int|string }
|
||||
//
|
||||
// type C interface { ~string|~int }
|
||||
//
|
||||
// type T[P interface{ A|B; C }] int
|
||||
//
|
||||
// In this example, the structural type restriction of P is ~string|int: A|B
|
||||
// expands to ~string|~[]byte|int|string, which reduces to ~string|~[]byte|int,
|
||||
// which when intersected with C (~string|~int) yields ~string|int.
|
||||
//
|
||||
// _NormalTerms computes these expansions and reductions, producing a
|
||||
// "normalized" form of the embeddings. A structural restriction is normalized
|
||||
// if it is a single union containing no interface terms, and is minimal in the
|
||||
// sense that removing any term changes the set of types satisfying the
|
||||
// constraint. It is left as a proof for the reader that, modulo sorting, there
|
||||
// is exactly one such normalized form.
|
||||
//
|
||||
// Because the minimal representation always takes this form, _NormalTerms
|
||||
// returns a slice of tilde terms corresponding to the terms of the union in
|
||||
// the normalized structural restriction. An error is returned if the type is
|
||||
// invalid, exceeds complexity bounds, or has an empty type set. In the latter
|
||||
// case, _NormalTerms returns ErrEmptyTypeSet.
|
||||
//
|
||||
// _NormalTerms makes no guarantees about the order of terms, except that it
|
||||
// is deterministic.
|
||||
func _NormalTerms(typ types.Type) ([]*types.Term, error) {
|
||||
switch typ := aliases.Unalias(typ).(type) {
|
||||
case *types.TypeParam:
|
||||
return StructuralTerms(typ)
|
||||
case *types.Union:
|
||||
return UnionTermSet(typ)
|
||||
case *types.Interface:
|
||||
return InterfaceTermSet(typ)
|
||||
default:
|
||||
return []*types.Term{types.NewTerm(false, typ)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// MustDeref returns the type of the variable pointed to by t.
|
||||
// It panics if t's core type is not a pointer.
|
||||
//
|
||||
// TODO(adonovan): ideally this would live in typesinternal, but that
|
||||
// creates an import cycle. Move there when we melt this package down.
|
||||
func MustDeref(t types.Type) types.Type {
|
||||
if ptr, ok := CoreType(t).(*types.Pointer); ok {
|
||||
return ptr.Elem()
|
||||
}
|
||||
panic(fmt.Sprintf("%v is not a pointer", t))
|
||||
}
|
|
@ -1,218 +0,0 @@
|
|||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package typeparams
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/types"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:generate go run copytermlist.go
|
||||
|
||||
const debug = false
|
||||
|
||||
var ErrEmptyTypeSet = errors.New("empty type set")
|
||||
|
||||
// StructuralTerms returns a slice of terms representing the normalized
|
||||
// structural type restrictions of a type parameter, if any.
|
||||
//
|
||||
// Structural type restrictions of a type parameter are created via
|
||||
// non-interface types embedded in its constraint interface (directly, or via a
|
||||
// chain of interface embeddings). For example, in the declaration
|
||||
//
|
||||
// type T[P interface{~int; m()}] int
|
||||
//
|
||||
// the structural restriction of the type parameter P is ~int.
|
||||
//
|
||||
// With interface embedding and unions, the specification of structural type
|
||||
// restrictions may be arbitrarily complex. For example, consider the
|
||||
// following:
|
||||
//
|
||||
// type A interface{ ~string|~[]byte }
|
||||
//
|
||||
// type B interface{ int|string }
|
||||
//
|
||||
// type C interface { ~string|~int }
|
||||
//
|
||||
// type T[P interface{ A|B; C }] int
|
||||
//
|
||||
// In this example, the structural type restriction of P is ~string|int: A|B
|
||||
// expands to ~string|~[]byte|int|string, which reduces to ~string|~[]byte|int,
|
||||
// which when intersected with C (~string|~int) yields ~string|int.
|
||||
//
|
||||
// StructuralTerms computes these expansions and reductions, producing a
|
||||
// "normalized" form of the embeddings. A structural restriction is normalized
|
||||
// if it is a single union containing no interface terms, and is minimal in the
|
||||
// sense that removing any term changes the set of types satisfying the
|
||||
// constraint. It is left as a proof for the reader that, modulo sorting, there
|
||||
// is exactly one such normalized form.
|
||||
//
|
||||
// Because the minimal representation always takes this form, StructuralTerms
|
||||
// returns a slice of tilde terms corresponding to the terms of the union in
|
||||
// the normalized structural restriction. An error is returned if the
|
||||
// constraint interface is invalid, exceeds complexity bounds, or has an empty
|
||||
// type set. In the latter case, StructuralTerms returns ErrEmptyTypeSet.
|
||||
//
|
||||
// StructuralTerms makes no guarantees about the order of terms, except that it
|
||||
// is deterministic.
|
||||
func StructuralTerms(tparam *types.TypeParam) ([]*types.Term, error) {
|
||||
constraint := tparam.Constraint()
|
||||
if constraint == nil {
|
||||
return nil, fmt.Errorf("%s has nil constraint", tparam)
|
||||
}
|
||||
iface, _ := constraint.Underlying().(*types.Interface)
|
||||
if iface == nil {
|
||||
return nil, fmt.Errorf("constraint is %T, not *types.Interface", constraint.Underlying())
|
||||
}
|
||||
return InterfaceTermSet(iface)
|
||||
}
|
||||
|
||||
// InterfaceTermSet computes the normalized terms for a constraint interface,
|
||||
// returning an error if the term set cannot be computed or is empty. In the
|
||||
// latter case, the error will be ErrEmptyTypeSet.
|
||||
//
|
||||
// See the documentation of StructuralTerms for more information on
|
||||
// normalization.
|
||||
func InterfaceTermSet(iface *types.Interface) ([]*types.Term, error) {
|
||||
return computeTermSet(iface)
|
||||
}
|
||||
|
||||
// UnionTermSet computes the normalized terms for a union, returning an error
|
||||
// if the term set cannot be computed or is empty. In the latter case, the
|
||||
// error will be ErrEmptyTypeSet.
|
||||
//
|
||||
// See the documentation of StructuralTerms for more information on
|
||||
// normalization.
|
||||
func UnionTermSet(union *types.Union) ([]*types.Term, error) {
|
||||
return computeTermSet(union)
|
||||
}
|
||||
|
||||
func computeTermSet(typ types.Type) ([]*types.Term, error) {
|
||||
tset, err := computeTermSetInternal(typ, make(map[types.Type]*termSet), 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tset.terms.isEmpty() {
|
||||
return nil, ErrEmptyTypeSet
|
||||
}
|
||||
if tset.terms.isAll() {
|
||||
return nil, nil
|
||||
}
|
||||
var terms []*types.Term
|
||||
for _, term := range tset.terms {
|
||||
terms = append(terms, types.NewTerm(term.tilde, term.typ))
|
||||
}
|
||||
return terms, nil
|
||||
}
|
||||
|
||||
// A termSet holds the normalized set of terms for a given type.
|
||||
//
|
||||
// The name termSet is intentionally distinct from 'type set': a type set is
|
||||
// all types that implement a type (and includes method restrictions), whereas
|
||||
// a term set just represents the structural restrictions on a type.
|
||||
type termSet struct {
|
||||
complete bool
|
||||
terms termlist
|
||||
}
|
||||
|
||||
func indentf(depth int, format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, strings.Repeat(".", depth)+format+"\n", args...)
|
||||
}
|
||||
|
||||
func computeTermSetInternal(t types.Type, seen map[types.Type]*termSet, depth int) (res *termSet, err error) {
|
||||
if t == nil {
|
||||
panic("nil type")
|
||||
}
|
||||
|
||||
if debug {
|
||||
indentf(depth, "%s", t.String())
|
||||
defer func() {
|
||||
if err != nil {
|
||||
indentf(depth, "=> %s", err)
|
||||
} else {
|
||||
indentf(depth, "=> %s", res.terms.String())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
const maxTermCount = 100
|
||||
if tset, ok := seen[t]; ok {
|
||||
if !tset.complete {
|
||||
return nil, fmt.Errorf("cycle detected in the declaration of %s", t)
|
||||
}
|
||||
return tset, nil
|
||||
}
|
||||
|
||||
// Mark the current type as seen to avoid infinite recursion.
|
||||
tset := new(termSet)
|
||||
defer func() {
|
||||
tset.complete = true
|
||||
}()
|
||||
seen[t] = tset
|
||||
|
||||
switch u := t.Underlying().(type) {
|
||||
case *types.Interface:
|
||||
// The term set of an interface is the intersection of the term sets of its
|
||||
// embedded types.
|
||||
tset.terms = allTermlist
|
||||
for i := 0; i < u.NumEmbeddeds(); i++ {
|
||||
embedded := u.EmbeddedType(i)
|
||||
if _, ok := embedded.Underlying().(*types.TypeParam); ok {
|
||||
return nil, fmt.Errorf("invalid embedded type %T", embedded)
|
||||
}
|
||||
tset2, err := computeTermSetInternal(embedded, seen, depth+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tset.terms = tset.terms.intersect(tset2.terms)
|
||||
}
|
||||
case *types.Union:
|
||||
// The term set of a union is the union of term sets of its terms.
|
||||
tset.terms = nil
|
||||
for i := 0; i < u.Len(); i++ {
|
||||
t := u.Term(i)
|
||||
var terms termlist
|
||||
switch t.Type().Underlying().(type) {
|
||||
case *types.Interface:
|
||||
tset2, err := computeTermSetInternal(t.Type(), seen, depth+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
terms = tset2.terms
|
||||
case *types.TypeParam, *types.Union:
|
||||
// A stand-alone type parameter or union is not permitted as union
|
||||
// term.
|
||||
return nil, fmt.Errorf("invalid union term %T", t)
|
||||
default:
|
||||
if t.Type() == types.Typ[types.Invalid] {
|
||||
continue
|
||||
}
|
||||
terms = termlist{{t.Tilde(), t.Type()}}
|
||||
}
|
||||
tset.terms = tset.terms.union(terms)
|
||||
if len(tset.terms) > maxTermCount {
|
||||
return nil, fmt.Errorf("exceeded max term count %d", maxTermCount)
|
||||
}
|
||||
}
|
||||
case *types.TypeParam:
|
||||
panic("unreachable")
|
||||
default:
|
||||
// For all other types, the term set is just a single non-tilde term
|
||||
// holding the type itself.
|
||||
if u != types.Typ[types.Invalid] {
|
||||
tset.terms = termlist{{false, t}}
|
||||
}
|
||||
}
|
||||
return tset, nil
|
||||
}
|
||||
|
||||
// under is a facade for the go/types internal function of the same name. It is
|
||||
// used by typeterm.go.
|
||||
func under(t types.Type) types.Type {
|
||||
return t.Underlying()
|
||||
}
|
|
@ -1,163 +0,0 @@
|
|||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Code generated by copytermlist.go DO NOT EDIT.
|
||||
|
||||
package typeparams
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"go/types"
|
||||
)
|
||||
|
||||
// A termlist represents the type set represented by the union
|
||||
// t1 ∪ y2 ∪ ... tn of the type sets of the terms t1 to tn.
|
||||
// A termlist is in normal form if all terms are disjoint.
|
||||
// termlist operations don't require the operands to be in
|
||||
// normal form.
|
||||
type termlist []*term
|
||||
|
||||
// allTermlist represents the set of all types.
|
||||
// It is in normal form.
|
||||
var allTermlist = termlist{new(term)}
|
||||
|
||||
// String prints the termlist exactly (without normalization).
|
||||
func (xl termlist) String() string {
|
||||
if len(xl) == 0 {
|
||||
return "∅"
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
for i, x := range xl {
|
||||
if i > 0 {
|
||||
buf.WriteString(" | ")
|
||||
}
|
||||
buf.WriteString(x.String())
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// isEmpty reports whether the termlist xl represents the empty set of types.
|
||||
func (xl termlist) isEmpty() bool {
|
||||
// If there's a non-nil term, the entire list is not empty.
|
||||
// If the termlist is in normal form, this requires at most
|
||||
// one iteration.
|
||||
for _, x := range xl {
|
||||
if x != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isAll reports whether the termlist xl represents the set of all types.
|
||||
func (xl termlist) isAll() bool {
|
||||
// If there's a 𝓤 term, the entire list is 𝓤.
|
||||
// If the termlist is in normal form, this requires at most
|
||||
// one iteration.
|
||||
for _, x := range xl {
|
||||
if x != nil && x.typ == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// norm returns the normal form of xl.
|
||||
func (xl termlist) norm() termlist {
|
||||
// Quadratic algorithm, but good enough for now.
|
||||
// TODO(gri) fix asymptotic performance
|
||||
used := make([]bool, len(xl))
|
||||
var rl termlist
|
||||
for i, xi := range xl {
|
||||
if xi == nil || used[i] {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j < len(xl); j++ {
|
||||
xj := xl[j]
|
||||
if xj == nil || used[j] {
|
||||
continue
|
||||
}
|
||||
if u1, u2 := xi.union(xj); u2 == nil {
|
||||
// If we encounter a 𝓤 term, the entire list is 𝓤.
|
||||
// Exit early.
|
||||
// (Note that this is not just an optimization;
|
||||
// if we continue, we may end up with a 𝓤 term
|
||||
// and other terms and the result would not be
|
||||
// in normal form.)
|
||||
if u1.typ == nil {
|
||||
return allTermlist
|
||||
}
|
||||
xi = u1
|
||||
used[j] = true // xj is now unioned into xi - ignore it in future iterations
|
||||
}
|
||||
}
|
||||
rl = append(rl, xi)
|
||||
}
|
||||
return rl
|
||||
}
|
||||
|
||||
// union returns the union xl ∪ yl.
|
||||
func (xl termlist) union(yl termlist) termlist {
|
||||
return append(xl, yl...).norm()
|
||||
}
|
||||
|
||||
// intersect returns the intersection xl ∩ yl.
|
||||
func (xl termlist) intersect(yl termlist) termlist {
|
||||
if xl.isEmpty() || yl.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Quadratic algorithm, but good enough for now.
|
||||
// TODO(gri) fix asymptotic performance
|
||||
var rl termlist
|
||||
for _, x := range xl {
|
||||
for _, y := range yl {
|
||||
if r := x.intersect(y); r != nil {
|
||||
rl = append(rl, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
return rl.norm()
|
||||
}
|
||||
|
||||
// equal reports whether xl and yl represent the same type set.
|
||||
func (xl termlist) equal(yl termlist) bool {
|
||||
// TODO(gri) this should be more efficient
|
||||
return xl.subsetOf(yl) && yl.subsetOf(xl)
|
||||
}
|
||||
|
||||
// includes reports whether t ∈ xl.
|
||||
func (xl termlist) includes(t types.Type) bool {
|
||||
for _, x := range xl {
|
||||
if x.includes(t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// supersetOf reports whether y ⊆ xl.
|
||||
func (xl termlist) supersetOf(y *term) bool {
|
||||
for _, x := range xl {
|
||||
if y.subsetOf(x) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// subsetOf reports whether xl ⊆ yl.
|
||||
func (xl termlist) subsetOf(yl termlist) bool {
|
||||
if yl.isEmpty() {
|
||||
return xl.isEmpty()
|
||||
}
|
||||
|
||||
// each term x of xl must be a subset of yl
|
||||
for _, x := range xl {
|
||||
if !yl.supersetOf(x) {
|
||||
return false // x is not a subset yl
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -1,169 +0,0 @@
|
|||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Code generated by copytermlist.go DO NOT EDIT.
|
||||
|
||||
package typeparams
|
||||
|
||||
import "go/types"
|
||||
|
||||
// A term describes elementary type sets:
|
||||
//
|
||||
// ∅: (*term)(nil) == ∅ // set of no types (empty set)
|
||||
// 𝓤: &term{} == 𝓤 // set of all types (𝓤niverse)
|
||||
// T: &term{false, T} == {T} // set of type T
|
||||
// ~t: &term{true, t} == {t' | under(t') == t} // set of types with underlying type t
|
||||
type term struct {
|
||||
tilde bool // valid if typ != nil
|
||||
typ types.Type
|
||||
}
|
||||
|
||||
func (x *term) String() string {
|
||||
switch {
|
||||
case x == nil:
|
||||
return "∅"
|
||||
case x.typ == nil:
|
||||
return "𝓤"
|
||||
case x.tilde:
|
||||
return "~" + x.typ.String()
|
||||
default:
|
||||
return x.typ.String()
|
||||
}
|
||||
}
|
||||
|
||||
// equal reports whether x and y represent the same type set.
|
||||
func (x *term) equal(y *term) bool {
|
||||
// easy cases
|
||||
switch {
|
||||
case x == nil || y == nil:
|
||||
return x == y
|
||||
case x.typ == nil || y.typ == nil:
|
||||
return x.typ == y.typ
|
||||
}
|
||||
// ∅ ⊂ x, y ⊂ 𝓤
|
||||
|
||||
return x.tilde == y.tilde && types.Identical(x.typ, y.typ)
|
||||
}
|
||||
|
||||
// union returns the union x ∪ y: zero, one, or two non-nil terms.
|
||||
func (x *term) union(y *term) (_, _ *term) {
|
||||
// easy cases
|
||||
switch {
|
||||
case x == nil && y == nil:
|
||||
return nil, nil // ∅ ∪ ∅ == ∅
|
||||
case x == nil:
|
||||
return y, nil // ∅ ∪ y == y
|
||||
case y == nil:
|
||||
return x, nil // x ∪ ∅ == x
|
||||
case x.typ == nil:
|
||||
return x, nil // 𝓤 ∪ y == 𝓤
|
||||
case y.typ == nil:
|
||||
return y, nil // x ∪ 𝓤 == 𝓤
|
||||
}
|
||||
// ∅ ⊂ x, y ⊂ 𝓤
|
||||
|
||||
if x.disjoint(y) {
|
||||
return x, y // x ∪ y == (x, y) if x ∩ y == ∅
|
||||
}
|
||||
// x.typ == y.typ
|
||||
|
||||
// ~t ∪ ~t == ~t
|
||||
// ~t ∪ T == ~t
|
||||
// T ∪ ~t == ~t
|
||||
// T ∪ T == T
|
||||
if x.tilde || !y.tilde {
|
||||
return x, nil
|
||||
}
|
||||
return y, nil
|
||||
}
|
||||
|
||||
// intersect returns the intersection x ∩ y.
|
||||
func (x *term) intersect(y *term) *term {
|
||||
// easy cases
|
||||
switch {
|
||||
case x == nil || y == nil:
|
||||
return nil // ∅ ∩ y == ∅ and ∩ ∅ == ∅
|
||||
case x.typ == nil:
|
||||
return y // 𝓤 ∩ y == y
|
||||
case y.typ == nil:
|
||||
return x // x ∩ 𝓤 == x
|
||||
}
|
||||
// ∅ ⊂ x, y ⊂ 𝓤
|
||||
|
||||
if x.disjoint(y) {
|
||||
return nil // x ∩ y == ∅ if x ∩ y == ∅
|
||||
}
|
||||
// x.typ == y.typ
|
||||
|
||||
// ~t ∩ ~t == ~t
|
||||
// ~t ∩ T == T
|
||||
// T ∩ ~t == T
|
||||
// T ∩ T == T
|
||||
if !x.tilde || y.tilde {
|
||||
return x
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
// includes reports whether t ∈ x.
|
||||
func (x *term) includes(t types.Type) bool {
|
||||
// easy cases
|
||||
switch {
|
||||
case x == nil:
|
||||
return false // t ∈ ∅ == false
|
||||
case x.typ == nil:
|
||||
return true // t ∈ 𝓤 == true
|
||||
}
|
||||
// ∅ ⊂ x ⊂ 𝓤
|
||||
|
||||
u := t
|
||||
if x.tilde {
|
||||
u = under(u)
|
||||
}
|
||||
return types.Identical(x.typ, u)
|
||||
}
|
||||
|
||||
// subsetOf reports whether x ⊆ y.
|
||||
func (x *term) subsetOf(y *term) bool {
|
||||
// easy cases
|
||||
switch {
|
||||
case x == nil:
|
||||
return true // ∅ ⊆ y == true
|
||||
case y == nil:
|
||||
return false // x ⊆ ∅ == false since x != ∅
|
||||
case y.typ == nil:
|
||||
return true // x ⊆ 𝓤 == true
|
||||
case x.typ == nil:
|
||||
return false // 𝓤 ⊆ y == false since y != 𝓤
|
||||
}
|
||||
// ∅ ⊂ x, y ⊂ 𝓤
|
||||
|
||||
if x.disjoint(y) {
|
||||
return false // x ⊆ y == false if x ∩ y == ∅
|
||||
}
|
||||
// x.typ == y.typ
|
||||
|
||||
// ~t ⊆ ~t == true
|
||||
// ~t ⊆ T == false
|
||||
// T ⊆ ~t == true
|
||||
// T ⊆ T == true
|
||||
return !x.tilde || y.tilde
|
||||
}
|
||||
|
||||
// disjoint reports whether x ∩ y == ∅.
|
||||
// x.typ and y.typ must not be nil.
|
||||
func (x *term) disjoint(y *term) bool {
|
||||
if debug && (x.typ == nil || y.typ == nil) {
|
||||
panic("invalid argument(s)")
|
||||
}
|
||||
ux := x.typ
|
||||
if y.tilde {
|
||||
ux = under(ux)
|
||||
}
|
||||
uy := y.typ
|
||||
if x.tilde {
|
||||
uy = under(uy)
|
||||
}
|
||||
return !types.Identical(ux, uy)
|
||||
}
|
|
@ -167,7 +167,7 @@ const (
|
|||
UntypedNilUse
|
||||
|
||||
// WrongAssignCount occurs when the number of values on the right-hand side
|
||||
// of an assignment or or initialization expression does not match the number
|
||||
// of an assignment or initialization expression does not match the number
|
||||
// of variables on the left-hand side.
|
||||
//
|
||||
// Example:
|
||||
|
@ -1449,10 +1449,10 @@ const (
|
|||
NotAGenericType
|
||||
|
||||
// WrongTypeArgCount occurs when a type or function is instantiated with an
|
||||
// incorrent number of type arguments, including when a generic type or
|
||||
// incorrect number of type arguments, including when a generic type or
|
||||
// function is used without instantiation.
|
||||
//
|
||||
// Errors inolving failed type inference are assigned other error codes.
|
||||
// Errors involving failed type inference are assigned other error codes.
|
||||
//
|
||||
// Example:
|
||||
// type T[p any] int
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright 2024 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package typesinternal
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
|
||||
"golang.org/x/tools/internal/stdlib"
|
||||
"golang.org/x/tools/internal/versions"
|
||||
)
|
||||
|
||||
// TooNewStdSymbols computes the set of package-level symbols
|
||||
// exported by pkg that are not available at the specified version.
|
||||
// The result maps each symbol to its minimum version.
|
||||
//
|
||||
// The pkg is allowed to contain type errors.
|
||||
func TooNewStdSymbols(pkg *types.Package, version string) map[types.Object]string {
|
||||
disallowed := make(map[types.Object]string)
|
||||
|
||||
// Pass 1: package-level symbols.
|
||||
symbols := stdlib.PackageSymbols[pkg.Path()]
|
||||
for _, sym := range symbols {
|
||||
symver := sym.Version.String()
|
||||
if versions.Before(version, symver) {
|
||||
switch sym.Kind {
|
||||
case stdlib.Func, stdlib.Var, stdlib.Const, stdlib.Type:
|
||||
disallowed[pkg.Scope().Lookup(sym.Name)] = symver
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: fields and methods.
|
||||
//
|
||||
// We allow fields and methods if their associated type is
|
||||
// disallowed, as otherwise we would report false positives
|
||||
// for compatibility shims. Consider:
|
||||
//
|
||||
// //go:build go1.22
|
||||
// type T struct { F std.Real } // correct new API
|
||||
//
|
||||
// //go:build !go1.22
|
||||
// type T struct { F fake } // shim
|
||||
// type fake struct { ... }
|
||||
// func (fake) M () {}
|
||||
//
|
||||
// These alternative declarations of T use either the std.Real
|
||||
// type, introduced in go1.22, or a fake type, for the field
|
||||
// F. (The fakery could be arbitrarily deep, involving more
|
||||
// nested fields and methods than are shown here.) Clients
|
||||
// that use the compatibility shim T will compile with any
|
||||
// version of go, whether older or newer than go1.22, but only
|
||||
// the newer version will use the std.Real implementation.
|
||||
//
|
||||
// Now consider a reference to method M in new(T).F.M() in a
|
||||
// module that requires a minimum of go1.21. The analysis may
|
||||
// occur using a version of Go higher than 1.21, selecting the
|
||||
// first version of T, so the method M is Real.M. This would
|
||||
// spuriously cause the analyzer to report a reference to a
|
||||
// too-new symbol even though this expression compiles just
|
||||
// fine (with the fake implementation) using go1.21.
|
||||
for _, sym := range symbols {
|
||||
symVersion := sym.Version.String()
|
||||
if !versions.Before(version, symVersion) {
|
||||
continue // allowed
|
||||
}
|
||||
|
||||
var obj types.Object
|
||||
switch sym.Kind {
|
||||
case stdlib.Field:
|
||||
typename, name := sym.SplitField()
|
||||
if t := pkg.Scope().Lookup(typename); t != nil && disallowed[t] == "" {
|
||||
obj, _, _ = types.LookupFieldOrMethod(t.Type(), false, pkg, name)
|
||||
}
|
||||
|
||||
case stdlib.Method:
|
||||
ptr, recvname, name := sym.SplitMethod()
|
||||
if t := pkg.Scope().Lookup(recvname); t != nil && disallowed[t] == "" {
|
||||
obj, _, _ = types.LookupFieldOrMethod(t.Type(), ptr, pkg, name)
|
||||
}
|
||||
}
|
||||
if obj != nil {
|
||||
disallowed[obj] = symVersion
|
||||
}
|
||||
}
|
||||
|
||||
return disallowed
|
||||
}
|
|
@ -48,5 +48,3 @@ func ReadGo116ErrorData(err types.Error) (code ErrorCode, start, end token.Pos,
|
|||
}
|
||||
return ErrorCode(data[0]), token.Pos(data[1]), token.Pos(data[2]), true
|
||||
}
|
||||
|
||||
var SetGoVersion = func(conf *types.Config, version string) bool { return false }
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
// Copyright 2021 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package typesinternal
|
||||
|
||||
import (
|
||||
"go/types"
|
||||
)
|
||||
|
||||
func init() {
|
||||
SetGoVersion = func(conf *types.Config, version string) bool {
|
||||
conf.GoVersion = version
|
||||
return true
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
# github.com/BurntSushi/toml v1.3.2
|
||||
## explicit; go 1.16
|
||||
# github.com/BurntSushi/toml v1.4.0
|
||||
## explicit; go 1.18
|
||||
github.com/BurntSushi/toml
|
||||
github.com/BurntSushi/toml/internal
|
||||
# github.com/VividCortex/ewma v1.2.0
|
||||
|
@ -104,7 +104,7 @@ github.com/powerman/deepequal
|
|||
# github.com/quic-go/qpack v0.4.0
|
||||
## explicit; go 1.18
|
||||
github.com/quic-go/qpack
|
||||
# github.com/quic-go/quic-go v0.43.1
|
||||
# github.com/quic-go/quic-go v0.44.0
|
||||
## explicit; go 1.21
|
||||
github.com/quic-go/quic-go
|
||||
github.com/quic-go/quic-go/http3
|
||||
|
@ -145,10 +145,10 @@ golang.org/x/crypto/nacl/box
|
|||
golang.org/x/crypto/nacl/secretbox
|
||||
golang.org/x/crypto/poly1305
|
||||
golang.org/x/crypto/salsa20/salsa
|
||||
# golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
|
||||
## explicit; go 1.18
|
||||
# golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
## explicit; go 1.20
|
||||
golang.org/x/exp/rand
|
||||
# golang.org/x/mod v0.16.0
|
||||
# golang.org/x/mod v0.17.0
|
||||
## explicit; go 1.18
|
||||
golang.org/x/mod/internal/lazyregexp
|
||||
golang.org/x/mod/modfile
|
||||
|
@ -167,6 +167,9 @@ golang.org/x/net/internal/socks
|
|||
golang.org/x/net/ipv4
|
||||
golang.org/x/net/ipv6
|
||||
golang.org/x/net/proxy
|
||||
# golang.org/x/sync v0.7.0
|
||||
## explicit; go 1.18
|
||||
golang.org/x/sync/errgroup
|
||||
# golang.org/x/sys v0.20.0
|
||||
## explicit; go 1.18
|
||||
golang.org/x/sys/cpu
|
||||
|
@ -182,7 +185,7 @@ golang.org/x/text/secure/bidirule
|
|||
golang.org/x/text/transform
|
||||
golang.org/x/text/unicode/bidi
|
||||
golang.org/x/text/unicode/norm
|
||||
# golang.org/x/tools v0.19.0
|
||||
# golang.org/x/tools v0.21.0
|
||||
## explicit; go 1.19
|
||||
golang.org/x/tools/go/ast/astutil
|
||||
golang.org/x/tools/go/ast/inspector
|
||||
|
@ -196,15 +199,14 @@ golang.org/x/tools/internal/event
|
|||
golang.org/x/tools/internal/event/core
|
||||
golang.org/x/tools/internal/event/keys
|
||||
golang.org/x/tools/internal/event/label
|
||||
golang.org/x/tools/internal/event/tag
|
||||
golang.org/x/tools/internal/gcimporter
|
||||
golang.org/x/tools/internal/gocommand
|
||||
golang.org/x/tools/internal/gopathwalk
|
||||
golang.org/x/tools/internal/imports
|
||||
golang.org/x/tools/internal/packagesinternal
|
||||
golang.org/x/tools/internal/pkgbits
|
||||
golang.org/x/tools/internal/stdlib
|
||||
golang.org/x/tools/internal/tokeninternal
|
||||
golang.org/x/tools/internal/typeparams
|
||||
golang.org/x/tools/internal/typesinternal
|
||||
golang.org/x/tools/internal/versions
|
||||
# google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f
|
||||
|
|
Loading…
Reference in New Issue