Upgrade quic-go to v0.36.1

quic-go has made breaking changes since v0.35.0, includes implementing
`CloseIdleConnections`.
Now, the local listener UDPConn are reused, and don't pile up. But,
1 instance (IPv4/IPv6) persists for each connected server.
This commit is contained in:
YX Hao 2023-07-03 18:37:42 +08:00
parent 16b2c84147
commit 89ccc59f0e
112 changed files with 5632 additions and 2620 deletions

View File

@ -57,7 +57,6 @@ type AltSupport struct {
type XTransport struct { type XTransport struct {
transport *http.Transport transport *http.Transport
h3Transport *http3.RoundTripper h3Transport *http3.RoundTripper
h3UDPConn *net.UDPConn
keepAlive time.Duration keepAlive time.Duration
timeout time.Duration timeout time.Duration
cachedIPs CachedIPs cachedIPs CachedIPs
@ -138,15 +137,7 @@ func (xTransport *XTransport) loadCachedIP(host string) (ip net.IP, expired bool
func (xTransport *XTransport) rebuildTransport() { func (xTransport *XTransport) rebuildTransport() {
dlog.Debug("Rebuilding transport") dlog.Debug("Rebuilding transport")
if xTransport.transport != nil { if xTransport.transport != nil {
if xTransport.h3Transport != nil { xTransport.transport.CloseIdleConnections()
xTransport.h3Transport.Close()
if xTransport.h3UDPConn != nil {
xTransport.h3UDPConn.Close()
xTransport.h3UDPConn = nil
}
} else {
xTransport.transport.CloseIdleConnections()
}
} }
timeout := xTransport.timeout timeout := xTransport.timeout
transport := &http.Transport{ transport := &http.Transport{
@ -272,27 +263,35 @@ func (xTransport *XTransport) rebuildTransport() {
host, port := ExtractHostAndPort(addrStr, stamps.DefaultPort) host, port := ExtractHostAndPort(addrStr, stamps.DefaultPort)
ipOnly := host ipOnly := host
cachedIP, _ := xTransport.loadCachedIP(host) cachedIP, _ := xTransport.loadCachedIP(host)
network := "udp4"
if cachedIP != nil { if cachedIP != nil {
if ipv4 := cachedIP.To4(); ipv4 != nil { if ipv4 := cachedIP.To4(); ipv4 != nil {
ipOnly = ipv4.String() ipOnly = ipv4.String()
} else { } else {
ipOnly = "[" + cachedIP.String() + "]" ipOnly = "[" + cachedIP.String() + "]"
network = "udp6"
} }
} else { } else {
dlog.Debugf("[%s] IP address was not cached in H3 DialContext", host) dlog.Debugf("[%s] IP address was not cached in H3 context", host)
if xTransport.useIPv6 {
if xTransport.useIPv4 {
network = "udp"
} else {
network = "udp6"
}
}
} }
addrStr = ipOnly + ":" + strconv.Itoa(port) addrStr = ipOnly + ":" + strconv.Itoa(port)
udpAddr, err := net.ResolveUDPAddr("udp", addrStr) udpAddr, err := net.ResolveUDPAddr(network, addrStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if xTransport.h3UDPConn == nil { udpConn, err := net.ListenUDP(network, nil)
xTransport.h3UDPConn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil {
if err != nil { return nil, err
return nil, err
}
} }
return quic.DialEarlyContext(ctx, xTransport.h3UDPConn, udpAddr, host, tlsCfg, cfg) tlsCfg.ServerName = host
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
}} }}
xTransport.h3Transport = h3Transport xTransport.h3Transport = h3Transport
} }
@ -545,13 +544,8 @@ func (xTransport *XTransport) Fetch(
err = errors.New(resp.Status) err = errors.New(resp.Status)
} }
} else { } else {
if hasAltSupport { dlog.Debugf("HTTP client error: [%v] - closing idle connections", err)
dlog.Debugf("HTTP client error: [%v] - closing H3 connections", err) xTransport.transport.CloseIdleConnections()
xTransport.h3Transport.Close()
} else {
dlog.Debugf("HTTP client error: [%v] - closing idle connections", err)
xTransport.transport.CloseIdleConnections()
}
} }
statusCode := 503 statusCode := 503
if resp != nil { if resp != nil {

12
go.mod
View File

@ -20,7 +20,7 @@ require (
github.com/kardianos/service v1.2.2 github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.55 github.com/miekg/dns v1.1.55
github.com/powerman/check v1.7.0 github.com/powerman/check v1.7.0
github.com/quic-go/quic-go v0.34.0 github.com/quic-go/quic-go v0.36.1
golang.org/x/crypto v0.10.0 golang.org/x/crypto v0.10.0
golang.org/x/net v0.11.0 golang.org/x/net v0.11.0
golang.org/x/sys v0.9.0 golang.org/x/sys v0.9.0
@ -29,12 +29,12 @@ require (
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/golang/mock v1.6.0 // indirect github.com/golang/mock v1.6.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/hashicorp/go-syslog v1.0.0 // indirect github.com/hashicorp/go-syslog v1.0.0 // indirect
github.com/onsi/ginkgo/v2 v2.2.0 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/powerman/deepequal v0.1.0 // indirect github.com/powerman/deepequal v0.1.0 // indirect
@ -43,9 +43,9 @@ require (
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
github.com/smartystreets/goconvey v1.7.2 // indirect github.com/smartystreets/goconvey v1.7.2 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.8.0 // indirect golang.org/x/mod v0.10.0 // indirect
golang.org/x/text v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect
golang.org/x/tools v0.6.0 // indirect golang.org/x/tools v0.9.1 // indirect
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect
google.golang.org/grpc v1.53.0 // indirect google.golang.org/grpc v1.53.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect

36
go.sum
View File

@ -12,13 +12,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185 h1:3T8ZyTDp5QxTx3NU48JVb2u+75xc040fofcBaN+6jPA= github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185 h1:3T8ZyTDp5QxTx3NU48JVb2u+75xc040fofcBaN+6jPA=
github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185/go.mod h1:cFRxtTwTOJkz2x3rQUNCYKWC93yP1VKjR8NUhqFxZNU= github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185/go.mod h1:cFRxtTwTOJkz2x3rQUNCYKWC93yP1VKjR8NUhqFxZNU=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
@ -56,9 +57,9 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo= github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@ -73,15 +74,15 @@ github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc8
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= github.com/quic-go/quic-go v0.36.1 h1:WsG73nVtnDy1TiACxFxhQ3TqaW+DipmqzLEtNlAwZyY=
github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.36.1/go.mod h1:zPetvwDlILVxt15n3hr3Gf/I3mDf7LpLKPhR4Ez0AZQ=
github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs= github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs=
github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg= github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM= github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -90,8 +91,8 @@ golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o=
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -100,7 +101,7 @@ golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -120,8 +121,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -137,6 +138,5 @@ google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"regexp" "regexp"
"strconv"
"strings" "strings"
) )
@ -50,6 +51,37 @@ func NewWithNoColorBool(noColor bool) Formatter {
} }
func New(colorMode ColorMode) Formatter { func New(colorMode ColorMode) Formatter {
colorAliases := map[string]int{
"black": 0,
"red": 1,
"green": 2,
"yellow": 3,
"blue": 4,
"magenta": 5,
"cyan": 6,
"white": 7,
}
for colorAlias, n := range colorAliases {
colorAliases[fmt.Sprintf("bright-%s", colorAlias)] = n + 8
}
getColor := func(color, defaultEscapeCode string) string {
color = strings.ToUpper(strings.ReplaceAll(color, "-", "_"))
envVar := fmt.Sprintf("GINKGO_CLI_COLOR_%s", color)
envVarColor := os.Getenv(envVar)
if envVarColor == "" {
return defaultEscapeCode
}
if colorCode, ok := colorAliases[envVarColor]; ok {
return fmt.Sprintf("\x1b[38;5;%dm", colorCode)
}
colorCode, err := strconv.Atoi(envVarColor)
if err != nil || colorCode < 0 || colorCode > 255 {
return defaultEscapeCode
}
return fmt.Sprintf("\x1b[38;5;%dm", colorCode)
}
f := Formatter{ f := Formatter{
ColorMode: colorMode, ColorMode: colorMode,
colors: map[string]string{ colors: map[string]string{
@ -57,18 +89,18 @@ func New(colorMode ColorMode) Formatter {
"bold": "\x1b[1m", "bold": "\x1b[1m",
"underline": "\x1b[4m", "underline": "\x1b[4m",
"red": "\x1b[38;5;9m", "red": getColor("red", "\x1b[38;5;9m"),
"orange": "\x1b[38;5;214m", "orange": getColor("orange", "\x1b[38;5;214m"),
"coral": "\x1b[38;5;204m", "coral": getColor("coral", "\x1b[38;5;204m"),
"magenta": "\x1b[38;5;13m", "magenta": getColor("magenta", "\x1b[38;5;13m"),
"green": "\x1b[38;5;10m", "green": getColor("green", "\x1b[38;5;10m"),
"dark-green": "\x1b[38;5;28m", "dark-green": getColor("dark-green", "\x1b[38;5;28m"),
"yellow": "\x1b[38;5;11m", "yellow": getColor("yellow", "\x1b[38;5;11m"),
"light-yellow": "\x1b[38;5;228m", "light-yellow": getColor("light-yellow", "\x1b[38;5;228m"),
"cyan": "\x1b[38;5;14m", "cyan": getColor("cyan", "\x1b[38;5;14m"),
"gray": "\x1b[38;5;243m", "gray": getColor("gray", "\x1b[38;5;243m"),
"light-gray": "\x1b[38;5;246m", "light-gray": getColor("light-gray", "\x1b[38;5;246m"),
"blue": "\x1b[38;5;12m", "blue": getColor("blue", "\x1b[38;5;12m"),
}, },
} }
colors := []string{} colors := []string{}
@ -88,7 +120,10 @@ func (f Formatter) Fi(indentation uint, format string, args ...interface{}) stri
} }
func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string { func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string {
out := fmt.Sprintf(f.style(format), args...) out := f.style(format)
if len(args) > 0 {
out = fmt.Sprintf(out, args...)
}
if indentation == 0 && maxWidth == 0 { if indentation == 0 && maxWidth == 0 {
return out return out

View File

@ -39,6 +39,8 @@ func buildSpecs(args []string, cliConfig types.CLIConfig, goFlagsConfig types.Go
command.AbortWith("Found no test suites") command.AbortWith("Found no test suites")
} }
internal.VerifyCLIAndFrameworkVersion(suites)
opc := internal.NewOrderedParallelCompiler(cliConfig.ComputedNumCompilers()) opc := internal.NewOrderedParallelCompiler(cliConfig.ComputedNumCompilers())
opc.StartCompiling(suites, goFlagsConfig) opc.StartCompiling(suites, goFlagsConfig)

View File

@ -2,6 +2,7 @@ package generators
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"os" "os"
"text/template" "text/template"
@ -25,6 +26,9 @@ func BuildBootstrapCommand() command.Command {
{Name: "template", KeyPath: "CustomTemplate", {Name: "template", KeyPath: "CustomTemplate",
UsageArgument: "template-file", UsageArgument: "template-file",
Usage: "If specified, generate will use the contents of the file passed as the bootstrap template"}, Usage: "If specified, generate will use the contents of the file passed as the bootstrap template"},
{Name: "template-data", KeyPath: "CustomTemplateData",
UsageArgument: "template-data-file",
Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the bootstrap template"},
}, },
&conf, &conf,
types.GinkgoFlagSections{}, types.GinkgoFlagSections{},
@ -57,6 +61,7 @@ type bootstrapData struct {
GomegaImport string GomegaImport string
GinkgoPackage string GinkgoPackage string
GomegaPackage string GomegaPackage string
CustomData map[string]any
} }
func generateBootstrap(conf GeneratorsConfig) { func generateBootstrap(conf GeneratorsConfig) {
@ -95,17 +100,32 @@ func generateBootstrap(conf GeneratorsConfig) {
tpl, err := os.ReadFile(conf.CustomTemplate) tpl, err := os.ReadFile(conf.CustomTemplate)
command.AbortIfError("Failed to read custom bootstrap file:", err) command.AbortIfError("Failed to read custom bootstrap file:", err)
templateText = string(tpl) templateText = string(tpl)
if conf.CustomTemplateData != "" {
var tplCustomDataMap map[string]any
tplCustomData, err := os.ReadFile(conf.CustomTemplateData)
command.AbortIfError("Failed to read custom boostrap data file:", err)
if !json.Valid([]byte(tplCustomData)) {
command.AbortWith("Invalid JSON object in custom data file.")
}
//create map from the custom template data
json.Unmarshal(tplCustomData, &tplCustomDataMap)
data.CustomData = tplCustomDataMap
}
} else if conf.Agouti { } else if conf.Agouti {
templateText = agoutiBootstrapText templateText = agoutiBootstrapText
} else { } else {
templateText = bootstrapText templateText = bootstrapText
} }
bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Parse(templateText) //Setting the option to explicitly fail if template is rendered trying to access missing key
bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText)
command.AbortIfError("Failed to parse bootstrap template:", err) command.AbortIfError("Failed to parse bootstrap template:", err)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
bootstrapTemplate.Execute(buf, data) //Being explicit about failing sooner during template rendering
//when accessing custom data rather than during the go fmt command
err = bootstrapTemplate.Execute(buf, data)
command.AbortIfError("Failed to render bootstrap template:", err)
buf.WriteTo(f) buf.WriteTo(f)

View File

@ -2,6 +2,7 @@ package generators
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -28,6 +29,9 @@ func BuildGenerateCommand() command.Command {
{Name: "template", KeyPath: "CustomTemplate", {Name: "template", KeyPath: "CustomTemplate",
UsageArgument: "template-file", UsageArgument: "template-file",
Usage: "If specified, generate will use the contents of the file passed as the test file template"}, Usage: "If specified, generate will use the contents of the file passed as the test file template"},
{Name: "template-data", KeyPath: "CustomTemplateData",
UsageArgument: "template-data-file",
Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the test file template"},
}, },
&conf, &conf,
types.GinkgoFlagSections{}, types.GinkgoFlagSections{},
@ -64,6 +68,7 @@ type specData struct {
GomegaImport string GomegaImport string
GinkgoPackage string GinkgoPackage string
GomegaPackage string GomegaPackage string
CustomData map[string]any
} }
func generateTestFiles(conf GeneratorsConfig, args []string) { func generateTestFiles(conf GeneratorsConfig, args []string) {
@ -122,16 +127,31 @@ func generateTestFileForSubject(subject string, conf GeneratorsConfig) {
tpl, err := os.ReadFile(conf.CustomTemplate) tpl, err := os.ReadFile(conf.CustomTemplate)
command.AbortIfError("Failed to read custom template file:", err) command.AbortIfError("Failed to read custom template file:", err)
templateText = string(tpl) templateText = string(tpl)
if conf.CustomTemplateData != "" {
var tplCustomDataMap map[string]any
tplCustomData, err := os.ReadFile(conf.CustomTemplateData)
command.AbortIfError("Failed to read custom template data file:", err)
if !json.Valid([]byte(tplCustomData)) {
command.AbortWith("Invalid JSON object in custom data file.")
}
//create map from the custom template data
json.Unmarshal(tplCustomData, &tplCustomDataMap)
data.CustomData = tplCustomDataMap
}
} else if conf.Agouti { } else if conf.Agouti {
templateText = agoutiSpecText templateText = agoutiSpecText
} else { } else {
templateText = specText templateText = specText
} }
specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Parse(templateText) //Setting the option to explicitly fail if template is rendered trying to access missing key
specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText)
command.AbortIfError("Failed to read parse test template:", err) command.AbortIfError("Failed to read parse test template:", err)
specTemplate.Execute(f, data) //Being explicit about failing sooner during template rendering
//when accessing custom data rather than during the go fmt command
err = specTemplate.Execute(f, data)
command.AbortIfError("Failed to render bootstrap template:", err)
internal.GoFmt(targetFile) internal.GoFmt(targetFile)
} }

View File

@ -13,6 +13,7 @@ import (
type GeneratorsConfig struct { type GeneratorsConfig struct {
Agouti, NoDot, Internal bool Agouti, NoDot, Internal bool
CustomTemplate string CustomTemplate string
CustomTemplateData string
} }
func getPackageAndFormattedName() (string, string, string) { func getPackageAndFormattedName() (string, string, string) {

View File

@ -25,7 +25,16 @@ func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig) TestSuite
return suite return suite
} }
args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./") ginkgoInvocationPath, _ := os.Getwd()
ginkgoInvocationPath, _ = filepath.Abs(ginkgoInvocationPath)
packagePath := suite.AbsPath()
pathToInvocationPath, err := filepath.Rel(packagePath, ginkgoInvocationPath)
if err != nil {
suite.State = TestSuiteStateFailedToCompile
suite.CompilationError = fmt.Errorf("Failed to get relative path from package to the current working directory:\n%s", err.Error())
return suite
}
args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./", pathToInvocationPath)
if err != nil { if err != nil {
suite.State = TestSuiteStateFailedToCompile suite.State = TestSuiteStateFailedToCompile
suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error()) suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error())

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"regexp" "regexp"
"strings" "strings"
"syscall" "syscall"
@ -63,6 +64,12 @@ func checkForNoTestsWarning(buf *bytes.Buffer) bool {
} }
func runGoTest(suite TestSuite, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) TestSuite { func runGoTest(suite TestSuite, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) TestSuite {
// As we run the go test from the suite directory, make sure the cover profile is absolute
// and placed into the expected output directory when one is configured.
if goFlagsConfig.Cover && !filepath.IsAbs(goFlagsConfig.CoverProfile) {
goFlagsConfig.CoverProfile = AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, 0)
}
args, err := types.GenerateGoTestRunArgs(goFlagsConfig) args, err := types.GenerateGoTestRunArgs(goFlagsConfig)
command.AbortIfError("Failed to generate test run arguments", err) command.AbortIfError("Failed to generate test run arguments", err)
cmd, buf := buildAndStartCommand(suite, args, true) cmd, buf := buildAndStartCommand(suite, args, true)

View File

@ -0,0 +1,54 @@
package internal
import (
"fmt"
"os/exec"
"regexp"
"strings"
"github.com/onsi/ginkgo/v2/formatter"
"github.com/onsi/ginkgo/v2/types"
)
var versiorRe = regexp.MustCompile(`v(\d+\.\d+\.\d+)`)
func VerifyCLIAndFrameworkVersion(suites TestSuites) {
cliVersion := types.VERSION
mismatches := map[string][]string{}
for _, suite := range suites {
cmd := exec.Command("go", "list", "-m", "github.com/onsi/ginkgo/v2")
cmd.Dir = suite.Path
output, err := cmd.CombinedOutput()
if err != nil {
continue
}
components := strings.Split(string(output), " ")
if len(components) != 2 {
continue
}
matches := versiorRe.FindStringSubmatch(components[1])
if matches == nil || len(matches) != 2 {
continue
}
libraryVersion := matches[1]
if cliVersion != libraryVersion {
mismatches[libraryVersion] = append(mismatches[libraryVersion], suite.PackageName)
}
}
if len(mismatches) == 0 {
return
}
fmt.Println(formatter.F("{{red}}{{bold}}Ginkgo detected a version mismatch between the Ginkgo CLI and the version of Ginkgo imported by your packages:{{/}}"))
fmt.Println(formatter.Fi(1, "Ginkgo CLI Version:"))
fmt.Println(formatter.Fi(2, "{{bold}}%s{{/}}", cliVersion))
fmt.Println(formatter.Fi(1, "Mismatched package versions found:"))
for version, packages := range mismatches {
fmt.Println(formatter.Fi(2, "{{bold}}%s{{/}} used by %s", version, strings.Join(packages, ", ")))
}
fmt.Println("")
fmt.Println(formatter.Fiw(1, formatter.COLS, "{{gray}}Ginkgo will continue to attempt to run but you may see errors (including flag parsing errors) and should either update your go.mod or your version of the Ginkgo CLI to match.\n\nTo install the matching version of the CLI run\n {{bold}}go install github.com/onsi/ginkgo/v2/ginkgo{{/}}{{gray}}\nfrom a path that contains a go.mod file. Alternatively you can use\n {{bold}}go run github.com/onsi/ginkgo/v2/ginkgo{{/}}{{gray}}\nfrom a path that contains a go.mod file to invoke the matching version of the Ginkgo CLI.\n\nIf you are attempting to test multiple packages that each have a different version of the Ginkgo library with a single Ginkgo CLI that is currently unsupported.\n{{/}}"))
}

View File

@ -1,6 +1,7 @@
package outline package outline
import ( import (
"github.com/onsi/ginkgo/v2/types"
"go/ast" "go/ast"
"go/token" "go/token"
"strconv" "strconv"
@ -25,9 +26,10 @@ type ginkgoMetadata struct {
// End is the position of first character immediately after the spec or container block // End is the position of first character immediately after the spec or container block
End int `json:"end"` End int `json:"end"`
Spec bool `json:"spec"` Spec bool `json:"spec"`
Focused bool `json:"focused"` Focused bool `json:"focused"`
Pending bool `json:"pending"` Pending bool `json:"pending"`
Labels []string `json:"labels"`
} }
// ginkgoNode is used to construct the outline as a tree // ginkgoNode is used to construct the outline as a tree
@ -145,27 +147,35 @@ func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackage
case "It", "Specify", "Entry": case "It", "Specify", "Entry":
n.Spec = true n.Spec = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
n.Pending = pendingFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "FIt", "FSpecify", "FEntry": case "FIt", "FSpecify", "FEntry":
n.Spec = true n.Spec = true
n.Focused = true n.Focused = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "PIt", "PSpecify", "XIt", "XSpecify", "PEntry", "XEntry": case "PIt", "PSpecify", "XIt", "XSpecify", "PEntry", "XEntry":
n.Spec = true n.Spec = true
n.Pending = true n.Pending = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "Context", "Describe", "When", "DescribeTable": case "Context", "Describe", "When", "DescribeTable":
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
n.Pending = pendingFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "FContext", "FDescribe", "FWhen", "FDescribeTable": case "FContext", "FDescribe", "FWhen", "FDescribeTable":
n.Focused = true n.Focused = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen", "PDescribeTable", "XDescribeTable": case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen", "PDescribeTable", "XDescribeTable":
n.Pending = true n.Pending = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "By": case "By":
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
@ -216,3 +226,77 @@ func textFromCallExpr(ce *ast.CallExpr) (string, bool) {
return text.Value, true return text.Value, true
} }
} }
func labelFromCallExpr(ce *ast.CallExpr) []string {
labels := []string{}
if len(ce.Args) < 2 {
return labels
}
for _, arg := range ce.Args[1:] {
switch expr := arg.(type) {
case *ast.CallExpr:
id, ok := expr.Fun.(*ast.Ident)
if !ok {
// to skip over cases where the expr.Fun. is actually *ast.SelectorExpr
continue
}
if id.Name == "Label" {
ls := extractLabels(expr)
for _, label := range ls {
labels = append(labels, label)
}
}
}
}
return labels
}
func extractLabels(expr *ast.CallExpr) []string {
out := []string{}
for _, arg := range expr.Args {
switch expr := arg.(type) {
case *ast.BasicLit:
if expr.Kind == token.STRING {
unquoted, err := strconv.Unquote(expr.Value)
if err != nil {
unquoted = expr.Value
}
validated, err := types.ValidateAndCleanupLabel(unquoted, types.CodeLocation{})
if err == nil {
out = append(out, validated)
}
}
}
}
return out
}
func pendingFromCallExpr(ce *ast.CallExpr) bool {
pending := false
if len(ce.Args) < 2 {
return pending
}
for _, arg := range ce.Args[1:] {
switch expr := arg.(type) {
case *ast.CallExpr:
id, ok := expr.Fun.(*ast.Ident)
if !ok {
// to skip over cases where the expr.Fun. is actually *ast.SelectorExpr
continue
}
if id.Name == "Pending" {
pending = true
}
case *ast.Ident:
if expr.Name == "Pending" {
pending = true
}
}
}
return pending
}

View File

@ -47,7 +47,7 @@ func packageNameForImport(f *ast.File, path string) *string {
// or nil otherwise. // or nil otherwise.
func importSpec(f *ast.File, path string) *ast.ImportSpec { func importSpec(f *ast.File, path string) *ast.ImportSpec {
for _, s := range f.Imports { for _, s := range f.Imports {
if importPath(s) == path { if strings.HasPrefix(importPath(s), path) {
return s return s
} }
} }

View File

@ -85,12 +85,19 @@ func (o *outline) String() string {
// one 'width' of spaces for every level of nesting. // one 'width' of spaces for every level of nesting.
func (o *outline) StringIndent(width int) string { func (o *outline) StringIndent(width int) string {
var b strings.Builder var b strings.Builder
b.WriteString("Name,Text,Start,End,Spec,Focused,Pending\n") b.WriteString("Name,Text,Start,End,Spec,Focused,Pending,Labels\n")
currentIndent := 0 currentIndent := 0
pre := func(n *ginkgoNode) { pre := func(n *ginkgoNode) {
b.WriteString(fmt.Sprintf("%*s", currentIndent, "")) b.WriteString(fmt.Sprintf("%*s", currentIndent, ""))
b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending)) var labels string
if len(n.Labels) == 1 {
labels = n.Labels[0]
} else {
labels = strings.Join(n.Labels, ", ")
}
//enclosing labels in a double quoted comma separate listed so that when inmported into a CSV app the Labels column has comma separate strings
b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t,\"%s\"\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending, labels))
currentIndent += width currentIndent += width
} }
post := func(n *ginkgoNode) { post := func(n *ginkgoNode) {

View File

@ -24,7 +24,7 @@ func BuildRunCommand() command.Command {
panic(err) panic(err)
} }
interruptHandler := interrupt_handler.NewInterruptHandler(0, nil) interruptHandler := interrupt_handler.NewInterruptHandler(nil)
interrupt_handler.SwallowSigQuit() interrupt_handler.SwallowSigQuit()
return command.Command{ return command.Command{
@ -69,6 +69,8 @@ func (r *SpecRunner) RunSpecs(args []string, additionalArgs []string) {
skippedSuites := suites.WithState(internal.TestSuiteStateSkippedByFilter) skippedSuites := suites.WithState(internal.TestSuiteStateSkippedByFilter)
suites = suites.WithoutState(internal.TestSuiteStateSkippedByFilter) suites = suites.WithoutState(internal.TestSuiteStateSkippedByFilter)
internal.VerifyCLIAndFrameworkVersion(suites)
if len(skippedSuites) > 0 { if len(skippedSuites) > 0 {
fmt.Println("Will skip:") fmt.Println("Will skip:")
for _, skippedSuite := range skippedSuites { for _, skippedSuite := range skippedSuites {
@ -115,7 +117,7 @@ OUTER_LOOP:
} }
suites[suiteIdx] = suite suites[suiteIdx] = suite
if r.interruptHandler.Status().Interrupted { if r.interruptHandler.Status().Interrupted() {
opc.StopAndDrain() opc.StopAndDrain()
break OUTER_LOOP break OUTER_LOOP
} }

View File

@ -22,7 +22,7 @@ func BuildWatchCommand() command.Command {
if err != nil { if err != nil {
panic(err) panic(err)
} }
interruptHandler := interrupt_handler.NewInterruptHandler(0, nil) interruptHandler := interrupt_handler.NewInterruptHandler(nil)
interrupt_handler.SwallowSigQuit() interrupt_handler.SwallowSigQuit()
return command.Command{ return command.Command{
@ -65,6 +65,8 @@ type SpecWatcher struct {
func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) { func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) {
suites := internal.FindSuites(args, w.cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter) suites := internal.FindSuites(args, w.cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter)
internal.VerifyCLIAndFrameworkVersion(suites)
if len(suites) == 0 { if len(suites) == 0 {
command.AbortWith("Found no test suites") command.AbortWith("Found no test suites")
} }
@ -127,7 +129,7 @@ func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) {
w.updateSeed() w.updateSeed()
w.computeSuccinctMode(len(suites)) w.computeSuccinctMode(len(suites))
for idx := range suites { for idx := range suites {
if w.interruptHandler.Status().Interrupted { if w.interruptHandler.Status().Interrupted() {
return return
} }
deltaTracker.WillRun(suites[idx]) deltaTracker.WillRun(suites[idx])
@ -156,7 +158,7 @@ func (w *SpecWatcher) compileAndRun(suite internal.TestSuite, additionalArgs []s
fmt.Println(suite.CompilationError.Error()) fmt.Println(suite.CompilationError.Error())
return suite return suite
} }
if w.interruptHandler.Status().Interrupted { if w.interruptHandler.Status().Interrupted() {
return suite return suite
} }
suite = internal.RunCompiledSuite(suite, w.suiteConfig, w.reporterConfig, w.cliConfig, w.goFlagsConfig, additionalArgs) suite = internal.RunCompiledSuite(suite, w.suiteConfig, w.reporterConfig, w.cliConfig, w.goFlagsConfig, additionalArgs)

View File

@ -1,7 +1,6 @@
package interrupt_handler package interrupt_handler
import ( import (
"fmt"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
@ -11,27 +10,29 @@ import (
"github.com/onsi/ginkgo/v2/internal/parallel_support" "github.com/onsi/ginkgo/v2/internal/parallel_support"
) )
const TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION = 30 * time.Second var ABORT_POLLING_INTERVAL = 500 * time.Millisecond
const TIMEOUT_REPEAT_INTERRUPT_FRACTION_OF_TIMEOUT = 10
const ABORT_POLLING_INTERVAL = 500 * time.Millisecond
const ABORT_REPEAT_INTERRUPT_DURATION = 30 * time.Second
type InterruptCause uint type InterruptCause uint
const ( const (
InterruptCauseInvalid InterruptCause = iota InterruptCauseInvalid InterruptCause = iota
InterruptCauseSignal InterruptCauseSignal
InterruptCauseTimeout
InterruptCauseAbortByOtherProcess InterruptCauseAbortByOtherProcess
) )
type InterruptLevel uint
const (
InterruptLevelUninterrupted InterruptLevel = iota
InterruptLevelCleanupAndReport
InterruptLevelReportOnly
InterruptLevelBailOut
)
func (ic InterruptCause) String() string { func (ic InterruptCause) String() string {
switch ic { switch ic {
case InterruptCauseSignal: case InterruptCauseSignal:
return "Interrupted by User" return "Interrupted by User"
case InterruptCauseTimeout:
return "Interrupted by Timeout"
case InterruptCauseAbortByOtherProcess: case InterruptCauseAbortByOtherProcess:
return "Interrupted by Other Ginkgo Process" return "Interrupted by Other Ginkgo Process"
} }
@ -39,37 +40,51 @@ func (ic InterruptCause) String() string {
} }
type InterruptStatus struct { type InterruptStatus struct {
Interrupted bool Channel chan interface{}
Channel chan interface{} Level InterruptLevel
Cause InterruptCause Cause InterruptCause
}
func (s InterruptStatus) Interrupted() bool {
return s.Level != InterruptLevelUninterrupted
}
func (s InterruptStatus) Message() string {
return s.Cause.String()
}
func (s InterruptStatus) ShouldIncludeProgressReport() bool {
return s.Cause != InterruptCauseAbortByOtherProcess
} }
type InterruptHandlerInterface interface { type InterruptHandlerInterface interface {
Status() InterruptStatus Status() InterruptStatus
SetInterruptPlaceholderMessage(string)
ClearInterruptPlaceholderMessage()
InterruptMessage() (string, bool)
} }
type InterruptHandler struct { type InterruptHandler struct {
c chan interface{} c chan interface{}
lock *sync.Mutex lock *sync.Mutex
interrupted bool level InterruptLevel
interruptPlaceholderMessage string cause InterruptCause
interruptCause InterruptCause client parallel_support.Client
client parallel_support.Client stop chan interface{}
stop chan interface{} signals []os.Signal
requestAbortCheck chan interface{}
} }
func NewInterruptHandler(timeout time.Duration, client parallel_support.Client) *InterruptHandler { func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *InterruptHandler {
handler := &InterruptHandler{ if len(signals) == 0 {
c: make(chan interface{}), signals = []os.Signal{os.Interrupt, syscall.SIGTERM}
lock: &sync.Mutex{},
interrupted: false,
stop: make(chan interface{}),
client: client,
} }
handler.registerForInterrupts(timeout) handler := &InterruptHandler{
c: make(chan interface{}),
lock: &sync.Mutex{},
stop: make(chan interface{}),
requestAbortCheck: make(chan interface{}),
client: client,
signals: signals,
}
handler.registerForInterrupts()
return handler return handler
} }
@ -77,30 +92,28 @@ func (handler *InterruptHandler) Stop() {
close(handler.stop) close(handler.stop)
} }
func (handler *InterruptHandler) registerForInterrupts(timeout time.Duration) { func (handler *InterruptHandler) registerForInterrupts() {
// os signal handling // os signal handling
signalChannel := make(chan os.Signal, 1) signalChannel := make(chan os.Signal, 1)
signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM) signal.Notify(signalChannel, handler.signals...)
// timeout handling
var timeoutChannel <-chan time.Time
var timeoutTimer *time.Timer
if timeout > 0 {
timeoutTimer = time.NewTimer(timeout)
timeoutChannel = timeoutTimer.C
}
// cross-process abort handling // cross-process abort handling
var abortChannel chan bool var abortChannel chan interface{}
if handler.client != nil { if handler.client != nil {
abortChannel = make(chan bool) abortChannel = make(chan interface{})
go func() { go func() {
pollTicker := time.NewTicker(ABORT_POLLING_INTERVAL) pollTicker := time.NewTicker(ABORT_POLLING_INTERVAL)
for { for {
select { select {
case <-pollTicker.C: case <-pollTicker.C:
if handler.client.ShouldAbort() { if handler.client.ShouldAbort() {
abortChannel <- true close(abortChannel)
pollTicker.Stop()
return
}
case <-handler.requestAbortCheck:
if handler.client.ShouldAbort() {
close(abortChannel)
pollTicker.Stop() pollTicker.Stop()
return return
} }
@ -112,85 +125,53 @@ func (handler *InterruptHandler) registerForInterrupts(timeout time.Duration) {
}() }()
} }
// listen for any interrupt signals go func(abortChannel chan interface{}) {
// note that some (timeouts, cross-process aborts) will only trigger once
// for these we set up a ticker to keep interrupting the suite until it ends
// this ensures any `AfterEach` or `AfterSuite`s that get stuck cleaning up
// get interrupted eventually
go func() {
var interruptCause InterruptCause var interruptCause InterruptCause
var repeatChannel <-chan time.Time
var repeatTicker *time.Ticker
for { for {
select { select {
case <-signalChannel: case <-signalChannel:
interruptCause = InterruptCauseSignal interruptCause = InterruptCauseSignal
case <-timeoutChannel:
interruptCause = InterruptCauseTimeout
repeatInterruptTimeout := timeout / time.Duration(TIMEOUT_REPEAT_INTERRUPT_FRACTION_OF_TIMEOUT)
if repeatInterruptTimeout > TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION {
repeatInterruptTimeout = TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION
}
timeoutTimer.Stop()
repeatTicker = time.NewTicker(repeatInterruptTimeout)
repeatChannel = repeatTicker.C
case <-abortChannel: case <-abortChannel:
interruptCause = InterruptCauseAbortByOtherProcess interruptCause = InterruptCauseAbortByOtherProcess
repeatTicker = time.NewTicker(ABORT_REPEAT_INTERRUPT_DURATION)
repeatChannel = repeatTicker.C
case <-repeatChannel:
//do nothing, just interrupt again using the same interruptCause
case <-handler.stop: case <-handler.stop:
if timeoutTimer != nil {
timeoutTimer.Stop()
}
if repeatTicker != nil {
repeatTicker.Stop()
}
signal.Stop(signalChannel) signal.Stop(signalChannel)
return return
} }
abortChannel = nil
handler.lock.Lock() handler.lock.Lock()
handler.interruptCause = interruptCause oldLevel := handler.level
if handler.interruptPlaceholderMessage != "" { handler.cause = interruptCause
fmt.Println(handler.interruptPlaceholderMessage) if handler.level == InterruptLevelUninterrupted {
handler.level = InterruptLevelCleanupAndReport
} else if handler.level == InterruptLevelCleanupAndReport {
handler.level = InterruptLevelReportOnly
} else if handler.level == InterruptLevelReportOnly {
handler.level = InterruptLevelBailOut
}
if handler.level != oldLevel {
close(handler.c)
handler.c = make(chan interface{})
} }
handler.interrupted = true
close(handler.c)
handler.c = make(chan interface{})
handler.lock.Unlock() handler.lock.Unlock()
} }
}() }(abortChannel)
} }
func (handler *InterruptHandler) Status() InterruptStatus { func (handler *InterruptHandler) Status() InterruptStatus {
handler.lock.Lock() handler.lock.Lock()
defer handler.lock.Unlock() status := InterruptStatus{
Level: handler.level,
return InterruptStatus{ Channel: handler.c,
Interrupted: handler.interrupted, Cause: handler.cause,
Channel: handler.c,
Cause: handler.interruptCause,
} }
} handler.lock.Unlock()
func (handler *InterruptHandler) SetInterruptPlaceholderMessage(message string) { if handler.client != nil && handler.client.ShouldAbort() && !status.Interrupted() {
handler.lock.Lock() close(handler.requestAbortCheck)
defer handler.lock.Unlock() <-status.Channel
return handler.Status()
handler.interruptPlaceholderMessage = message }
}
return status
func (handler *InterruptHandler) ClearInterruptPlaceholderMessage() {
handler.lock.Lock()
defer handler.lock.Unlock()
handler.interruptPlaceholderMessage = ""
}
func (handler *InterruptHandler) InterruptMessage() (string, bool) {
handler.lock.Lock()
out := fmt.Sprintf("%s", handler.interruptCause.String())
defer handler.lock.Unlock()
return out, handler.interruptCause != InterruptCauseAbortByOtherProcess
} }

View File

@ -42,6 +42,8 @@ type Client interface {
PostSuiteWillBegin(report types.Report) error PostSuiteWillBegin(report types.Report) error
PostDidRun(report types.SpecReport) error PostDidRun(report types.SpecReport) error
PostSuiteDidEnd(report types.Report) error PostSuiteDidEnd(report types.Report) error
PostReportBeforeSuiteCompleted(state types.SpecState) error
BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error)
PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error
BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error) BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error)
BlockUntilNonprimaryProcsHaveFinished() error BlockUntilNonprimaryProcsHaveFinished() error

View File

@ -98,6 +98,19 @@ func (client *httpClient) PostEmitProgressReport(report types.ProgressReport) er
return client.post("/progress-report", report) return client.post("/progress-report", report)
} }
func (client *httpClient) PostReportBeforeSuiteCompleted(state types.SpecState) error {
return client.post("/report-before-suite-completed", state)
}
func (client *httpClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) {
var state types.SpecState
err := client.poll("/report-before-suite-state", &state)
if err == ErrorGone {
return types.SpecStateFailed, nil
}
return state, err
}
func (client *httpClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error { func (client *httpClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error {
beforeSuiteState := BeforeSuiteState{ beforeSuiteState := BeforeSuiteState{
State: state, State: state,

View File

@ -26,7 +26,7 @@ type httpServer struct {
handler *ServerHandler handler *ServerHandler
} }
//Create a new server, automatically selecting a port // Create a new server, automatically selecting a port
func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, error) { func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
@ -38,7 +38,7 @@ func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer,
}, nil }, nil
} }
//Start the server. You don't need to `go s.Start()`, just `s.Start()` // Start the server. You don't need to `go s.Start()`, just `s.Start()`
func (server *httpServer) Start() { func (server *httpServer) Start() {
httpServer := &http.Server{} httpServer := &http.Server{}
mux := http.NewServeMux() mux := http.NewServeMux()
@ -52,6 +52,8 @@ func (server *httpServer) Start() {
mux.HandleFunc("/progress-report", server.emitProgressReport) mux.HandleFunc("/progress-report", server.emitProgressReport)
//synchronization endpoints //synchronization endpoints
mux.HandleFunc("/report-before-suite-completed", server.handleReportBeforeSuiteCompleted)
mux.HandleFunc("/report-before-suite-state", server.handleReportBeforeSuiteState)
mux.HandleFunc("/before-suite-completed", server.handleBeforeSuiteCompleted) mux.HandleFunc("/before-suite-completed", server.handleBeforeSuiteCompleted)
mux.HandleFunc("/before-suite-state", server.handleBeforeSuiteState) mux.HandleFunc("/before-suite-state", server.handleBeforeSuiteState)
mux.HandleFunc("/have-nonprimary-procs-finished", server.handleHaveNonprimaryProcsFinished) mux.HandleFunc("/have-nonprimary-procs-finished", server.handleHaveNonprimaryProcsFinished)
@ -63,12 +65,12 @@ func (server *httpServer) Start() {
go httpServer.Serve(server.listener) go httpServer.Serve(server.listener)
} }
//Stop the server // Stop the server
func (server *httpServer) Close() { func (server *httpServer) Close() {
server.listener.Close() server.listener.Close()
} }
//The address the server can be reached it. Pass this into the `ForwardingReporter`. // The address the server can be reached it. Pass this into the `ForwardingReporter`.
func (server *httpServer) Address() string { func (server *httpServer) Address() string {
return "http://" + server.listener.Addr().String() return "http://" + server.listener.Addr().String()
} }
@ -93,7 +95,7 @@ func (server *httpServer) RegisterAlive(node int, alive func() bool) {
// Streaming Endpoints // Streaming Endpoints
// //
//The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters` // The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters`
func (server *httpServer) decode(writer http.ResponseWriter, request *http.Request, object interface{}) bool { func (server *httpServer) decode(writer http.ResponseWriter, request *http.Request, object interface{}) bool {
defer request.Body.Close() defer request.Body.Close()
if json.NewDecoder(request.Body).Decode(object) != nil { if json.NewDecoder(request.Body).Decode(object) != nil {
@ -164,6 +166,23 @@ func (server *httpServer) emitProgressReport(writer http.ResponseWriter, request
server.handleError(server.handler.EmitProgressReport(report, voidReceiver), writer) server.handleError(server.handler.EmitProgressReport(report, voidReceiver), writer)
} }
func (server *httpServer) handleReportBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) {
var state types.SpecState
if !server.decode(writer, request, &state) {
return
}
server.handleError(server.handler.ReportBeforeSuiteCompleted(state, voidReceiver), writer)
}
func (server *httpServer) handleReportBeforeSuiteState(writer http.ResponseWriter, request *http.Request) {
var state types.SpecState
if server.handleError(server.handler.ReportBeforeSuiteState(voidSender, &state), writer) {
return
}
json.NewEncoder(writer).Encode(state)
}
func (server *httpServer) handleBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) { func (server *httpServer) handleBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) {
var beforeSuiteState BeforeSuiteState var beforeSuiteState BeforeSuiteState
if !server.decode(writer, request, &beforeSuiteState) { if !server.decode(writer, request, &beforeSuiteState) {

View File

@ -76,6 +76,19 @@ func (client *rpcClient) PostEmitProgressReport(report types.ProgressReport) err
return client.client.Call("Server.EmitProgressReport", report, voidReceiver) return client.client.Call("Server.EmitProgressReport", report, voidReceiver)
} }
func (client *rpcClient) PostReportBeforeSuiteCompleted(state types.SpecState) error {
return client.client.Call("Server.ReportBeforeSuiteCompleted", state, voidReceiver)
}
func (client *rpcClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) {
var state types.SpecState
err := client.poll("Server.ReportBeforeSuiteState", &state)
if err == ErrorGone {
return types.SpecStateFailed, nil
}
return state, err
}
func (client *rpcClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error { func (client *rpcClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error {
beforeSuiteState := BeforeSuiteState{ beforeSuiteState := BeforeSuiteState{
State: state, State: state,

View File

@ -18,16 +18,17 @@ var voidSender Void
// It handles all the business logic to avoid duplication between the two servers // It handles all the business logic to avoid duplication between the two servers
type ServerHandler struct { type ServerHandler struct {
done chan interface{} done chan interface{}
outputDestination io.Writer outputDestination io.Writer
reporter reporters.Reporter reporter reporters.Reporter
alives []func() bool alives []func() bool
lock *sync.Mutex lock *sync.Mutex
beforeSuiteState BeforeSuiteState beforeSuiteState BeforeSuiteState
parallelTotal int reportBeforeSuiteState types.SpecState
counter int parallelTotal int
counterLock *sync.Mutex counter int
shouldAbort bool counterLock *sync.Mutex
shouldAbort bool
numSuiteDidBegins int numSuiteDidBegins int
numSuiteDidEnds int numSuiteDidEnds int
@ -37,11 +38,12 @@ type ServerHandler struct {
func newServerHandler(parallelTotal int, reporter reporters.Reporter) *ServerHandler { func newServerHandler(parallelTotal int, reporter reporters.Reporter) *ServerHandler {
return &ServerHandler{ return &ServerHandler{
reporter: reporter, reporter: reporter,
lock: &sync.Mutex{}, lock: &sync.Mutex{},
counterLock: &sync.Mutex{}, counterLock: &sync.Mutex{},
alives: make([]func() bool, parallelTotal), alives: make([]func() bool, parallelTotal),
beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid}, beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid},
parallelTotal: parallelTotal, parallelTotal: parallelTotal,
outputDestination: os.Stdout, outputDestination: os.Stdout,
done: make(chan interface{}), done: make(chan interface{}),
@ -140,6 +142,29 @@ func (handler *ServerHandler) haveNonprimaryProcsFinished() bool {
return true return true
} }
func (handler *ServerHandler) ReportBeforeSuiteCompleted(reportBeforeSuiteState types.SpecState, _ *Void) error {
handler.lock.Lock()
defer handler.lock.Unlock()
handler.reportBeforeSuiteState = reportBeforeSuiteState
return nil
}
func (handler *ServerHandler) ReportBeforeSuiteState(_ Void, reportBeforeSuiteState *types.SpecState) error {
proc1IsAlive := handler.procIsAlive(1)
handler.lock.Lock()
defer handler.lock.Unlock()
if handler.reportBeforeSuiteState == types.SpecStateInvalid {
if proc1IsAlive {
return ErrorEarly
} else {
return ErrorGone
}
}
*reportBeforeSuiteState = handler.reportBeforeSuiteState
return nil
}
func (handler *ServerHandler) BeforeSuiteCompleted(beforeSuiteState BeforeSuiteState, _ *Void) error { func (handler *ServerHandler) BeforeSuiteCompleted(beforeSuiteState BeforeSuiteState, _ *Void) error {
handler.lock.Lock() handler.lock.Lock()
defer handler.lock.Unlock() defer handler.lock.Unlock()

View File

@ -12,6 +12,7 @@ import (
"io" "io"
"runtime" "runtime"
"strings" "strings"
"sync"
"time" "time"
"github.com/onsi/ginkgo/v2/formatter" "github.com/onsi/ginkgo/v2/formatter"
@ -23,13 +24,16 @@ type DefaultReporter struct {
writer io.Writer writer io.Writer
// managing the emission stream // managing the emission stream
lastChar string lastCharWasNewline bool
lastEmissionWasDelimiter bool lastEmissionWasDelimiter bool
// rendering // rendering
specDenoter string specDenoter string
retryDenoter string retryDenoter string
formatter formatter.Formatter formatter formatter.Formatter
runningInParallel bool
lock *sync.Mutex
} }
func NewDefaultReporterUnderTest(conf types.ReporterConfig, writer io.Writer) *DefaultReporter { func NewDefaultReporterUnderTest(conf types.ReporterConfig, writer io.Writer) *DefaultReporter {
@ -44,12 +48,13 @@ func NewDefaultReporter(conf types.ReporterConfig, writer io.Writer) *DefaultRep
conf: conf, conf: conf,
writer: writer, writer: writer,
lastChar: "\n", lastCharWasNewline: true,
lastEmissionWasDelimiter: false, lastEmissionWasDelimiter: false,
specDenoter: "•", specDenoter: "•",
retryDenoter: "↺", retryDenoter: "↺",
formatter: formatter.NewWithNoColorBool(conf.NoColor), formatter: formatter.NewWithNoColorBool(conf.NoColor),
lock: &sync.Mutex{},
} }
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
reporter.specDenoter = "+" reporter.specDenoter = "+"
@ -97,166 +102,10 @@ func (r *DefaultReporter) SuiteWillBegin(report types.Report) {
} }
} }
func (r *DefaultReporter) WillRun(report types.SpecReport) {
if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) {
return
}
r.emitDelimiter()
indentation := uint(0)
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
r.emitBlock(r.f("{{bold}}[%s] %s{{/}}", report.LeafNodeType.String(), report.LeafNodeText))
} else {
if len(report.ContainerHierarchyTexts) > 0 {
r.emitBlock(r.cycleJoin(report.ContainerHierarchyTexts, " "))
indentation = 1
}
line := r.fi(indentation, "{{bold}}%s{{/}}", report.LeafNodeText)
labels := report.Labels()
if len(labels) > 0 {
line += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels, ", "))
}
r.emitBlock(line)
}
r.emitBlock(r.fi(indentation, "{{gray}}%s{{/}}", report.LeafNodeLocation))
}
func (r *DefaultReporter) DidRun(report types.SpecReport) {
v := r.conf.Verbosity()
var header, highlightColor string
includeRuntime, emitGinkgoWriterOutput, stream, denoter := true, true, false, r.specDenoter
succinctLocationBlock := v.Is(types.VerbosityLevelSuccinct)
hasGW := report.CapturedGinkgoWriterOutput != ""
hasStd := report.CapturedStdOutErr != ""
hasEmittableReports := report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways) || (report.ReportEntries.HasVisibility(types.ReportEntryVisibilityFailureOrVerbose) && (!report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose)))
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
denoter = fmt.Sprintf("[%s]", report.LeafNodeType)
}
switch report.State {
case types.SpecStatePassed:
highlightColor, succinctLocationBlock = "{{green}}", v.LT(types.VerbosityLevelVerbose)
emitGinkgoWriterOutput = (r.conf.AlwaysEmitGinkgoWriter || v.GTE(types.VerbosityLevelVerbose)) && hasGW
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
if v.GTE(types.VerbosityLevelVerbose) || hasStd || hasEmittableReports {
header = fmt.Sprintf("%s PASSED", denoter)
} else {
return
}
} else {
header, stream = denoter, true
if report.NumAttempts > 1 {
header, stream = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), false
}
if report.RunTime > r.conf.SlowSpecThreshold {
header, stream = fmt.Sprintf("%s [SLOW TEST]", header), false
}
}
if hasStd || emitGinkgoWriterOutput || hasEmittableReports {
stream = false
}
case types.SpecStatePending:
highlightColor = "{{yellow}}"
includeRuntime, emitGinkgoWriterOutput = false, false
if v.Is(types.VerbosityLevelSuccinct) {
header, stream = "P", true
} else {
header, succinctLocationBlock = "P [PENDING]", v.LT(types.VerbosityLevelVeryVerbose)
}
case types.SpecStateSkipped:
highlightColor = "{{cyan}}"
if report.Failure.Message != "" || v.Is(types.VerbosityLevelVeryVerbose) {
header = "S [SKIPPED]"
} else {
header, stream = "S", true
}
case types.SpecStateFailed:
highlightColor, header = "{{red}}", fmt.Sprintf("%s [FAILED]", denoter)
case types.SpecStatePanicked:
highlightColor, header = "{{magenta}}", fmt.Sprintf("%s! [PANICKED]", denoter)
case types.SpecStateInterrupted:
highlightColor, header = "{{orange}}", fmt.Sprintf("%s! [INTERRUPTED]", denoter)
case types.SpecStateAborted:
highlightColor, header = "{{coral}}", fmt.Sprintf("%s! [ABORTED]", denoter)
}
// Emit stream and return
if stream {
r.emit(r.f(highlightColor + header + "{{/}}"))
return
}
// Emit header
r.emitDelimiter()
if includeRuntime {
header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds())
}
r.emitBlock(r.f(highlightColor + header + "{{/}}"))
// Emit Code Location Block
r.emitBlock(r.codeLocationBlock(report, highlightColor, succinctLocationBlock, false))
//Emit Stdout/Stderr Output
if hasStd {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Begin Captured StdOut/StdErr Output >>{{/}}"))
r.emitBlock(r.fi(2, "%s", report.CapturedStdOutErr))
r.emitBlock(r.fi(1, "{{gray}}<< End Captured StdOut/StdErr Output{{/}}"))
}
//Emit Captured GinkgoWriter Output
if emitGinkgoWriterOutput && hasGW {
r.emitBlock("\n")
r.emitGinkgoWriterOutput(1, report.CapturedGinkgoWriterOutput, 0)
}
if hasEmittableReports {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Begin Report Entries >>{{/}}"))
reportEntries := report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways)
if !report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose) {
reportEntries = report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways, types.ReportEntryVisibilityFailureOrVerbose)
}
for _, entry := range reportEntries {
r.emitBlock(r.fi(2, "{{bold}}"+entry.Name+"{{gray}} - %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT)))
if representation := entry.StringRepresentation(); representation != "" {
r.emitBlock(r.fi(3, representation))
}
}
r.emitBlock(r.fi(1, "{{gray}}<< End Report Entries{{/}}"))
}
// Emit Failure Message
if !report.Failure.IsZero() {
r.emitBlock("\n")
r.emitBlock(r.fi(1, highlightColor+"%s{{/}}", report.Failure.Message))
r.emitBlock(r.fi(1, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}}\n", report.Failure.FailureNodeType, report.Failure.Location))
if report.Failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(1, highlightColor+"%s{{/}}", report.Failure.ForwardedPanic))
}
if r.conf.FullTrace || report.Failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(1, highlightColor+"Full Stack Trace{{/}}"))
r.emitBlock(r.fi(2, "%s", report.Failure.Location.FullStackTrace))
}
if !report.Failure.ProgressReport.IsZero() {
r.emitBlock("\n")
r.emitProgressReport(1, false, report.Failure.ProgressReport)
}
}
r.emitDelimiter()
}
func (r *DefaultReporter) SuiteDidEnd(report types.Report) { func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
failures := report.SpecReports.WithState(types.SpecStateFailureStates) failures := report.SpecReports.WithState(types.SpecStateFailureStates)
if len(failures) > 0 { if len(failures) > 0 {
r.emitBlock("\n\n") r.emitBlock("\n")
if len(failures) > 1 { if len(failures) > 1 {
r.emitBlock(r.f("{{red}}{{bold}}Summarizing %d Failures:{{/}}", len(failures))) r.emitBlock(r.f("{{red}}{{bold}}Summarizing %d Failures:{{/}}", len(failures)))
} else { } else {
@ -269,10 +118,12 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
highlightColor, heading = "{{magenta}}", "[PANICKED!]" highlightColor, heading = "{{magenta}}", "[PANICKED!]"
case types.SpecStateAborted: case types.SpecStateAborted:
highlightColor, heading = "{{coral}}", "[ABORTED]" highlightColor, heading = "{{coral}}", "[ABORTED]"
case types.SpecStateTimedout:
highlightColor, heading = "{{orange}}", "[TIMEDOUT]"
case types.SpecStateInterrupted: case types.SpecStateInterrupted:
highlightColor, heading = "{{orange}}", "[INTERRUPTED]" highlightColor, heading = "{{orange}}", "[INTERRUPTED]"
} }
locationBlock := r.codeLocationBlock(specReport, highlightColor, true, true) locationBlock := r.codeLocationBlock(specReport, highlightColor, false, true)
r.emitBlock(r.fi(1, highlightColor+"%s{{/}} %s", heading, locationBlock)) r.emitBlock(r.fi(1, highlightColor+"%s{{/}} %s", heading, locationBlock))
} }
} }
@ -313,28 +164,294 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
if specs.CountOfFlakedSpecs() > 0 { if specs.CountOfFlakedSpecs() > 0 {
r.emit(r.f("{{light-yellow}}{{bold}}%d Flaked{{/}} | ", specs.CountOfFlakedSpecs())) r.emit(r.f("{{light-yellow}}{{bold}}%d Flaked{{/}} | ", specs.CountOfFlakedSpecs()))
} }
if specs.CountOfRepeatedSpecs() > 0 {
r.emit(r.f("{{light-yellow}}{{bold}}%d Repeated{{/}} | ", specs.CountOfRepeatedSpecs()))
}
r.emit(r.f("{{yellow}}{{bold}}%d Pending{{/}} | ", specs.CountWithState(types.SpecStatePending))) r.emit(r.f("{{yellow}}{{bold}}%d Pending{{/}} | ", specs.CountWithState(types.SpecStatePending)))
r.emit(r.f("{{cyan}}{{bold}}%d Skipped{{/}}\n", specs.CountWithState(types.SpecStateSkipped))) r.emit(r.f("{{cyan}}{{bold}}%d Skipped{{/}}\n", specs.CountWithState(types.SpecStateSkipped)))
} }
} }
func (r *DefaultReporter) WillRun(report types.SpecReport) {
v := r.conf.Verbosity()
if v.LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) || report.RunningInParallel {
return
}
r.emitDelimiter(0)
r.emitBlock(r.f(r.codeLocationBlock(report, "{{/}}", v.Is(types.VerbosityLevelVeryVerbose), false)))
}
func (r *DefaultReporter) DidRun(report types.SpecReport) {
v := r.conf.Verbosity()
inParallel := report.RunningInParallel
header := r.specDenoter
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
header = fmt.Sprintf("[%s]", report.LeafNodeType)
}
highlightColor := r.highlightColorForState(report.State)
// have we already been streaming the timeline?
timelineHasBeenStreaming := v.GTE(types.VerbosityLevelVerbose) && !inParallel
// should we show the timeline?
var timeline types.Timeline
showTimeline := !timelineHasBeenStreaming && (v.GTE(types.VerbosityLevelVerbose) || report.Failed())
if showTimeline {
timeline = report.Timeline().WithoutHiddenReportEntries()
keepVeryVerboseSpecEvents := v.Is(types.VerbosityLevelVeryVerbose) ||
(v.Is(types.VerbosityLevelVerbose) && r.conf.ShowNodeEvents) ||
(report.Failed() && r.conf.ShowNodeEvents)
if !keepVeryVerboseSpecEvents {
timeline = timeline.WithoutVeryVerboseSpecEvents()
}
if len(timeline) == 0 && report.CapturedGinkgoWriterOutput == "" {
// the timeline is completely empty - don't show it
showTimeline = false
}
if v.LT(types.VerbosityLevelVeryVerbose) && report.CapturedGinkgoWriterOutput == "" && len(timeline) > 0 {
//if we aren't -vv and the timeline only has a single failure, don't show it as it will appear at the end of the report
failure, isFailure := timeline[0].(types.Failure)
if isFailure && (len(timeline) == 1 || (len(timeline) == 2 && failure.AdditionalFailure != nil)) {
showTimeline = false
}
}
}
// should we have a separate section for always-visible reports?
showSeparateVisibilityAlwaysReportsSection := !timelineHasBeenStreaming && !showTimeline && report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways)
// should we have a separate section for captured stdout/stderr
showSeparateStdSection := inParallel && (report.CapturedStdOutErr != "")
// given all that - do we have any actual content to show? or are we a single denoter in a stream?
reportHasContent := v.Is(types.VerbosityLevelVeryVerbose) || showTimeline || showSeparateVisibilityAlwaysReportsSection || showSeparateStdSection || report.Failed() || (v.Is(types.VerbosityLevelVerbose) && !report.State.Is(types.SpecStateSkipped))
// should we show a runtime?
includeRuntime := !report.State.Is(types.SpecStateSkipped|types.SpecStatePending) || (report.State.Is(types.SpecStateSkipped) && report.Failure.Message != "")
// should we show the codelocation block?
showCodeLocation := !timelineHasBeenStreaming || !report.State.Is(types.SpecStatePassed)
switch report.State {
case types.SpecStatePassed:
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) && !reportHasContent {
return
}
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
header = fmt.Sprintf("%s PASSED", header)
}
if report.NumAttempts > 1 && report.MaxFlakeAttempts > 1 {
header, reportHasContent = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), true
}
case types.SpecStatePending:
header = "P"
if v.GT(types.VerbosityLevelSuccinct) {
header, reportHasContent = "P [PENDING]", true
}
case types.SpecStateSkipped:
header = "S"
if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && report.Failure.Message != "") {
header, reportHasContent = "S [SKIPPED]", true
}
default:
header = fmt.Sprintf("%s [%s]", header, r.humanReadableState(report.State))
if report.MaxMustPassRepeatedly > 1 {
header = fmt.Sprintf("%s DURING REPETITION #%d", header, report.NumAttempts)
}
}
// If we have no content to show, jsut emit the header and return
if !reportHasContent {
r.emit(r.f(highlightColor + header + "{{/}}"))
return
}
if includeRuntime {
header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds())
}
// Emit header
if !timelineHasBeenStreaming {
r.emitDelimiter(0)
}
r.emitBlock(r.f(highlightColor + header + "{{/}}"))
if showCodeLocation {
r.emitBlock(r.codeLocationBlock(report, highlightColor, v.Is(types.VerbosityLevelVeryVerbose), false))
}
//Emit Stdout/Stderr Output
if showSeparateStdSection {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Captured StdOut/StdErr Output >>{{/}}"))
r.emitBlock(r.fi(1, "%s", report.CapturedStdOutErr))
r.emitBlock(r.fi(1, "{{gray}}<< Captured StdOut/StdErr Output{{/}}"))
}
if showSeparateVisibilityAlwaysReportsSection {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Report Entries >>{{/}}"))
for _, entry := range report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways) {
r.emitReportEntry(1, entry)
}
r.emitBlock(r.fi(1, "{{gray}}<< Report Entries{{/}}"))
}
if showTimeline {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Timeline >>{{/}}"))
r.emitTimeline(1, report, timeline)
r.emitBlock(r.fi(1, "{{gray}}<< Timeline{{/}}"))
}
// Emit Failure Message
if !report.Failure.IsZero() && !v.Is(types.VerbosityLevelVeryVerbose) {
r.emitBlock("\n")
r.emitFailure(1, report.State, report.Failure, true)
if len(report.AdditionalFailures) > 0 {
r.emitBlock(r.fi(1, "\nThere were {{bold}}{{red}}additional failures{{/}} detected. To view them in detail run {{bold}}ginkgo -vv{{/}}"))
}
}
r.emitDelimiter(0)
}
func (r *DefaultReporter) highlightColorForState(state types.SpecState) string {
switch state {
case types.SpecStatePassed:
return "{{green}}"
case types.SpecStatePending:
return "{{yellow}}"
case types.SpecStateSkipped:
return "{{cyan}}"
case types.SpecStateFailed:
return "{{red}}"
case types.SpecStateTimedout:
return "{{orange}}"
case types.SpecStatePanicked:
return "{{magenta}}"
case types.SpecStateInterrupted:
return "{{orange}}"
case types.SpecStateAborted:
return "{{coral}}"
default:
return "{{gray}}"
}
}
func (r *DefaultReporter) humanReadableState(state types.SpecState) string {
return strings.ToUpper(state.String())
}
func (r *DefaultReporter) emitTimeline(indent uint, report types.SpecReport, timeline types.Timeline) {
isVeryVerbose := r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose)
gw := report.CapturedGinkgoWriterOutput
cursor := 0
for _, entry := range timeline {
tl := entry.GetTimelineLocation()
if tl.Offset < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:tl.Offset]))
cursor = tl.Offset
} else if cursor < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:]))
cursor = len(gw)
}
switch x := entry.(type) {
case types.Failure:
if isVeryVerbose {
r.emitFailure(indent, report.State, x, false)
} else {
r.emitShortFailure(indent, report.State, x)
}
case types.AdditionalFailure:
if isVeryVerbose {
r.emitFailure(indent, x.State, x.Failure, true)
} else {
r.emitShortFailure(indent, x.State, x.Failure)
}
case types.ReportEntry:
r.emitReportEntry(indent, x)
case types.ProgressReport:
r.emitProgressReport(indent, false, x)
case types.SpecEvent:
if isVeryVerbose || !x.IsOnlyVisibleAtVeryVerbose() || r.conf.ShowNodeEvents {
r.emitSpecEvent(indent, x, isVeryVerbose)
}
}
}
if cursor < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:]))
}
}
func (r *DefaultReporter) EmitFailure(state types.SpecState, failure types.Failure) {
if r.conf.Verbosity().Is(types.VerbosityLevelVerbose) {
r.emitShortFailure(1, state, failure)
} else if r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose) {
r.emitFailure(1, state, failure, true)
}
}
func (r *DefaultReporter) emitShortFailure(indent uint, state types.SpecState, failure types.Failure) {
r.emitBlock(r.fi(indent, r.highlightColorForState(state)+"[%s]{{/}} in [%s] - %s {{gray}}@ %s{{/}}",
r.humanReadableState(state),
failure.FailureNodeType,
failure.Location,
failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT),
))
}
func (r *DefaultReporter) emitFailure(indent uint, state types.SpecState, failure types.Failure, includeAdditionalFailure bool) {
highlightColor := r.highlightColorForState(state)
r.emitBlock(r.fi(indent, highlightColor+"[%s] %s{{/}}", r.humanReadableState(state), failure.Message))
r.emitBlock(r.fi(indent, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}} {{gray}}@ %s{{/}}\n", failure.FailureNodeType, failure.Location, failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
if failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.ForwardedPanic))
}
if r.conf.FullTrace || failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(indent, highlightColor+"Full Stack Trace{{/}}"))
r.emitBlock(r.fi(indent+1, "%s", failure.Location.FullStackTrace))
}
if !failure.ProgressReport.IsZero() {
r.emitBlock("\n")
r.emitProgressReport(indent, false, failure.ProgressReport)
}
if failure.AdditionalFailure != nil && includeAdditionalFailure {
r.emitBlock("\n")
r.emitFailure(indent, failure.AdditionalFailure.State, failure.AdditionalFailure.Failure, true)
}
}
func (r *DefaultReporter) EmitProgressReport(report types.ProgressReport) { func (r *DefaultReporter) EmitProgressReport(report types.ProgressReport) {
r.emitDelimiter() r.emitDelimiter(1)
if report.RunningInParallel { if report.RunningInParallel {
r.emit(r.f("{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess)) r.emit(r.fi(1, "{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess))
} }
r.emitProgressReport(0, true, report) shouldEmitGW := report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)
r.emitDelimiter() r.emitProgressReport(1, shouldEmitGW, report)
r.emitDelimiter(1)
} }
func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput bool, report types.ProgressReport) { func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput bool, report types.ProgressReport) {
if report.Message != "" {
r.emitBlock(r.fi(indent, report.Message+"\n"))
indent += 1
}
if report.LeafNodeText != "" { if report.LeafNodeText != "" {
subjectIndent := indent
if len(report.ContainerHierarchyTexts) > 0 { if len(report.ContainerHierarchyTexts) > 0 {
r.emit(r.fi(indent, r.cycleJoin(report.ContainerHierarchyTexts, " "))) r.emit(r.fi(indent, r.cycleJoin(report.ContainerHierarchyTexts, " ")))
r.emit(" ") r.emit(" ")
subjectIndent = 0
} }
r.emit(r.f("{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time.Sub(report.SpecStartTime).Round(time.Millisecond))) r.emit(r.fi(subjectIndent, "{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time().Sub(report.SpecStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.LeafNodeLocation)) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.LeafNodeLocation))
indent += 1 indent += 1
} }
@ -344,12 +461,12 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
r.emit(r.f(" {{bold}}{{orange}}%s{{/}}", report.CurrentNodeText)) r.emit(r.f(" {{bold}}{{orange}}%s{{/}}", report.CurrentNodeText))
} }
r.emit(r.f(" (Node Runtime: %s)\n", report.Time.Sub(report.CurrentNodeStartTime).Round(time.Millisecond))) r.emit(r.f(" (Node Runtime: %s)\n", report.Time().Sub(report.CurrentNodeStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentNodeLocation)) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentNodeLocation))
indent += 1 indent += 1
} }
if report.CurrentStepText != "" { if report.CurrentStepText != "" {
r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time.Sub(report.CurrentStepStartTime).Round(time.Millisecond))) r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time().Sub(report.CurrentStepStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentStepLocation)) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentStepLocation))
indent += 1 indent += 1
} }
@ -358,9 +475,19 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
indent -= 1 indent -= 1
} }
if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" && (report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)) { if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" {
r.emit("\n") r.emit("\n")
r.emitGinkgoWriterOutput(indent, report.CapturedGinkgoWriterOutput, 10) r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}"))
limit, lines := 10, strings.Split(report.CapturedGinkgoWriterOutput, "\n")
if len(lines) <= limit {
r.emitBlock(r.fi(indent+1, "%s", report.CapturedGinkgoWriterOutput))
} else {
r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}"))
for _, line := range lines[len(lines)-limit-1:] {
r.emitBlock(r.fi(indent+1, "%s", line))
}
}
r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}"))
} }
if !report.SpecGoroutine().IsZero() { if !report.SpecGoroutine().IsZero() {
@ -369,6 +496,18 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
r.emitGoroutines(indent, report.SpecGoroutine()) r.emitGoroutines(indent, report.SpecGoroutine())
} }
if len(report.AdditionalReports) > 0 {
r.emit("\n")
r.emitBlock(r.fi(indent, "{{gray}}Begin Additional Progress Reports >>{{/}}"))
for i, additionalReport := range report.AdditionalReports {
r.emit(r.fi(indent+1, additionalReport))
if i < len(report.AdditionalReports)-1 {
r.emitBlock(r.fi(indent+1, "{{gray}}%s{{/}}", strings.Repeat("-", 10)))
}
}
r.emitBlock(r.fi(indent, "{{gray}}<< End Additional Progress Reports{{/}}"))
}
highlightedGoroutines := report.HighlightedGoroutines() highlightedGoroutines := report.HighlightedGoroutines()
if len(highlightedGoroutines) > 0 { if len(highlightedGoroutines) > 0 {
r.emit("\n") r.emit("\n")
@ -384,22 +523,48 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
} }
} }
func (r *DefaultReporter) emitGinkgoWriterOutput(indent uint, output string, limit int) { func (r *DefaultReporter) EmitReportEntry(entry types.ReportEntry) {
r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}")) if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || entry.Visibility == types.ReportEntryVisibilityNever {
if limit == 0 { return
r.emitBlock(r.fi(indent+1, "%s", output)) }
} else { r.emitReportEntry(1, entry)
lines := strings.Split(output, "\n") }
if len(lines) <= limit {
r.emitBlock(r.fi(indent+1, "%s", output)) func (r *DefaultReporter) emitReportEntry(indent uint, entry types.ReportEntry) {
} else { r.emitBlock(r.fi(indent, "{{bold}}"+entry.Name+"{{gray}} "+fmt.Sprintf("- %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT))))
r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}")) if representation := entry.StringRepresentation(); representation != "" {
for _, line := range lines[len(lines)-limit-1:] { r.emitBlock(r.fi(indent+1, representation))
r.emitBlock(r.fi(indent+1, "%s", line)) }
} }
}
func (r *DefaultReporter) EmitSpecEvent(event types.SpecEvent) {
v := r.conf.Verbosity()
if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && (r.conf.ShowNodeEvents || !event.IsOnlyVisibleAtVeryVerbose())) {
r.emitSpecEvent(1, event, r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose))
}
}
func (r *DefaultReporter) emitSpecEvent(indent uint, event types.SpecEvent, includeLocation bool) {
location := ""
if includeLocation {
location = fmt.Sprintf("- %s ", event.CodeLocation.String())
}
switch event.SpecEventType {
case types.SpecEventInvalid:
return
case types.SpecEventByStart:
r.emitBlock(r.fi(indent, "{{bold}}STEP:{{/}} %s {{gray}}%s@ %s{{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventByEnd:
r.emitBlock(r.fi(indent, "{{bold}}END STEP:{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond)))
case types.SpecEventNodeStart:
r.emitBlock(r.fi(indent, "> Enter {{bold}}[%s]{{/}} %s {{gray}}%s@ %s{{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventNodeEnd:
r.emitBlock(r.fi(indent, "< Exit {{bold}}[%s]{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond)))
case types.SpecEventSpecRepeat:
r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{green}}Passed{{/}}{{bold}}. Repeating %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventSpecRetry:
r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{red}}Failed{{/}}{{bold}}. Retrying %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
} }
r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}"))
} }
func (r *DefaultReporter) emitGoroutines(indent uint, goroutines ...types.Goroutine) { func (r *DefaultReporter) emitGoroutines(indent uint, goroutines ...types.Goroutine) {
@ -457,31 +622,37 @@ func (r *DefaultReporter) emitSource(indent uint, fc types.FunctionCall) {
/* Emitting to the writer */ /* Emitting to the writer */
func (r *DefaultReporter) emit(s string) { func (r *DefaultReporter) emit(s string) {
if len(s) > 0 { r._emit(s, false, false)
r.lastChar = s[len(s)-1:]
r.lastEmissionWasDelimiter = false
r.writer.Write([]byte(s))
}
} }
func (r *DefaultReporter) emitBlock(s string) { func (r *DefaultReporter) emitBlock(s string) {
if len(s) > 0 { r._emit(s, true, false)
if r.lastChar != "\n" {
r.emit("\n")
}
r.emit(s)
if r.lastChar != "\n" {
r.emit("\n")
}
}
} }
func (r *DefaultReporter) emitDelimiter() { func (r *DefaultReporter) emitDelimiter(indent uint) {
if r.lastEmissionWasDelimiter { r._emit(r.fi(indent, "{{gray}}%s{{/}}", strings.Repeat("-", 30)), true, true)
}
// a bit ugly - but we're trying to minimize locking on this hot codepath
func (r *DefaultReporter) _emit(s string, block bool, isDelimiter bool) {
if len(s) == 0 {
return return
} }
r.emitBlock(r.f("{{gray}}%s{{/}}", strings.Repeat("-", 30))) r.lock.Lock()
r.lastEmissionWasDelimiter = true defer r.lock.Unlock()
if isDelimiter && r.lastEmissionWasDelimiter {
return
}
if block && !r.lastCharWasNewline {
r.writer.Write([]byte("\n"))
}
r.lastCharWasNewline = (s[len(s)-1:] == "\n")
r.writer.Write([]byte(s))
if block && !r.lastCharWasNewline {
r.writer.Write([]byte("\n"))
r.lastCharWasNewline = true
}
r.lastEmissionWasDelimiter = isDelimiter
} }
/* Rendering text */ /* Rendering text */
@ -497,13 +668,14 @@ func (r *DefaultReporter) cycleJoin(elements []string, joiner string) string {
return r.formatter.CycleJoin(elements, joiner, []string{"{{/}}", "{{gray}}"}) return r.formatter.CycleJoin(elements, joiner, []string{"{{/}}", "{{gray}}"})
} }
func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, succinct bool, usePreciseFailureLocation bool) string { func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, veryVerbose bool, usePreciseFailureLocation bool) string {
texts, locations, labels := []string{}, []types.CodeLocation{}, [][]string{} texts, locations, labels := []string{}, []types.CodeLocation{}, [][]string{}
texts, locations, labels = append(texts, report.ContainerHierarchyTexts...), append(locations, report.ContainerHierarchyLocations...), append(labels, report.ContainerHierarchyLabels...) texts, locations, labels = append(texts, report.ContainerHierarchyTexts...), append(locations, report.ContainerHierarchyLocations...), append(labels, report.ContainerHierarchyLabels...)
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
texts = append(texts, r.f("[%s] %s", report.LeafNodeType, report.LeafNodeText)) texts = append(texts, r.f("[%s] %s", report.LeafNodeType, report.LeafNodeText))
} else { } else {
texts = append(texts, report.LeafNodeText) texts = append(texts, r.f(report.LeafNodeText))
} }
labels = append(labels, report.LeafNodeLabels) labels = append(labels, report.LeafNodeLabels)
locations = append(locations, report.LeafNodeLocation) locations = append(locations, report.LeafNodeLocation)
@ -513,24 +685,58 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo
failureLocation = report.Failure.Location failureLocation = report.Failure.Location
} }
highlightIndex := -1
switch report.Failure.FailureNodeContext { switch report.Failure.FailureNodeContext {
case types.FailureNodeAtTopLevel: case types.FailureNodeAtTopLevel:
texts = append([]string{r.f(highlightColor+"{{bold}}TOP-LEVEL [%s]{{/}}", report.Failure.FailureNodeType)}, texts...) texts = append([]string{fmt.Sprintf("TOP-LEVEL [%s]", report.Failure.FailureNodeType)}, texts...)
locations = append([]types.CodeLocation{failureLocation}, locations...) locations = append([]types.CodeLocation{failureLocation}, locations...)
labels = append([][]string{{}}, labels...) labels = append([][]string{{}}, labels...)
highlightIndex = 0
case types.FailureNodeInContainer: case types.FailureNodeInContainer:
i := report.Failure.FailureNodeContainerIndex i := report.Failure.FailureNodeContainerIndex
texts[i] = r.f(highlightColor+"{{bold}}%s [%s]{{/}}", texts[i], report.Failure.FailureNodeType) texts[i] = fmt.Sprintf("%s [%s]", texts[i], report.Failure.FailureNodeType)
locations[i] = failureLocation locations[i] = failureLocation
highlightIndex = i
case types.FailureNodeIsLeafNode: case types.FailureNodeIsLeafNode:
i := len(texts) - 1 i := len(texts) - 1
texts[i] = r.f(highlightColor+"{{bold}}[%s] %s{{/}}", report.LeafNodeType, report.LeafNodeText) texts[i] = fmt.Sprintf("[%s] %s", report.LeafNodeType, report.LeafNodeText)
locations[i] = failureLocation locations[i] = failureLocation
highlightIndex = i
default:
//there is no failure, so we highlight the leaf ndoe
highlightIndex = len(texts) - 1
} }
out := "" out := ""
if succinct { if veryVerbose {
out += r.f("%s", r.cycleJoin(texts, " ")) for i := range texts {
if i == highlightIndex {
out += r.fi(uint(i), highlightColor+"{{bold}}%s{{/}}", texts[i])
} else {
out += r.fi(uint(i), "%s", texts[i])
}
if len(labels[i]) > 0 {
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", "))
}
out += "\n"
out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i])
}
} else {
for i := range texts {
style := "{{/}}"
if i%2 == 1 {
style = "{{gray}}"
}
if i == highlightIndex {
style = highlightColor + "{{bold}}"
}
out += r.f(style+"%s", texts[i])
if i < len(texts)-1 {
out += " "
} else {
out += r.f("{{/}}")
}
}
flattenedLabels := report.Labels() flattenedLabels := report.Labels()
if len(flattenedLabels) > 0 { if len(flattenedLabels) > 0 {
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(flattenedLabels, ", ")) out += r.f(" {{coral}}[%s]{{/}}", strings.Join(flattenedLabels, ", "))
@ -539,17 +745,15 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo
if usePreciseFailureLocation { if usePreciseFailureLocation {
out += r.f("{{gray}}%s{{/}}", failureLocation) out += r.f("{{gray}}%s{{/}}", failureLocation)
} else { } else {
out += r.f("{{gray}}%s{{/}}", locations[len(locations)-1]) leafLocation := locations[len(locations)-1]
} if (report.Failure.FailureNodeLocation != types.CodeLocation{}) && (report.Failure.FailureNodeLocation != leafLocation) {
} else { out += r.fi(1, highlightColor+"[%s]{{/}} {{gray}}%s{{/}}\n", report.Failure.FailureNodeType, report.Failure.FailureNodeLocation)
for i := range texts { out += r.fi(1, "{{gray}}[%s] %s{{/}}", report.LeafNodeType, leafLocation)
out += r.fi(uint(i), "%s", texts[i]) } else {
if len(labels[i]) > 0 { out += r.f("{{gray}}%s{{/}}", leafLocation)
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", "))
} }
out += "\n"
out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i])
} }
} }
return out return out
} }

View File

@ -35,7 +35,7 @@ func ReportViaDeprecatedReporter(reporter DeprecatedReporter, report types.Repor
FailOnPending: report.SuiteConfig.FailOnPending, FailOnPending: report.SuiteConfig.FailOnPending,
FailFast: report.SuiteConfig.FailFast, FailFast: report.SuiteConfig.FailFast,
FlakeAttempts: report.SuiteConfig.FlakeAttempts, FlakeAttempts: report.SuiteConfig.FlakeAttempts,
EmitSpecProgress: report.SuiteConfig.EmitSpecProgress, EmitSpecProgress: false,
DryRun: report.SuiteConfig.DryRun, DryRun: report.SuiteConfig.DryRun,
ParallelNode: report.SuiteConfig.ParallelProcess, ParallelNode: report.SuiteConfig.ParallelProcess,
ParallelTotal: report.SuiteConfig.ParallelTotal, ParallelTotal: report.SuiteConfig.ParallelTotal,

View File

@ -15,12 +15,32 @@ import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"time"
"github.com/onsi/ginkgo/v2/config" "github.com/onsi/ginkgo/v2/config"
"github.com/onsi/ginkgo/v2/types" "github.com/onsi/ginkgo/v2/types"
) )
type JunitReportConfig struct {
// Spec States for which no timeline should be emitted for system-err
// set this to types.SpecStatePassed|types.SpecStateSkipped|types.SpecStatePending to only match failing specs
OmitTimelinesForSpecState types.SpecState
// Enable OmitFailureMessageAttr to prevent failure messages appearing in the "message" attribute of the Failure and Error tags
OmitFailureMessageAttr bool
//Enable OmitCapturedStdOutErr to prevent captured stdout/stderr appearing in system-out
OmitCapturedStdOutErr bool
// Enable OmitSpecLabels to prevent labels from appearing in the spec name
OmitSpecLabels bool
// Enable OmitLeafNodeType to prevent the spec leaf node type from appearing in the spec name
OmitLeafNodeType bool
// Enable OmitSuiteSetupNodes to prevent the creation of testcase entries for setup nodes
OmitSuiteSetupNodes bool
}
type JUnitTestSuites struct { type JUnitTestSuites struct {
XMLName xml.Name `xml:"testsuites"` XMLName xml.Name `xml:"testsuites"`
// Tests maps onto the total number of specs in all test suites (this includes any suite nodes such as BeforeSuite) // Tests maps onto the total number of specs in all test suites (this includes any suite nodes such as BeforeSuite)
@ -128,6 +148,10 @@ type JUnitFailure struct {
} }
func GenerateJUnitReport(report types.Report, dst string) error { func GenerateJUnitReport(report types.Report, dst string) error {
return GenerateJUnitReportWithConfig(report, dst, JunitReportConfig{})
}
func GenerateJUnitReportWithConfig(report types.Report, dst string, config JunitReportConfig) error {
suite := JUnitTestSuite{ suite := JUnitTestSuite{
Name: report.SuiteDescription, Name: report.SuiteDescription,
Package: report.SuitePath, Package: report.SuitePath,
@ -149,7 +173,6 @@ func GenerateJUnitReport(report types.Report, dst string) error {
{"FailOnPending", fmt.Sprintf("%t", report.SuiteConfig.FailOnPending)}, {"FailOnPending", fmt.Sprintf("%t", report.SuiteConfig.FailOnPending)},
{"FailFast", fmt.Sprintf("%t", report.SuiteConfig.FailFast)}, {"FailFast", fmt.Sprintf("%t", report.SuiteConfig.FailFast)},
{"FlakeAttempts", fmt.Sprintf("%d", report.SuiteConfig.FlakeAttempts)}, {"FlakeAttempts", fmt.Sprintf("%d", report.SuiteConfig.FlakeAttempts)},
{"EmitSpecProgress", fmt.Sprintf("%t", report.SuiteConfig.EmitSpecProgress)},
{"DryRun", fmt.Sprintf("%t", report.SuiteConfig.DryRun)}, {"DryRun", fmt.Sprintf("%t", report.SuiteConfig.DryRun)},
{"ParallelTotal", fmt.Sprintf("%d", report.SuiteConfig.ParallelTotal)}, {"ParallelTotal", fmt.Sprintf("%d", report.SuiteConfig.ParallelTotal)},
{"OutputInterceptorMode", report.SuiteConfig.OutputInterceptorMode}, {"OutputInterceptorMode", report.SuiteConfig.OutputInterceptorMode},
@ -157,22 +180,33 @@ func GenerateJUnitReport(report types.Report, dst string) error {
}, },
} }
for _, spec := range report.SpecReports { for _, spec := range report.SpecReports {
if config.OmitSuiteSetupNodes && spec.LeafNodeType != types.NodeTypeIt {
continue
}
name := fmt.Sprintf("[%s]", spec.LeafNodeType) name := fmt.Sprintf("[%s]", spec.LeafNodeType)
if config.OmitLeafNodeType {
name = ""
}
if spec.FullText() != "" { if spec.FullText() != "" {
name = name + " " + spec.FullText() name = name + " " + spec.FullText()
} }
labels := spec.Labels() labels := spec.Labels()
if len(labels) > 0 { if len(labels) > 0 && !config.OmitSpecLabels {
name = name + " [" + strings.Join(labels, ", ") + "]" name = name + " [" + strings.Join(labels, ", ") + "]"
} }
name = strings.TrimSpace(name)
test := JUnitTestCase{ test := JUnitTestCase{
Name: name, Name: name,
Classname: report.SuiteDescription, Classname: report.SuiteDescription,
Status: spec.State.String(), Status: spec.State.String(),
Time: spec.RunTime.Seconds(), Time: spec.RunTime.Seconds(),
SystemOut: systemOutForUnstructuredReporters(spec), }
SystemErr: systemErrForUnstructuredReporters(spec), if !spec.State.Is(config.OmitTimelinesForSpecState) {
test.SystemErr = systemErrForUnstructuredReporters(spec)
}
if !config.OmitCapturedStdOutErr {
test.SystemOut = systemOutForUnstructuredReporters(spec)
} }
suite.Tests += 1 suite.Tests += 1
@ -191,28 +225,50 @@ func GenerateJUnitReport(report types.Report, dst string) error {
test.Failure = &JUnitFailure{ test.Failure = &JUnitFailure{
Message: spec.Failure.Message, Message: spec.Failure.Message,
Type: "failed", Type: "failed",
Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
}
suite.Failures += 1
case types.SpecStateTimedout:
test.Failure = &JUnitFailure{
Message: spec.Failure.Message,
Type: "timedout",
Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
} }
suite.Failures += 1 suite.Failures += 1
case types.SpecStateInterrupted: case types.SpecStateInterrupted:
test.Error = &JUnitError{ test.Error = &JUnitError{
Message: "interrupted", Message: spec.Failure.Message,
Type: "interrupted", Type: "interrupted",
Description: interruptDescriptionForUnstructuredReporters(spec.Failure), Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Error.Message = ""
} }
suite.Errors += 1 suite.Errors += 1
case types.SpecStateAborted: case types.SpecStateAborted:
test.Failure = &JUnitFailure{ test.Failure = &JUnitFailure{
Message: spec.Failure.Message, Message: spec.Failure.Message,
Type: "aborted", Type: "aborted",
Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
} }
suite.Errors += 1 suite.Errors += 1
case types.SpecStatePanicked: case types.SpecStatePanicked:
test.Error = &JUnitError{ test.Error = &JUnitError{
Message: spec.Failure.ForwardedPanic, Message: spec.Failure.ForwardedPanic,
Type: "panicked", Type: "panicked",
Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Error.Message = ""
} }
suite.Errors += 1 suite.Errors += 1
} }
@ -278,52 +334,27 @@ func MergeAndCleanupJUnitReports(sources []string, dst string) ([]string, error)
return messages, f.Close() return messages, f.Close()
} }
func interruptDescriptionForUnstructuredReporters(failure types.Failure) string { func failureDescriptionForUnstructuredReporters(spec types.SpecReport) string {
out := &strings.Builder{} out := &strings.Builder{}
out.WriteString(failure.Message + "\n") NewDefaultReporter(types.ReporterConfig{NoColor: true, VeryVerbose: true}, out).emitFailure(0, spec.State, spec.Failure, true)
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(failure.ProgressReport) if len(spec.AdditionalFailures) > 0 {
out.WriteString("\nThere were additional failures detected after the initial failure. These are visible in the timeline\n")
}
return out.String() return out.String()
} }
func systemErrForUnstructuredReporters(spec types.SpecReport) string { func systemErrForUnstructuredReporters(spec types.SpecReport) string {
return RenderTimeline(spec, true)
}
func RenderTimeline(spec types.SpecReport, noColor bool) string {
out := &strings.Builder{} out := &strings.Builder{}
gw := spec.CapturedGinkgoWriterOutput NewDefaultReporter(types.ReporterConfig{NoColor: noColor, VeryVerbose: true}, out).emitTimeline(0, spec, spec.Timeline())
cursor := 0
for _, pr := range spec.ProgressReports {
if cursor < pr.GinkgoWriterOffset {
if pr.GinkgoWriterOffset < len(gw) {
out.WriteString(gw[cursor:pr.GinkgoWriterOffset])
cursor = pr.GinkgoWriterOffset
} else if cursor < len(gw) {
out.WriteString(gw[cursor:])
cursor = len(gw)
}
}
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(pr)
}
if cursor < len(gw) {
out.WriteString(gw[cursor:])
}
return out.String() return out.String()
} }
func systemOutForUnstructuredReporters(spec types.SpecReport) string { func systemOutForUnstructuredReporters(spec types.SpecReport) string {
systemOut := spec.CapturedStdOutErr return spec.CapturedStdOutErr
if len(spec.ReportEntries) > 0 {
systemOut += "\nReport Entries:\n"
for i, entry := range spec.ReportEntries {
systemOut += fmt.Sprintf("%s\n%s\n%s\n", entry.Name, entry.Location, entry.Time.Format(time.RFC3339Nano))
if representation := entry.StringRepresentation(); representation != "" {
systemOut += representation + "\n"
}
if i+1 < len(spec.ReportEntries) {
systemOut += "--\n"
}
}
}
return systemOut
} }
// Deprecated JUnitReporter (so folks can still compile their suites) // Deprecated JUnitReporter (so folks can still compile their suites)

View File

@ -9,13 +9,21 @@ type Reporter interface {
WillRun(report types.SpecReport) WillRun(report types.SpecReport)
DidRun(report types.SpecReport) DidRun(report types.SpecReport)
SuiteDidEnd(report types.Report) SuiteDidEnd(report types.Report)
//Timeline emission
EmitFailure(state types.SpecState, failure types.Failure)
EmitProgressReport(progressReport types.ProgressReport) EmitProgressReport(progressReport types.ProgressReport)
EmitReportEntry(entry types.ReportEntry)
EmitSpecEvent(event types.SpecEvent)
} }
type NoopReporter struct{} type NoopReporter struct{}
func (n NoopReporter) SuiteWillBegin(report types.Report) {} func (n NoopReporter) SuiteWillBegin(report types.Report) {}
func (n NoopReporter) WillRun(report types.SpecReport) {} func (n NoopReporter) WillRun(report types.SpecReport) {}
func (n NoopReporter) DidRun(report types.SpecReport) {} func (n NoopReporter) DidRun(report types.SpecReport) {}
func (n NoopReporter) SuiteDidEnd(report types.Report) {} func (n NoopReporter) SuiteDidEnd(report types.Report) {}
func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {} func (n NoopReporter) EmitFailure(state types.SpecState, failure types.Failure) {}
func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {}
func (n NoopReporter) EmitReportEntry(entry types.ReportEntry) {}
func (n NoopReporter) EmitSpecEvent(event types.SpecEvent) {}

View File

@ -60,15 +60,19 @@ func GenerateTeamcityReport(report types.Report, dst string) error {
} }
fmt.Fprintf(f, "##teamcity[testIgnored name='%s' message='%s']\n", name, tcEscape(message)) fmt.Fprintf(f, "##teamcity[testIgnored name='%s' message='%s']\n", name, tcEscape(message))
case types.SpecStateFailed: case types.SpecStateFailed:
details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='failed - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='failed - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details))
case types.SpecStatePanicked: case types.SpecStatePanicked:
details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='panicked - %s' details='%s']\n", name, tcEscape(spec.Failure.ForwardedPanic), tcEscape(details)) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='panicked - %s' details='%s']\n", name, tcEscape(spec.Failure.ForwardedPanic), tcEscape(details))
case types.SpecStateTimedout:
details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='timedout - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details))
case types.SpecStateInterrupted: case types.SpecStateInterrupted:
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='interrupted' details='%s']\n", name, tcEscape(interruptDescriptionForUnstructuredReporters(spec.Failure))) details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='interrupted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details))
case types.SpecStateAborted: case types.SpecStateAborted:
details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) details := failureDescriptionForUnstructuredReporters(spec)
fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='aborted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='aborted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details))
} }

View File

@ -7,6 +7,7 @@ import (
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync"
) )
type CodeLocation struct { type CodeLocation struct {
@ -38,6 +39,73 @@ func (codeLocation CodeLocation) ContentsOfLine() string {
return lines[codeLocation.LineNumber-1] return lines[codeLocation.LineNumber-1]
} }
type codeLocationLocator struct {
pcs map[uintptr]bool
helpers map[string]bool
lock *sync.Mutex
}
func (c *codeLocationLocator) addHelper(pc uintptr) {
c.lock.Lock()
defer c.lock.Unlock()
if c.pcs[pc] {
return
}
c.lock.Unlock()
f := runtime.FuncForPC(pc)
c.lock.Lock()
if f == nil {
return
}
c.helpers[f.Name()] = true
c.pcs[pc] = true
}
func (c *codeLocationLocator) hasHelper(name string) bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.helpers[name]
}
func (c *codeLocationLocator) getCodeLocation(skip int) CodeLocation {
pc := make([]uintptr, 40)
n := runtime.Callers(skip+2, pc)
if n == 0 {
return CodeLocation{}
}
pc = pc[:n]
frames := runtime.CallersFrames(pc)
for {
frame, more := frames.Next()
if !c.hasHelper(frame.Function) {
return CodeLocation{FileName: frame.File, LineNumber: frame.Line}
}
if !more {
break
}
}
return CodeLocation{}
}
var clLocator = &codeLocationLocator{
pcs: map[uintptr]bool{},
helpers: map[string]bool{},
lock: &sync.Mutex{},
}
// MarkAsHelper is used by GinkgoHelper to mark the caller (appropriately offset by skip)as a helper. You can use this directly if you need to provide an optional `skip` to mark functions further up the call stack as helpers.
func MarkAsHelper(optionalSkip ...int) {
skip := 1
if len(optionalSkip) > 0 {
skip += optionalSkip[0]
}
pc, _, _, ok := runtime.Caller(skip)
if ok {
clLocator.addHelper(pc)
}
}
func NewCustomCodeLocation(message string) CodeLocation { func NewCustomCodeLocation(message string) CodeLocation {
return CodeLocation{ return CodeLocation{
CustomMessage: message, CustomMessage: message,
@ -45,14 +113,13 @@ func NewCustomCodeLocation(message string) CodeLocation {
} }
func NewCodeLocation(skip int) CodeLocation { func NewCodeLocation(skip int) CodeLocation {
_, file, line, _ := runtime.Caller(skip + 1) return clLocator.getCodeLocation(skip + 1)
return CodeLocation{FileName: file, LineNumber: line}
} }
func NewCodeLocationWithStackTrace(skip int) CodeLocation { func NewCodeLocationWithStackTrace(skip int) CodeLocation {
_, file, line, _ := runtime.Caller(skip + 1) cl := clLocator.getCodeLocation(skip + 1)
stackTrace := PruneStack(string(debug.Stack()), skip+1) cl.FullStackTrace = PruneStack(string(debug.Stack()), skip+1)
return CodeLocation{FileName: file, LineNumber: line, FullStackTrace: stackTrace} return cl
} }
// PruneStack removes references to functions that are internal to Ginkgo // PruneStack removes references to functions that are internal to Ginkgo

View File

@ -8,6 +8,7 @@ package types
import ( import (
"flag" "flag"
"os" "os"
"path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -26,13 +27,14 @@ type SuiteConfig struct {
FailOnPending bool FailOnPending bool
FailFast bool FailFast bool
FlakeAttempts int FlakeAttempts int
EmitSpecProgress bool
DryRun bool DryRun bool
PollProgressAfter time.Duration PollProgressAfter time.Duration
PollProgressInterval time.Duration PollProgressInterval time.Duration
Timeout time.Duration Timeout time.Duration
EmitSpecProgress bool // this is deprecated but its removal is causing compile issue for some users that were setting it manually
OutputInterceptorMode string OutputInterceptorMode string
SourceRoots []string SourceRoots []string
GracePeriod time.Duration
ParallelProcess int ParallelProcess int
ParallelTotal int ParallelTotal int
@ -45,6 +47,7 @@ func NewDefaultSuiteConfig() SuiteConfig {
Timeout: time.Hour, Timeout: time.Hour,
ParallelProcess: 1, ParallelProcess: 1,
ParallelTotal: 1, ParallelTotal: 1,
GracePeriod: 30 * time.Second,
} }
} }
@ -79,13 +82,12 @@ func (vl VerbosityLevel) LT(comp VerbosityLevel) bool {
// Configuration for Ginkgo's reporter // Configuration for Ginkgo's reporter
type ReporterConfig struct { type ReporterConfig struct {
NoColor bool NoColor bool
SlowSpecThreshold time.Duration Succinct bool
Succinct bool Verbose bool
Verbose bool VeryVerbose bool
VeryVerbose bool FullTrace bool
FullTrace bool ShowNodeEvents bool
AlwaysEmitGinkgoWriter bool
JSONReport string JSONReport string
JUnitReport string JUnitReport string
@ -108,9 +110,7 @@ func (rc ReporterConfig) WillGenerateReport() bool {
} }
func NewDefaultReporterConfig() ReporterConfig { func NewDefaultReporterConfig() ReporterConfig {
return ReporterConfig{ return ReporterConfig{}
SlowSpecThreshold: 5 * time.Second,
}
} }
// Configuration for the Ginkgo CLI // Configuration for the Ginkgo CLI
@ -233,6 +233,9 @@ type deprecatedConfig struct {
SlowSpecThresholdWithFLoatUnits float64 SlowSpecThresholdWithFLoatUnits float64
Stream bool Stream bool
Notify bool Notify bool
EmitSpecProgress bool
SlowSpecThreshold time.Duration
AlwaysEmitGinkgoWriter bool
} }
// Flags // Flags
@ -273,8 +276,6 @@ var SuiteConfigFlags = GinkgoFlags{
{KeyPath: "S.DryRun", Name: "dry-run", SectionKey: "debug", DeprecatedName: "dryRun", DeprecatedDocLink: "changed-command-line-flags", {KeyPath: "S.DryRun", Name: "dry-run", SectionKey: "debug", DeprecatedName: "dryRun", DeprecatedDocLink: "changed-command-line-flags",
Usage: "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v."}, Usage: "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v."},
{KeyPath: "S.EmitSpecProgress", Name: "progress", SectionKey: "debug",
Usage: "If set, ginkgo will emit progress information as each spec runs to the GinkgoWriter."},
{KeyPath: "S.PollProgressAfter", Name: "poll-progress-after", SectionKey: "debug", UsageDefaultValue: "0", {KeyPath: "S.PollProgressAfter", Name: "poll-progress-after", SectionKey: "debug", UsageDefaultValue: "0",
Usage: "Emit node progress reports periodically if node hasn't completed after this duration."}, Usage: "Emit node progress reports periodically if node hasn't completed after this duration."},
{KeyPath: "S.PollProgressInterval", Name: "poll-progress-interval", SectionKey: "debug", UsageDefaultValue: "10s", {KeyPath: "S.PollProgressInterval", Name: "poll-progress-interval", SectionKey: "debug", UsageDefaultValue: "10s",
@ -283,6 +284,8 @@ var SuiteConfigFlags = GinkgoFlags{
Usage: "The location to look for source code when generating progress reports. You can pass multiple --source-root flags."}, Usage: "The location to look for source code when generating progress reports. You can pass multiple --source-root flags."},
{KeyPath: "S.Timeout", Name: "timeout", SectionKey: "debug", UsageDefaultValue: "1h", {KeyPath: "S.Timeout", Name: "timeout", SectionKey: "debug", UsageDefaultValue: "1h",
Usage: "Test suite fails if it does not complete within the specified timeout."}, Usage: "Test suite fails if it does not complete within the specified timeout."},
{KeyPath: "S.GracePeriod", Name: "grace-period", SectionKey: "debug", UsageDefaultValue: "30s",
Usage: "When interrupted, Ginkgo will wait for GracePeriod for the current running node to exit before moving on to the next one."},
{KeyPath: "S.OutputInterceptorMode", Name: "output-interceptor-mode", SectionKey: "debug", UsageArgument: "dup, swap, or none", {KeyPath: "S.OutputInterceptorMode", Name: "output-interceptor-mode", SectionKey: "debug", UsageArgument: "dup, swap, or none",
Usage: "If set, ginkgo will use the specified output interception strategy when running in parallel. Defaults to dup on unix and swap on windows."}, Usage: "If set, ginkgo will use the specified output interception strategy when running in parallel. Defaults to dup on unix and swap on windows."},
@ -299,6 +302,8 @@ var SuiteConfigFlags = GinkgoFlags{
{KeyPath: "D.RegexScansFilePath", DeprecatedName: "regexScansFilePath", DeprecatedDocLink: "removed--regexscansfilepath", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.RegexScansFilePath", DeprecatedName: "regexScansFilePath", DeprecatedDocLink: "removed--regexscansfilepath", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.DebugParallel", DeprecatedName: "debug", DeprecatedDocLink: "removed--debug", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.DebugParallel", DeprecatedName: "debug", DeprecatedDocLink: "removed--debug", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.EmitSpecProgress", DeprecatedName: "progress", SectionKey: "debug",
DeprecatedVersion: "2.5.0", Usage: ". The functionality provided by --progress was confusing and is no longer needed. Use --show-node-events instead to see node entry and exit events included in the timeline of failed and verbose specs. Or you can run with -vv to always see all node events. Lastly, --poll-progress-after and the PollProgressAfter decorator now provide a better mechanism for debugging specs that tend to get stuck."},
} }
// ParallelConfigFlags provides flags for the Ginkgo test process (not the CLI) // ParallelConfigFlags provides flags for the Ginkgo test process (not the CLI)
@ -315,8 +320,6 @@ var ParallelConfigFlags = GinkgoFlags{
var ReporterConfigFlags = GinkgoFlags{ var ReporterConfigFlags = GinkgoFlags{
{KeyPath: "R.NoColor", Name: "no-color", SectionKey: "output", DeprecatedName: "noColor", DeprecatedDocLink: "changed-command-line-flags", {KeyPath: "R.NoColor", Name: "no-color", SectionKey: "output", DeprecatedName: "noColor", DeprecatedDocLink: "changed-command-line-flags",
Usage: "If set, suppress color output in default reporter."}, Usage: "If set, suppress color output in default reporter."},
{KeyPath: "R.SlowSpecThreshold", Name: "slow-spec-threshold", SectionKey: "output", UsageArgument: "duration", UsageDefaultValue: "5s",
Usage: "Specs that take longer to run than this threshold are flagged as slow by the default reporter."},
{KeyPath: "R.Verbose", Name: "v", SectionKey: "output", {KeyPath: "R.Verbose", Name: "v", SectionKey: "output",
Usage: "If set, emits more output including GinkgoWriter contents."}, Usage: "If set, emits more output including GinkgoWriter contents."},
{KeyPath: "R.VeryVerbose", Name: "vv", SectionKey: "output", {KeyPath: "R.VeryVerbose", Name: "vv", SectionKey: "output",
@ -325,8 +328,8 @@ var ReporterConfigFlags = GinkgoFlags{
Usage: "If set, default reporter prints out a very succinct report"}, Usage: "If set, default reporter prints out a very succinct report"},
{KeyPath: "R.FullTrace", Name: "trace", SectionKey: "output", {KeyPath: "R.FullTrace", Name: "trace", SectionKey: "output",
Usage: "If set, default reporter prints out the full stack trace when a failure occurs"}, Usage: "If set, default reporter prints out the full stack trace when a failure occurs"},
{KeyPath: "R.AlwaysEmitGinkgoWriter", Name: "always-emit-ginkgo-writer", SectionKey: "output", DeprecatedName: "reportPassed", DeprecatedDocLink: "renamed--reportpassed", {KeyPath: "R.ShowNodeEvents", Name: "show-node-events", SectionKey: "output",
Usage: "If set, default reporter prints out captured output of passed tests."}, Usage: "If set, default reporter prints node > Enter and < Exit events when specs fail"},
{KeyPath: "R.JSONReport", Name: "json-report", UsageArgument: "filename.json", SectionKey: "output", {KeyPath: "R.JSONReport", Name: "json-report", UsageArgument: "filename.json", SectionKey: "output",
Usage: "If set, Ginkgo will generate a JSON-formatted test report at the specified location."}, Usage: "If set, Ginkgo will generate a JSON-formatted test report at the specified location."},
@ -339,6 +342,8 @@ var ReporterConfigFlags = GinkgoFlags{
Usage: "use --slow-spec-threshold instead and pass in a duration string (e.g. '5s', not '5.0')"}, Usage: "use --slow-spec-threshold instead and pass in a duration string (e.g. '5s', not '5.0')"},
{KeyPath: "D.NoisyPendings", DeprecatedName: "noisyPendings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.NoisyPendings", DeprecatedName: "noisyPendings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.NoisySkippings", DeprecatedName: "noisySkippings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.NoisySkippings", DeprecatedName: "noisySkippings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.SlowSpecThreshold", DeprecatedName: "slow-spec-threshold", SectionKey: "output", Usage: "--slow-spec-threshold has been deprecated and will be removed in a future version of Ginkgo. This feature has proved to be more noisy than useful. You can use --poll-progress-after, instead, to get more actionable feedback about potentially slow specs and understand where they might be getting stuck.", DeprecatedVersion: "2.5.0"},
{KeyPath: "D.AlwaysEmitGinkgoWriter", DeprecatedName: "always-emit-ginkgo-writer", SectionKey: "output", Usage: " - use -v instead, or one of Ginkgo's machine-readable report formats to get GinkgoWriter output for passing specs."},
} }
// BuildTestSuiteFlagSet attaches to the CommandLine flagset and provides flags for the Ginkgo test process // BuildTestSuiteFlagSet attaches to the CommandLine flagset and provides flags for the Ginkgo test process
@ -390,6 +395,10 @@ func VetConfig(flagSet GinkgoFlagSet, suiteConfig SuiteConfig, reporterConfig Re
errors = append(errors, GinkgoErrors.DryRunInParallelConfiguration()) errors = append(errors, GinkgoErrors.DryRunInParallelConfiguration())
} }
if suiteConfig.GracePeriod <= 0 {
errors = append(errors, GinkgoErrors.GracePeriodCannotBeZero())
}
if len(suiteConfig.FocusFiles) > 0 { if len(suiteConfig.FocusFiles) > 0 {
_, err := ParseFileFilters(suiteConfig.FocusFiles) _, err := ParseFileFilters(suiteConfig.FocusFiles)
if err != nil { if err != nil {
@ -592,13 +601,29 @@ func VetAndInitializeCLIAndGoConfig(cliConfig CLIConfig, goFlagsConfig GoFlagsCo
} }
// GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test // GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test
func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string) ([]string, error) { func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string, pathToInvocationPath string) ([]string, error) {
// if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure // if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure
// the built test binary can generate a coverprofile // the built test binary can generate a coverprofile
if goFlagsConfig.CoverProfile != "" { if goFlagsConfig.CoverProfile != "" {
goFlagsConfig.Cover = true goFlagsConfig.Cover = true
} }
if goFlagsConfig.CoverPkg != "" {
coverPkgs := strings.Split(goFlagsConfig.CoverPkg, ",")
adjustedCoverPkgs := make([]string, len(coverPkgs))
for i, coverPkg := range coverPkgs {
coverPkg = strings.Trim(coverPkg, " ")
if strings.HasPrefix(coverPkg, "./") {
// this is a relative coverPkg - we need to reroot it
adjustedCoverPkgs[i] = "./" + filepath.Join(pathToInvocationPath, strings.TrimPrefix(coverPkg, "./"))
} else {
// this is a package name - don't touch it
adjustedCoverPkgs[i] = coverPkg
}
}
goFlagsConfig.CoverPkg = strings.Join(adjustedCoverPkgs, ",")
}
args := []string{"test", "-c", "-o", destination, packageToBuild} args := []string{"test", "-c", "-o", destination, packageToBuild}
goArgs, err := GenerateFlagArgs( goArgs, err := GenerateFlagArgs(
GoBuildFlags, GoBuildFlags,

View File

@ -38,7 +38,7 @@ func (d deprecations) Async() Deprecation {
func (d deprecations) Measure() Deprecation { func (d deprecations) Measure() Deprecation {
return Deprecation{ return Deprecation{
Message: "Measure is deprecated and will be removed in Ginkgo V2. Please migrate to gomega/gmeasure.", Message: "Measure is deprecated and has been removed from Ginkgo V2. Any Measure tests in your spec will not run. Please migrate to gomega/gmeasure.",
DocLink: "removed-measure", DocLink: "removed-measure",
Version: "1.16.3", Version: "1.16.3",
} }
@ -83,6 +83,13 @@ func (d deprecations) Nodot() Deprecation {
} }
} }
func (d deprecations) SuppressProgressReporting() Deprecation {
return Deprecation{
Message: "Improvements to how reporters emit timeline information means that SuppressProgressReporting is no longer necessary and has been deprecated.",
Version: "2.5.0",
}
}
type DeprecationTracker struct { type DeprecationTracker struct {
deprecations map[Deprecation][]CodeLocation deprecations map[Deprecation][]CodeLocation
lock *sync.Mutex lock *sync.Mutex

View File

@ -108,8 +108,8 @@ Please ensure all assertions are inside leaf nodes such as {{bold}}BeforeEach{{/
func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocation) error { func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocation) error {
docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite" docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite"
if nodeType.Is(NodeTypeReportAfterSuite) { if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) {
docLink = "reporting-nodes---reportaftersuite" docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite"
} }
return GinkgoError{ return GinkgoError{
@ -125,8 +125,8 @@ func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocatio
func (g ginkgoErrors) SuiteNodeDuringRunPhase(nodeType NodeType, cl CodeLocation) error { func (g ginkgoErrors) SuiteNodeDuringRunPhase(nodeType NodeType, cl CodeLocation) error {
docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite" docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite"
if nodeType.Is(NodeTypeReportAfterSuite) { if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) {
docLink = "reporting-nodes---reportaftersuite" docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite"
} }
return GinkgoError{ return GinkgoError{
@ -180,6 +180,15 @@ func (g ginkgoErrors) InvalidDeclarationOfFocusedAndPending(cl CodeLocation, nod
} }
} }
func (g ginkgoErrors) InvalidDeclarationOfFlakeAttemptsAndMustPassRepeatedly(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{
Heading: "Invalid Combination of Decorators: FlakeAttempts and MustPassRepeatedly",
Message: formatter.F(`[%s] node was decorated with both FlakeAttempts and MustPassRepeatedly. At most one is allowed.`, nodeType),
CodeLocation: cl,
DocLink: "node-decorators-overview",
}
}
func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decorator interface{}) error { func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decorator interface{}) error {
return GinkgoError{ return GinkgoError{
Heading: "Unknown Decorator", Heading: "Unknown Decorator",
@ -189,20 +198,55 @@ func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decor
} }
} }
func (g ginkgoErrors) InvalidBodyTypeForContainer(t reflect.Type, cl CodeLocation, nodeType NodeType) error {
return GinkgoError{
Heading: "Invalid Function",
Message: formatter.F(`[%s] node must be passed {{bold}}func(){{/}} - i.e. functions that take nothing and return nothing. You passed {{bold}}%s{{/}} instead.`, nodeType, t),
CodeLocation: cl,
DocLink: "node-decorators-overview",
}
}
func (g ginkgoErrors) InvalidBodyType(t reflect.Type, cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) InvalidBodyType(t reflect.Type, cl CodeLocation, nodeType NodeType) error {
mustGet := "{{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}"
if nodeType.Is(NodeTypeContainer) {
mustGet = "{{bold}}func(){{/}}"
}
return GinkgoError{ return GinkgoError{
Heading: "Invalid Function", Heading: "Invalid Function",
Message: formatter.F(`[%s] node must be passed {{bold}}func(){{/}} - i.e. functions that take nothing and return nothing. Message: formatter.F(`[%s] node must be passed `+mustGet+`.
You passed {{bold}}%s{{/}} instead.`, nodeType, t), You passed {{bold}}%s{{/}} instead.`, nodeType, t),
CodeLocation: cl, CodeLocation: cl,
DocLink: "node-decorators-overview", DocLink: "node-decorators-overview",
} }
} }
func (g ginkgoErrors) InvalidBodyTypeForSynchronizedBeforeSuiteProc1(t reflect.Type, cl CodeLocation) error {
mustGet := "{{bold}}func() []byte{{/}}, {{bold}}func(ctx SpecContext) []byte{{/}}, or {{bold}}func(ctx context.Context) []byte{{/}}, {{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}"
return GinkgoError{
Heading: "Invalid Function",
Message: formatter.F(`[SynchronizedBeforeSuite] node must be passed `+mustGet+` for its first function.
You passed {{bold}}%s{{/}} instead.`, t),
CodeLocation: cl,
DocLink: "node-decorators-overview",
}
}
func (g ginkgoErrors) InvalidBodyTypeForSynchronizedBeforeSuiteAllProcs(t reflect.Type, cl CodeLocation) error {
mustGet := "{{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}, {{bold}}func([]byte){{/}}, {{bold}}func(ctx SpecContext, []byte){{/}}, or {{bold}}func(ctx context.Context, []byte){{/}}"
return GinkgoError{
Heading: "Invalid Function",
Message: formatter.F(`[SynchronizedBeforeSuite] node must be passed `+mustGet+` for its second function.
You passed {{bold}}%s{{/}} instead.`, t),
CodeLocation: cl,
DocLink: "node-decorators-overview",
}
}
func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{ return GinkgoError{
Heading: "Multiple Functions", Heading: "Multiple Functions",
Message: formatter.F(`[%s] node must be passed a single {{bold}}func(){{/}} - but more than one was passed in.`, nodeType), Message: formatter.F(`[%s] node must be passed a single function - but more than one was passed in.`, nodeType),
CodeLocation: cl, CodeLocation: cl,
DocLink: "node-decorators-overview", DocLink: "node-decorators-overview",
} }
@ -211,12 +255,30 @@ func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType)
func (g ginkgoErrors) MissingBodyFunction(cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) MissingBodyFunction(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{ return GinkgoError{
Heading: "Missing Functions", Heading: "Missing Functions",
Message: formatter.F(`[%s] node must be passed a single {{bold}}func(){{/}} - but none was passed in.`, nodeType), Message: formatter.F(`[%s] node must be passed a single function - but none was passed in.`, nodeType),
CodeLocation: cl, CodeLocation: cl,
DocLink: "node-decorators-overview", DocLink: "node-decorators-overview",
} }
} }
func (g ginkgoErrors) InvalidTimeoutOrGracePeriodForNonContextNode(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{
Heading: "Invalid NodeTimeout SpecTimeout, or GracePeriod",
Message: formatter.F(`[%s] was passed NodeTimeout, SpecTimeout, or GracePeriod but does not have a callback that accepts a {{bold}}SpecContext{{/}} or {{bold}}context.Context{{/}}. You must accept a context to enable timeouts and grace periods`, nodeType),
CodeLocation: cl,
DocLink: "spec-timeouts-and-interruptible-nodes",
}
}
func (g ginkgoErrors) InvalidTimeoutOrGracePeriodForNonContextCleanupNode(cl CodeLocation) error {
return GinkgoError{
Heading: "Invalid NodeTimeout SpecTimeout, or GracePeriod",
Message: formatter.F(`[DeferCleanup] was passed NodeTimeout or GracePeriod but does not have a callback that accepts a {{bold}}SpecContext{{/}} or {{bold}}context.Context{{/}}. You must accept a context to enable timeouts and grace periods`),
CodeLocation: cl,
DocLink: "spec-timeouts-and-interruptible-nodes",
}
}
/* Ordered Container errors */ /* Ordered Container errors */
func (g ginkgoErrors) InvalidSerialNodeInNonSerialOrderedContainer(cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) InvalidSerialNodeInNonSerialOrderedContainer(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{ return GinkgoError{
@ -236,6 +298,15 @@ func (g ginkgoErrors) SetupNodeNotInOrderedContainer(cl CodeLocation, nodeType N
} }
} }
func (g ginkgoErrors) InvalidContinueOnFailureDecoration(cl CodeLocation) error {
return GinkgoError{
Heading: "ContinueOnFailure not decorating an outermost Ordered Container",
Message: "ContinueOnFailure can only decorate an Ordered container, and this Ordered container must be the outermost Ordered container.",
CodeLocation: cl,
DocLink: "ordered-containers",
}
}
/* DeferCleanup errors */ /* DeferCleanup errors */
func (g ginkgoErrors) DeferCleanupInvalidFunction(cl CodeLocation) error { func (g ginkgoErrors) DeferCleanupInvalidFunction(cl CodeLocation) error {
return GinkgoError{ return GinkgoError{
@ -258,7 +329,7 @@ func (g ginkgoErrors) PushingCleanupNodeDuringTreeConstruction(cl CodeLocation)
func (g ginkgoErrors) PushingCleanupInReportingNode(cl CodeLocation, nodeType NodeType) error { func (g ginkgoErrors) PushingCleanupInReportingNode(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{ return GinkgoError{
Heading: fmt.Sprintf("DeferCleanup cannot be called in %s", nodeType), Heading: fmt.Sprintf("DeferCleanup cannot be called in %s", nodeType),
Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a ReportAfterEach or ReportAfterSuite.", Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a Reporting node.",
CodeLocation: cl, CodeLocation: cl,
DocLink: "cleaning-up-our-cleanup-code-defercleanup", DocLink: "cleaning-up-our-cleanup-code-defercleanup",
} }
@ -380,6 +451,15 @@ func (g ginkgoErrors) InvalidEntryDescription(cl CodeLocation) error {
} }
} }
func (g ginkgoErrors) MissingParametersForTableFunction(cl CodeLocation) error {
return GinkgoError{
Heading: fmt.Sprintf("No parameters have been passed to the Table Function"),
Message: fmt.Sprintf("The Table Function expected at least 1 parameter"),
CodeLocation: cl,
DocLink: "table-specs",
}
}
func (g ginkgoErrors) IncorrectParameterTypeForTable(i int, name string, cl CodeLocation) error { func (g ginkgoErrors) IncorrectParameterTypeForTable(i int, name string, cl CodeLocation) error {
return GinkgoError{ return GinkgoError{
Heading: "DescribeTable passed incorrect parameter type", Heading: "DescribeTable passed incorrect parameter type",
@ -498,6 +578,13 @@ func (g ginkgoErrors) DryRunInParallelConfiguration() error {
} }
} }
func (g ginkgoErrors) GracePeriodCannotBeZero() error {
return GinkgoError{
Heading: "Ginkgo requires a positive --grace-period.",
Message: "Please set --grace-period to a positive duration. The default is 30s.",
}
}
func (g ginkgoErrors) ConflictingVerbosityConfiguration() error { func (g ginkgoErrors) ConflictingVerbosityConfiguration() error {
return GinkgoError{ return GinkgoError{
Heading: "Conflicting reporter verbosity settings.", Heading: "Conflicting reporter verbosity settings.",

View File

@ -272,12 +272,23 @@ func tokenize(input string) func() (*treeNode, error) {
} }
} }
func MustParseLabelFilter(input string) LabelFilter {
filter, err := ParseLabelFilter(input)
if err != nil {
panic(err)
}
return filter
}
func ParseLabelFilter(input string) (LabelFilter, error) { func ParseLabelFilter(input string) (LabelFilter, error) {
if DEBUG_LABEL_FILTER_PARSING { if DEBUG_LABEL_FILTER_PARSING {
fmt.Println("\n==============") fmt.Println("\n==============")
fmt.Println("Input: ", input) fmt.Println("Input: ", input)
fmt.Print("Tokens: ") fmt.Print("Tokens: ")
} }
if input == "" {
return func(_ []string) bool { return true }, nil
}
nextToken := tokenize(input) nextToken := tokenize(input)
root := &treeNode{token: lfTokenRoot} root := &treeNode{token: lfTokenRoot}

View File

@ -6,8 +6,8 @@ import (
"time" "time"
) )
//ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports // ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports
//and across the network connection when running in parallel // and across the network connection when running in parallel
type ReportEntryValue struct { type ReportEntryValue struct {
raw interface{} //unexported to prevent gob from freaking out about unregistered structs raw interface{} //unexported to prevent gob from freaking out about unregistered structs
AsJSON string AsJSON string
@ -85,10 +85,12 @@ func (rev *ReportEntryValue) GobDecode(data []byte) error {
type ReportEntry struct { type ReportEntry struct {
// Visibility captures the visibility policy for this ReportEntry // Visibility captures the visibility policy for this ReportEntry
Visibility ReportEntryVisibility Visibility ReportEntryVisibility
// Time captures the time the AddReportEntry was called
Time time.Time
// Location captures the location of the AddReportEntry call // Location captures the location of the AddReportEntry call
Location CodeLocation Location CodeLocation
Time time.Time //need this for backwards compatibility
TimelineLocation TimelineLocation
// Name captures the name of this report // Name captures the name of this report
Name string Name string
// Value captures the (optional) object passed into AddReportEntry - this can be // Value captures the (optional) object passed into AddReportEntry - this can be
@ -120,7 +122,9 @@ func (entry ReportEntry) GetRawValue() interface{} {
return entry.Value.GetRawValue() return entry.Value.GetRawValue()
} }
func (entry ReportEntry) GetTimelineLocation() TimelineLocation {
return entry.TimelineLocation
}
type ReportEntries []ReportEntry type ReportEntries []ReportEntry

View File

@ -2,6 +2,8 @@ package types
import ( import (
"encoding/json" "encoding/json"
"fmt"
"sort"
"strings" "strings"
"time" "time"
) )
@ -56,19 +58,20 @@ type Report struct {
SuiteConfig SuiteConfig SuiteConfig SuiteConfig
//SpecReports is a list of all SpecReports generated by this test run //SpecReports is a list of all SpecReports generated by this test run
//It is empty when the SuiteReport is provided to ReportBeforeSuite
SpecReports SpecReports SpecReports SpecReports
} }
//PreRunStats contains a set of stats captured before the test run begins. This is primarily used // PreRunStats contains a set of stats captured before the test run begins. This is primarily used
//by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs) // by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs)
//and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters. // and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters.
type PreRunStats struct { type PreRunStats struct {
TotalSpecs int TotalSpecs int
SpecsThatWillRun int SpecsThatWillRun int
} }
//Add is ued by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes // Add is used by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes
//to form a complete final report. // to form a complete final report.
func (report Report) Add(other Report) Report { func (report Report) Add(other Report) Report {
report.SuiteSucceeded = report.SuiteSucceeded && other.SuiteSucceeded report.SuiteSucceeded = report.SuiteSucceeded && other.SuiteSucceeded
@ -147,14 +150,24 @@ type SpecReport struct {
// ParallelProcess captures the parallel process that this spec ran on // ParallelProcess captures the parallel process that this spec ran on
ParallelProcess int ParallelProcess int
// RunningInParallel captures whether this spec is part of a suite that ran in parallel
RunningInParallel bool
//Failure is populated if a spec has failed, panicked, been interrupted, or skipped by the user (e.g. calling Skip()) //Failure is populated if a spec has failed, panicked, been interrupted, or skipped by the user (e.g. calling Skip())
//It includes detailed information about the Failure //It includes detailed information about the Failure
Failure Failure Failure Failure
// NumAttempts captures the number of times this Spec was run. Flakey specs can be retried with // NumAttempts captures the number of times this Spec was run.
// ginkgo --flake-attempts=N // Flakey specs can be retried with ginkgo --flake-attempts=N or the use of the FlakeAttempts decorator.
// Repeated specs can be retried with the use of the MustPassRepeatedly decorator
NumAttempts int NumAttempts int
// MaxFlakeAttempts captures whether the spec has been retried with ginkgo --flake-attempts=N or the use of the FlakeAttempts decorator.
MaxFlakeAttempts int
// MaxMustPassRepeatedly captures whether the spec has the MustPassRepeatedly decorator
MaxMustPassRepeatedly int
// CapturedGinkgoWriterOutput contains text printed to the GinkgoWriter // CapturedGinkgoWriterOutput contains text printed to the GinkgoWriter
CapturedGinkgoWriterOutput string CapturedGinkgoWriterOutput string
@ -168,6 +181,12 @@ type SpecReport struct {
// ProgressReports contains any progress reports generated during this spec. These can either be manually triggered, or automatically generated by Ginkgo via the PollProgressAfter() decorator // ProgressReports contains any progress reports generated during this spec. These can either be manually triggered, or automatically generated by Ginkgo via the PollProgressAfter() decorator
ProgressReports []ProgressReport ProgressReports []ProgressReport
// AdditionalFailures contains any failures that occurred after the initial spec failure. These typically occur in cleanup nodes after the initial failure and are only emitted when running in verbose mode.
AdditionalFailures []AdditionalFailure
// SpecEvents capture additional events that occur during the spec run
SpecEvents SpecEvents
} }
func (report SpecReport) MarshalJSON() ([]byte, error) { func (report SpecReport) MarshalJSON() ([]byte, error) {
@ -187,10 +206,14 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
ParallelProcess int ParallelProcess int
Failure *Failure `json:",omitempty"` Failure *Failure `json:",omitempty"`
NumAttempts int NumAttempts int
CapturedGinkgoWriterOutput string `json:",omitempty"` MaxFlakeAttempts int
CapturedStdOutErr string `json:",omitempty"` MaxMustPassRepeatedly int
ReportEntries ReportEntries `json:",omitempty"` CapturedGinkgoWriterOutput string `json:",omitempty"`
ProgressReports []ProgressReport `json:",omitempty"` CapturedStdOutErr string `json:",omitempty"`
ReportEntries ReportEntries `json:",omitempty"`
ProgressReports []ProgressReport `json:",omitempty"`
AdditionalFailures []AdditionalFailure `json:",omitempty"`
SpecEvents SpecEvents `json:",omitempty"`
}{ }{
ContainerHierarchyTexts: report.ContainerHierarchyTexts, ContainerHierarchyTexts: report.ContainerHierarchyTexts,
ContainerHierarchyLocations: report.ContainerHierarchyLocations, ContainerHierarchyLocations: report.ContainerHierarchyLocations,
@ -207,6 +230,8 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
Failure: nil, Failure: nil,
ReportEntries: nil, ReportEntries: nil,
NumAttempts: report.NumAttempts, NumAttempts: report.NumAttempts,
MaxFlakeAttempts: report.MaxFlakeAttempts,
MaxMustPassRepeatedly: report.MaxMustPassRepeatedly,
CapturedGinkgoWriterOutput: report.CapturedGinkgoWriterOutput, CapturedGinkgoWriterOutput: report.CapturedGinkgoWriterOutput,
CapturedStdOutErr: report.CapturedStdOutErr, CapturedStdOutErr: report.CapturedStdOutErr,
} }
@ -220,6 +245,12 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
if len(report.ProgressReports) > 0 { if len(report.ProgressReports) > 0 {
out.ProgressReports = report.ProgressReports out.ProgressReports = report.ProgressReports
} }
if len(report.AdditionalFailures) > 0 {
out.AdditionalFailures = report.AdditionalFailures
}
if len(report.SpecEvents) > 0 {
out.SpecEvents = report.SpecEvents
}
return json.Marshal(out) return json.Marshal(out)
} }
@ -237,13 +268,13 @@ func (report SpecReport) CombinedOutput() string {
return report.CapturedStdOutErr + "\n" + report.CapturedGinkgoWriterOutput return report.CapturedStdOutErr + "\n" + report.CapturedGinkgoWriterOutput
} }
//Failed returns true if report.State is one of the SpecStateFailureStates // Failed returns true if report.State is one of the SpecStateFailureStates
// (SpecStateFailed, SpecStatePanicked, SpecStateinterrupted, SpecStateAborted) // (SpecStateFailed, SpecStatePanicked, SpecStateinterrupted, SpecStateAborted)
func (report SpecReport) Failed() bool { func (report SpecReport) Failed() bool {
return report.State.Is(SpecStateFailureStates) return report.State.Is(SpecStateFailureStates)
} }
//FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText // FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText
func (report SpecReport) FullText() string { func (report SpecReport) FullText() string {
texts := []string{} texts := []string{}
texts = append(texts, report.ContainerHierarchyTexts...) texts = append(texts, report.ContainerHierarchyTexts...)
@ -253,7 +284,7 @@ func (report SpecReport) FullText() string {
return strings.Join(texts, " ") return strings.Join(texts, " ")
} }
//Labels returns a deduped set of all the spec's Labels. // Labels returns a deduped set of all the spec's Labels.
func (report SpecReport) Labels() []string { func (report SpecReport) Labels() []string {
out := []string{} out := []string{}
seen := map[string]bool{} seen := map[string]bool{}
@ -275,7 +306,7 @@ func (report SpecReport) Labels() []string {
return out return out
} }
//MatchesLabelFilter returns true if the spec satisfies the passed in label filter query // MatchesLabelFilter returns true if the spec satisfies the passed in label filter query
func (report SpecReport) MatchesLabelFilter(query string) (bool, error) { func (report SpecReport) MatchesLabelFilter(query string) (bool, error) {
filter, err := ParseLabelFilter(query) filter, err := ParseLabelFilter(query)
if err != nil { if err != nil {
@ -284,29 +315,54 @@ func (report SpecReport) MatchesLabelFilter(query string) (bool, error) {
return filter(report.Labels()), nil return filter(report.Labels()), nil
} }
//FileName() returns the name of the file containing the spec // FileName() returns the name of the file containing the spec
func (report SpecReport) FileName() string { func (report SpecReport) FileName() string {
return report.LeafNodeLocation.FileName return report.LeafNodeLocation.FileName
} }
//LineNumber() returns the line number of the leaf node // LineNumber() returns the line number of the leaf node
func (report SpecReport) LineNumber() int { func (report SpecReport) LineNumber() int {
return report.LeafNodeLocation.LineNumber return report.LeafNodeLocation.LineNumber
} }
//FailureMessage() returns the failure message (or empty string if the test hasn't failed) // FailureMessage() returns the failure message (or empty string if the test hasn't failed)
func (report SpecReport) FailureMessage() string { func (report SpecReport) FailureMessage() string {
return report.Failure.Message return report.Failure.Message
} }
//FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed) // FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed)
func (report SpecReport) FailureLocation() CodeLocation { func (report SpecReport) FailureLocation() CodeLocation {
return report.Failure.Location return report.Failure.Location
} }
// Timeline() returns a timeline view of the report
func (report SpecReport) Timeline() Timeline {
timeline := Timeline{}
if !report.Failure.IsZero() {
timeline = append(timeline, report.Failure)
if report.Failure.AdditionalFailure != nil {
timeline = append(timeline, *(report.Failure.AdditionalFailure))
}
}
for _, additionalFailure := range report.AdditionalFailures {
timeline = append(timeline, additionalFailure)
}
for _, reportEntry := range report.ReportEntries {
timeline = append(timeline, reportEntry)
}
for _, progressReport := range report.ProgressReports {
timeline = append(timeline, progressReport)
}
for _, specEvent := range report.SpecEvents {
timeline = append(timeline, specEvent)
}
sort.Sort(timeline)
return timeline
}
type SpecReports []SpecReport type SpecReports []SpecReport
//WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes // WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes
func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports { func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports {
count := 0 count := 0
for i := range reports { for i := range reports {
@ -326,7 +382,7 @@ func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports {
return out return out
} }
//WithState returns the subset of SpecReports with State matching one of the requested SpecStates // WithState returns the subset of SpecReports with State matching one of the requested SpecStates
func (reports SpecReports) WithState(states SpecState) SpecReports { func (reports SpecReports) WithState(states SpecState) SpecReports {
count := 0 count := 0
for i := range reports { for i := range reports {
@ -345,7 +401,7 @@ func (reports SpecReports) WithState(states SpecState) SpecReports {
return out return out
} }
//CountWithState returns the number of SpecReports with State matching one of the requested SpecStates // CountWithState returns the number of SpecReports with State matching one of the requested SpecStates
func (reports SpecReports) CountWithState(states SpecState) int { func (reports SpecReports) CountWithState(states SpecState) int {
n := 0 n := 0
for i := range reports { for i := range reports {
@ -356,17 +412,75 @@ func (reports SpecReports) CountWithState(states SpecState) int {
return n return n
} }
//CountWithState returns the number of SpecReports that passed after multiple attempts // If the Spec passes, CountOfFlakedSpecs returns the number of SpecReports that failed after multiple attempts.
func (reports SpecReports) CountOfFlakedSpecs() int { func (reports SpecReports) CountOfFlakedSpecs() int {
n := 0 n := 0
for i := range reports { for i := range reports {
if reports[i].State.Is(SpecStatePassed) && reports[i].NumAttempts > 1 { if reports[i].MaxFlakeAttempts > 1 && reports[i].State.Is(SpecStatePassed) && reports[i].NumAttempts > 1 {
n += 1 n += 1
} }
} }
return n return n
} }
// If the Spec fails, CountOfRepeatedSpecs returns the number of SpecReports that passed after multiple attempts
func (reports SpecReports) CountOfRepeatedSpecs() int {
n := 0
for i := range reports {
if reports[i].MaxMustPassRepeatedly > 1 && reports[i].State.Is(SpecStateFailureStates) && reports[i].NumAttempts > 1 {
n += 1
}
}
return n
}
// TimelineLocation captures the location of an event in the spec's timeline
type TimelineLocation struct {
//Offset is the offset (in bytes) of the event relative to the GinkgoWriter stream
Offset int `json:",omitempty"`
//Order is the order of the event with respect to other events. The absolute value of Order
//is irrelevant. All that matters is that an event with a lower Order occurs before ane vent with a higher Order
Order int `json:",omitempty"`
Time time.Time
}
// TimelineEvent represent an event on the timeline
// consumers of Timeline will need to check the concrete type of each entry to determine how to handle it
type TimelineEvent interface {
GetTimelineLocation() TimelineLocation
}
type Timeline []TimelineEvent
func (t Timeline) Len() int { return len(t) }
func (t Timeline) Less(i, j int) bool {
return t[i].GetTimelineLocation().Order < t[j].GetTimelineLocation().Order
}
func (t Timeline) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
func (t Timeline) WithoutHiddenReportEntries() Timeline {
out := Timeline{}
for _, event := range t {
if reportEntry, isReportEntry := event.(ReportEntry); isReportEntry && reportEntry.Visibility == ReportEntryVisibilityNever {
continue
}
out = append(out, event)
}
return out
}
func (t Timeline) WithoutVeryVerboseSpecEvents() Timeline {
out := Timeline{}
for _, event := range t {
if specEvent, isSpecEvent := event.(SpecEvent); isSpecEvent && specEvent.IsOnlyVisibleAtVeryVerbose() {
continue
}
out = append(out, event)
}
return out
}
// Failure captures failure information for an individual test // Failure captures failure information for an individual test
type Failure struct { type Failure struct {
// Message - the failure message passed into Fail(...). When using a matcher library // Message - the failure message passed into Fail(...). When using a matcher library
@ -379,6 +493,8 @@ type Failure struct {
// This CodeLocation will include a fully-populated StackTrace // This CodeLocation will include a fully-populated StackTrace
Location CodeLocation Location CodeLocation
TimelineLocation TimelineLocation
// ForwardedPanic - if the failure represents a captured panic (i.e. Summary.State == SpecStatePanicked) // ForwardedPanic - if the failure represents a captured panic (i.e. Summary.State == SpecStatePanicked)
// then ForwardedPanic will be populated with a string representation of the captured panic. // then ForwardedPanic will be populated with a string representation of the captured panic.
ForwardedPanic string `json:",omitempty"` ForwardedPanic string `json:",omitempty"`
@ -391,19 +507,29 @@ type Failure struct {
// FailureNodeType will contain the NodeType of the node in which the failure occurred. // FailureNodeType will contain the NodeType of the node in which the failure occurred.
// FailureNodeLocation will contain the CodeLocation of the node in which the failure occurred. // FailureNodeLocation will contain the CodeLocation of the node in which the failure occurred.
// If populated, FailureNodeContainerIndex will be the index into SpecReport.ContainerHierarchyTexts and SpecReport.ContainerHierarchyLocations that represents the parent container of the node in which the failure occurred. // If populated, FailureNodeContainerIndex will be the index into SpecReport.ContainerHierarchyTexts and SpecReport.ContainerHierarchyLocations that represents the parent container of the node in which the failure occurred.
FailureNodeContext FailureNodeContext FailureNodeContext FailureNodeContext `json:",omitempty"`
FailureNodeType NodeType
FailureNodeLocation CodeLocation FailureNodeType NodeType `json:",omitempty"`
FailureNodeContainerIndex int
FailureNodeLocation CodeLocation `json:",omitempty"`
FailureNodeContainerIndex int `json:",omitempty"`
//ProgressReport is populated if the spec was interrupted or timed out //ProgressReport is populated if the spec was interrupted or timed out
ProgressReport ProgressReport ProgressReport ProgressReport `json:",omitempty"`
//AdditionalFailure is non-nil if a follow-on failure occurred within the same node after the primary failure. This only happens when a node has timed out or been interrupted. In such cases the AdditionalFailure can include information about where/why the spec was stuck.
AdditionalFailure *AdditionalFailure `json:",omitempty"`
} }
func (f Failure) IsZero() bool { func (f Failure) IsZero() bool {
return f.Message == "" && (f.Location == CodeLocation{}) return f.Message == "" && (f.Location == CodeLocation{})
} }
func (f Failure) GetTimelineLocation() TimelineLocation {
return f.TimelineLocation
}
// FailureNodeContext captures the location context for the node containing the failing line of code // FailureNodeContext captures the location context for the node containing the failing line of code
type FailureNodeContext uint type FailureNodeContext uint
@ -434,6 +560,18 @@ func (fnc FailureNodeContext) MarshalJSON() ([]byte, error) {
return fncEnumSupport.MarshJSON(uint(fnc)) return fncEnumSupport.MarshJSON(uint(fnc))
} }
// AdditionalFailure capturs any additional failures that occur after the initial failure of a psec
// these typically occur in clean up nodes after the spec has failed.
// We can't simply use Failure as we want to track the SpecState to know what kind of failure this is
type AdditionalFailure struct {
State SpecState
Failure Failure
}
func (f AdditionalFailure) GetTimelineLocation() TimelineLocation {
return f.Failure.TimelineLocation
}
// SpecState captures the state of a spec // SpecState captures the state of a spec
// To determine if a given `state` represents a failure state, use `state.Is(SpecStateFailureStates)` // To determine if a given `state` represents a failure state, use `state.Is(SpecStateFailureStates)`
type SpecState uint type SpecState uint
@ -448,6 +586,7 @@ const (
SpecStateAborted SpecStateAborted
SpecStatePanicked SpecStatePanicked
SpecStateInterrupted SpecStateInterrupted
SpecStateTimedout
) )
var ssEnumSupport = NewEnumSupport(map[uint]string{ var ssEnumSupport = NewEnumSupport(map[uint]string{
@ -459,11 +598,15 @@ var ssEnumSupport = NewEnumSupport(map[uint]string{
uint(SpecStateAborted): "aborted", uint(SpecStateAborted): "aborted",
uint(SpecStatePanicked): "panicked", uint(SpecStatePanicked): "panicked",
uint(SpecStateInterrupted): "interrupted", uint(SpecStateInterrupted): "interrupted",
uint(SpecStateTimedout): "timedout",
}) })
func (ss SpecState) String() string { func (ss SpecState) String() string {
return ssEnumSupport.String(uint(ss)) return ssEnumSupport.String(uint(ss))
} }
func (ss SpecState) GomegaString() string {
return ssEnumSupport.String(uint(ss))
}
func (ss *SpecState) UnmarshalJSON(b []byte) error { func (ss *SpecState) UnmarshalJSON(b []byte) error {
out, err := ssEnumSupport.UnmarshJSON(b) out, err := ssEnumSupport.UnmarshJSON(b)
*ss = SpecState(out) *ss = SpecState(out)
@ -473,7 +616,7 @@ func (ss SpecState) MarshalJSON() ([]byte, error) {
return ssEnumSupport.MarshJSON(uint(ss)) return ssEnumSupport.MarshJSON(uint(ss))
} }
var SpecStateFailureStates = SpecStateFailed | SpecStateAborted | SpecStatePanicked | SpecStateInterrupted var SpecStateFailureStates = SpecStateFailed | SpecStateTimedout | SpecStateAborted | SpecStatePanicked | SpecStateInterrupted
func (ss SpecState) Is(states SpecState) bool { func (ss SpecState) Is(states SpecState) bool {
return ss&states != 0 return ss&states != 0
@ -481,35 +624,40 @@ func (ss SpecState) Is(states SpecState) bool {
// ProgressReport captures the progress of the current spec. It is, effectively, a structured Ginkgo-aware stack trace // ProgressReport captures the progress of the current spec. It is, effectively, a structured Ginkgo-aware stack trace
type ProgressReport struct { type ProgressReport struct {
ParallelProcess int Message string `json:",omitempty"`
RunningInParallel bool ParallelProcess int `json:",omitempty"`
RunningInParallel bool `json:",omitempty"`
Time time.Time ContainerHierarchyTexts []string `json:",omitempty"`
LeafNodeText string `json:",omitempty"`
LeafNodeLocation CodeLocation `json:",omitempty"`
SpecStartTime time.Time `json:",omitempty"`
ContainerHierarchyTexts []string CurrentNodeType NodeType `json:",omitempty"`
LeafNodeText string CurrentNodeText string `json:",omitempty"`
LeafNodeLocation CodeLocation CurrentNodeLocation CodeLocation `json:",omitempty"`
SpecStartTime time.Time CurrentNodeStartTime time.Time `json:",omitempty"`
CurrentNodeType NodeType CurrentStepText string `json:",omitempty"`
CurrentNodeText string CurrentStepLocation CodeLocation `json:",omitempty"`
CurrentNodeLocation CodeLocation CurrentStepStartTime time.Time `json:",omitempty"`
CurrentNodeStartTime time.Time
CurrentStepText string AdditionalReports []string `json:",omitempty"`
CurrentStepLocation CodeLocation
CurrentStepStartTime time.Time
CapturedGinkgoWriterOutput string `json:",omitempty"` CapturedGinkgoWriterOutput string `json:",omitempty"`
GinkgoWriterOffset int TimelineLocation TimelineLocation `json:",omitempty"`
Goroutines []Goroutine Goroutines []Goroutine `json:",omitempty"`
} }
func (pr ProgressReport) IsZero() bool { func (pr ProgressReport) IsZero() bool {
return pr.CurrentNodeType == NodeTypeInvalid return pr.CurrentNodeType == NodeTypeInvalid
} }
func (pr ProgressReport) Time() time.Time {
return pr.TimelineLocation.Time
}
func (pr ProgressReport) SpecGoroutine() Goroutine { func (pr ProgressReport) SpecGoroutine() Goroutine {
for _, goroutine := range pr.Goroutines { for _, goroutine := range pr.Goroutines {
if goroutine.IsSpecGoroutine { if goroutine.IsSpecGoroutine {
@ -547,6 +695,22 @@ func (pr ProgressReport) WithoutCapturedGinkgoWriterOutput() ProgressReport {
return out return out
} }
func (pr ProgressReport) WithoutOtherGoroutines() ProgressReport {
out := pr
filteredGoroutines := []Goroutine{}
for _, goroutine := range pr.Goroutines {
if goroutine.IsSpecGoroutine || goroutine.HasHighlights() {
filteredGoroutines = append(filteredGoroutines, goroutine)
}
}
out.Goroutines = filteredGoroutines
return out
}
func (pr ProgressReport) GetTimelineLocation() TimelineLocation {
return pr.TimelineLocation
}
type Goroutine struct { type Goroutine struct {
ID uint64 ID uint64
State string State string
@ -601,6 +765,7 @@ const (
NodeTypeReportBeforeEach NodeTypeReportBeforeEach
NodeTypeReportAfterEach NodeTypeReportAfterEach
NodeTypeReportBeforeSuite
NodeTypeReportAfterSuite NodeTypeReportAfterSuite
NodeTypeCleanupInvalid NodeTypeCleanupInvalid
@ -610,7 +775,9 @@ const (
) )
var NodeTypesForContainerAndIt = NodeTypeContainer | NodeTypeIt var NodeTypesForContainerAndIt = NodeTypeContainer | NodeTypeIt
var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite
var NodeTypesAllowedDuringCleanupInterrupt = NodeTypeAfterEach | NodeTypeJustAfterEach | NodeTypeAfterAll | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeCleanupAfterEach | NodeTypeCleanupAfterAll | NodeTypeCleanupAfterSuite
var NodeTypesAllowedDuringReportInterrupt = NodeTypeReportBeforeEach | NodeTypeReportAfterEach | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite
var ntEnumSupport = NewEnumSupport(map[uint]string{ var ntEnumSupport = NewEnumSupport(map[uint]string{
uint(NodeTypeInvalid): "INVALID NODE TYPE", uint(NodeTypeInvalid): "INVALID NODE TYPE",
@ -628,9 +795,10 @@ var ntEnumSupport = NewEnumSupport(map[uint]string{
uint(NodeTypeSynchronizedAfterSuite): "SynchronizedAfterSuite", uint(NodeTypeSynchronizedAfterSuite): "SynchronizedAfterSuite",
uint(NodeTypeReportBeforeEach): "ReportBeforeEach", uint(NodeTypeReportBeforeEach): "ReportBeforeEach",
uint(NodeTypeReportAfterEach): "ReportAfterEach", uint(NodeTypeReportAfterEach): "ReportAfterEach",
uint(NodeTypeReportBeforeSuite): "ReportBeforeSuite",
uint(NodeTypeReportAfterSuite): "ReportAfterSuite", uint(NodeTypeReportAfterSuite): "ReportAfterSuite",
uint(NodeTypeCleanupInvalid): "INVALID CLEANUP NODE", uint(NodeTypeCleanupInvalid): "DeferCleanup",
uint(NodeTypeCleanupAfterEach): "DeferCleanup", uint(NodeTypeCleanupAfterEach): "DeferCleanup (Each)",
uint(NodeTypeCleanupAfterAll): "DeferCleanup (All)", uint(NodeTypeCleanupAfterAll): "DeferCleanup (All)",
uint(NodeTypeCleanupAfterSuite): "DeferCleanup (Suite)", uint(NodeTypeCleanupAfterSuite): "DeferCleanup (Suite)",
}) })
@ -650,3 +818,99 @@ func (nt NodeType) MarshalJSON() ([]byte, error) {
func (nt NodeType) Is(nodeTypes NodeType) bool { func (nt NodeType) Is(nodeTypes NodeType) bool {
return nt&nodeTypes != 0 return nt&nodeTypes != 0
} }
/*
SpecEvent captures a vareity of events that can occur when specs run. See SpecEventType for the list of available events.
*/
type SpecEvent struct {
SpecEventType SpecEventType
CodeLocation CodeLocation
TimelineLocation TimelineLocation
Message string `json:",omitempty"`
Duration time.Duration `json:",omitempty"`
NodeType NodeType `json:",omitempty"`
Attempt int `json:",omitempty"`
}
func (se SpecEvent) GetTimelineLocation() TimelineLocation {
return se.TimelineLocation
}
func (se SpecEvent) IsOnlyVisibleAtVeryVerbose() bool {
return se.SpecEventType.Is(SpecEventByEnd | SpecEventNodeStart | SpecEventNodeEnd)
}
func (se SpecEvent) GomegaString() string {
out := &strings.Builder{}
out.WriteString("[" + se.SpecEventType.String() + " SpecEvent] ")
if se.Message != "" {
out.WriteString("Message=")
out.WriteString(`"` + se.Message + `",`)
}
if se.Duration != 0 {
out.WriteString("Duration=" + se.Duration.String() + ",")
}
if se.NodeType != NodeTypeInvalid {
out.WriteString("NodeType=" + se.NodeType.String() + ",")
}
if se.Attempt != 0 {
out.WriteString(fmt.Sprintf("Attempt=%d", se.Attempt) + ",")
}
out.WriteString("CL=" + se.CodeLocation.String() + ",")
out.WriteString(fmt.Sprintf("TL.Offset=%d", se.TimelineLocation.Offset))
return out.String()
}
type SpecEvents []SpecEvent
func (se SpecEvents) WithType(seType SpecEventType) SpecEvents {
out := SpecEvents{}
for _, event := range se {
if event.SpecEventType.Is(seType) {
out = append(out, event)
}
}
return out
}
type SpecEventType uint
const (
SpecEventInvalid SpecEventType = 0
SpecEventByStart SpecEventType = 1 << iota
SpecEventByEnd
SpecEventNodeStart
SpecEventNodeEnd
SpecEventSpecRepeat
SpecEventSpecRetry
)
var seEnumSupport = NewEnumSupport(map[uint]string{
uint(SpecEventInvalid): "INVALID SPEC EVENT",
uint(SpecEventByStart): "By",
uint(SpecEventByEnd): "By (End)",
uint(SpecEventNodeStart): "Node",
uint(SpecEventNodeEnd): "Node (End)",
uint(SpecEventSpecRepeat): "Repeat",
uint(SpecEventSpecRetry): "Retry",
})
func (se SpecEventType) String() string {
return seEnumSupport.String(uint(se))
}
func (se *SpecEventType) UnmarshalJSON(b []byte) error {
out, err := seEnumSupport.UnmarshJSON(b)
*se = SpecEventType(out)
return err
}
func (se SpecEventType) MarshalJSON() ([]byte, error) {
return seEnumSupport.MarshJSON(uint(se))
}
func (se SpecEventType) Is(specEventTypes SpecEventType) bool {
return se&specEventTypes != 0
}

View File

@ -1,3 +1,3 @@
package types package types
const VERSION = "2.2.0" const VERSION = "2.9.5"

View File

@ -5,28 +5,185 @@
[![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go) [![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go)
[![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/) [![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/)
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go, including the Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) and Datagram Packetization Layer Path MTU quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)). It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
In addition to these base RFCs, it also implements the following RFCs:
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem. In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem.
## Guides This repository provides both a QUIC implementation, located in the `quic` package, as well as an HTTP/3 implementation, located in the `http3` package.
*We currently support Go 1.19.x and Go 1.20.x* ## Using QUIC
Running tests: ### Running a Server
go test ./... The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket.
### QUIC without HTTP/3 ```go
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234})
// ... error handling
tr := quic.Transport{
Conn: udpConn,
}
ln, err := tr.Listen(tlsConf, quicConf)
// ... error handling
go func() {
for {
conn, err := ln.Accept()
// ... error handling
// handle the connection, usually in a new Go routine
}
}
```
Take a look at [this echo example](example/echo/echo.go). The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`).
## Usage As a shortcut, `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`:
```
ln, err := quic.Listen(udpConn, tlsConf, quicConf)
```
When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections.
### Running a Client
As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections.
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := tr.Dial(ctx, <server address>, <tls.Config>, <quic.Config>)
// ... error handling
```
As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`:
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := quic.Dial(ctx, conn, <server address>, <tls.Config>, <quic.Config>)
```
Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections.
### Using a QUIC Connection
#### Accepting Streams
QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3).
Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides.
On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop:
```go
for {
str, err := conn.AcceptStream(context.Background()) // for bidirectional streams
// ... error handling
// handle the stream, usually in a new Go routine
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Opening Streams
There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so.
Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away:
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream
```
The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error:
```go
str, err := conn.OpenStream()
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
// It's currently not possible to open another stream,
// but it might be possible later, once the peer allowed us to do so.
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Using Streams
Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions.
Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream.
A bidirectional stream is only closed once both the read and the write side of the stream have been either closed and reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`.
### Configuring QUIC
The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details.
The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)).
### Closing a Connection
#### When the remote Peer closes the Connection
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Users can use errors assertions to find out what exactly went wrong:
* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur.
* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer.
* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error.
* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below.
#### Initiated by the Application
A `quic.Connection` can be closed using `CloseWithError`:
```go
conn.CloseWithError(0x42, "error 0x42 occurred")
```
Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes.
On the receiver side, this is surfaced as a `quic.ApplicationError`.
### QUIC Datagrams
Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`.
QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not automatically retransmitted.
Datagrams are sent using the `SendMessage` method on the `quic.Connection`:
```go
conn.SendMessage([]byte("foobar"))
```
And received using `ReceiveMessage`:
```go
msg, err := conn.ReceiveMessage()
```
Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599).
## Using HTTP/3
### As a server ### As a server
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go: See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go:
```go ```go
http.Handle("/", http.FileServer(http.Dir(wwwDir))) http.Handle("/", http.FileServer(http.Dir(wwwDir)))
@ -45,18 +202,29 @@ http.Client{
## Projects using quic-go ## Projects using quic-go
| Project | Description | Stars | | Project | Description | Stars |
|-----------------------------------------------------------|---------------------------------------------------------------------------------------------------------|-------| |-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) | | [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) |
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | | [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | | [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | | [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | | [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | | [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | | [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | | [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | | [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) |
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) | | [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) |
If you'd like to see your project added to this list, please send us a PR.
## Release Policy
quic-go always aims to support the latest two Go releases.
### Dependency on forked crypto/tls
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20) and [qtls for Go 1.19](https://github.com/quic-go/qtls-go1-19). This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
## Contributing ## Contributing

View File

@ -51,18 +51,22 @@ func (b *packetBuffer) Release() {
} }
// Len returns the length of Data // Len returns the length of Data
func (b *packetBuffer) Len() protocol.ByteCount { func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
return protocol.ByteCount(len(b.Data)) func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
}
func (b *packetBuffer) putBack() { func (b *packetBuffer) putBack() {
if cap(b.Data) != int(protocol.MaxPacketBufferSize) { if cap(b.Data) == protocol.MaxPacketBufferSize {
panic("putPacketBuffer called with packet of wrong size!") bufferPool.Put(b)
return
} }
bufferPool.Put(b) if cap(b.Data) == protocol.MaxLargePacketBufferSize {
largeBufferPool.Put(b)
return
}
panic("putPacketBuffer called with packet of wrong size!")
} }
var bufferPool sync.Pool var bufferPool, largeBufferPool sync.Pool
func getPacketBuffer() *packetBuffer { func getPacketBuffer() *packetBuffer {
buf := bufferPool.Get().(*packetBuffer) buf := bufferPool.Get().(*packetBuffer)
@ -71,10 +75,18 @@ func getPacketBuffer() *packetBuffer {
return buf return buf
} }
func getLargePacketBuffer() *packetBuffer {
buf := largeBufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func init() { func init() {
bufferPool.New = func() interface{} { bufferPool.New = func() any {
return &packetBuffer{ return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
Data: make([]byte, 0, protocol.MaxPacketBufferSize), }
} largeBufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
} }
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -13,20 +12,19 @@ import (
) )
type client struct { type client struct {
sconn sendConn sendConn sendConn
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
use0RTT bool use0RTT bool
packetHandlers packetHandlerManager packetHandlers packetHandlerManager
onClose func()
tlsConf *tls.Config tlsConf *tls.Config
config *Config config *Config
srcConnID protocol.ConnectionID connIDGenerator ConnectionIDGenerator
destConnID protocol.ConnectionID srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialPacketNumber protocol.PacketNumber initialPacketNumber protocol.PacketNumber
hasNegotiatedVersion bool hasNegotiatedVersion bool
@ -45,153 +43,107 @@ type client struct {
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed. // It resolves the address, and then creates a new UDP connection to dial the QUIC server.
// The hostname for SNI is taken from the given address. // When the QUIC connection is closed, this UDP connection is closed.
func DialAddr( // See Dial for more details.
addr string, func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
tlsConf *tls.Config,
config *Config,
) (Connection, error) {
return DialAddrContext(context.Background(), addr, tlsConf, config)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// The hostname for SNI is taken from the given address.
func DialAddrEarly(
addr string,
tlsConf *tls.Config,
config *Config,
) (EarlyConnection, error) {
return DialAddrEarlyContext(context.Background(), addr, tlsConf, config)
}
// DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context.
// See DialAddrEarly for details
func DialAddrEarlyContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
) (EarlyConnection, error) {
conn, err := dialAddrContext(ctx, addr, tlsConf, config, true)
if err != nil {
return nil, err
}
utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection")
return conn, nil
}
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// See DialAddr for details.
func DialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
) (Connection, error) {
return dialAddrContext(ctx, addr, tlsConf, config, false)
}
func dialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
use0RTT bool,
) (quicConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil { if err != nil {
return nil, err return nil, err
} }
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true) udpAddr, err := net.ResolveUDPAddr("udp", addr)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// packets. The same PacketConn can be used for multiple calls to Dial and
// Listen, QUIC connection IDs are used for demultiplexing the different
// connections. The host parameter is used for SNI. The tls.Config must define
// an application protocol (using NextProtos).
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Connection, error) {
return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false)
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// The same PacketConn can be used for multiple calls to Dial and Listen,
// QUIC connection IDs are used for demultiplexing the different connections.
// The host parameter is used for SNI.
// The tls.Config must define an application protocol (using NextProtos).
func DialEarly(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (EarlyConnection, error) {
return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
}
// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
// See DialEarly for details.
func DialEarlyContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (EarlyConnection, error) {
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false)
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// See Dial for details.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Connection, error) {
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false)
}
func dialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
use0RTT bool,
createdPacketConn bool,
) (quicConn, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(config); err != nil {
return nil, err
}
config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn) dl, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
return dl.Dial(ctx, udpAddr, tlsConf, conf)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// See DialAddr for more details.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// See Dial for more details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// will be used instead of ReadFrom and WriteTo to read/write packets.
// The tls.Config must define an application protocol (using NextProtos).
//
// This is a convenience function. More advanced use cases should instantiate a Transport,
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for multiple QUIC connections.
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
return &Transport{
Conn: c,
createdConn: createdPacketConn,
isSingleUse: true,
}, nil
}
func dial(
ctx context.Context,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
packetHandlers packetHandlerManager,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
) (quicConn, error) {
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -199,14 +151,10 @@ func dialContext(
c.tracingID = nextConnTracingID() c.tracingID = nextConnTracingID()
if c.config.Tracer != nil { if c.config.Tracer != nil {
c.tracer = c.config.Tracer.TracerForConnection( c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
context.WithValue(ctx, ConnectionTracingKey, c.tracingID),
protocol.PerspectiveClient,
c.destConnID,
)
} }
if c.tracer != nil { if c.tracer != nil {
c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
} }
if err := c.dial(ctx); err != nil { if err := c.dial(ctx); err != nil {
return nil, err return nil, err
@ -214,40 +162,14 @@ func dialContext(
return c.conn, nil return c.conn, nil
} }
func newClient( func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
pconn net.PacketConn,
remoteAddr net.Addr,
config *Config,
tlsConf *tls.Config,
host string,
use0RTT bool,
createdPacketConn bool,
) (*client, error) {
if tlsConf == nil { if tlsConf == nil {
tlsConf = &tls.Config{} tlsConf = &tls.Config{}
} else { } else {
tlsConf = tlsConf.Clone() tlsConf = tlsConf.Clone()
} }
if tlsConf.ServerName == "" {
sni, _, err := net.SplitHostPort(host)
if err != nil {
// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
sni = host
}
tlsConf.ServerName = sni srcConnID, err := connIDGenerator.GenerateConnectionID()
}
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -256,28 +178,30 @@ func newClient(
return nil, err return nil, err
} }
c := &client{ c := &client{
srcConnID: srcConnID, connIDGenerator: connIDGenerator,
destConnID: destConnID, srcConnID: srcConnID,
sconn: newSendPconn(pconn, remoteAddr), destConnID: destConnID,
createdPacketConn: createdPacketConn, sendConn: sendConn,
use0RTT: use0RTT, use0RTT: use0RTT,
tlsConf: tlsConf, onClose: onClose,
config: config, tlsConf: tlsConf,
version: config.Versions[0], config: config,
handshakeChan: make(chan struct{}), version: config.Versions[0],
logger: utils.DefaultLogger.WithPrefix("client"), handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
} }
return c, nil return c, nil
} }
func (c *client) dial(ctx context.Context) error { func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection( c.conn = newClientConnection(
c.sconn, c.sendConn,
c.packetHandlers, c.packetHandlers,
c.destConnID, c.destConnID,
c.srcConnID, c.srcConnID,
c.connIDGenerator,
c.config, c.config,
c.tlsConf, c.tlsConf,
c.initialPacketNumber, c.initialPacketNumber,
@ -291,13 +215,18 @@ func (c *client) dial(ctx context.Context) error {
c.packetHandlers.Add(c.srcConnID, c.conn) c.packetHandlers.Add(c.srcConnID, c.conn)
errorChan := make(chan error, 1) errorChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() { go func() {
err := c.conn.run() // returns as soon as the connection is closed err := c.conn.run()
var recreateErr *errCloseForRecreating
if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { if errors.As(err, &recreateErr) {
c.packetHandlers.Destroy() recreateChan <- *recreateErr
return
} }
errorChan <- err if c.onClose != nil {
c.onClose()
}
errorChan <- err // returns as soon as the connection is closed
}() }()
// only set when we're using 0-RTT // only set when we're using 0-RTT
@ -312,14 +241,12 @@ func (c *client) dial(ctx context.Context) error {
c.conn.shutdown() c.conn.shutdown()
return ctx.Err() return ctx.Err()
case err := <-errorChan: case err := <-errorChan:
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
}
return err return err
case recreateErr := <-recreateChan:
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
case <-earlyConnChan: case <-earlyConnChan:
// ready to send 0-RTT data // ready to send 0-RTT data
return nil return nil

View File

@ -16,13 +16,13 @@ type closedLocalConn struct {
perspective protocol.Perspective perspective protocol.Perspective
logger utils.Logger logger utils.Logger
sendPacket func(net.Addr, *packetInfo) sendPacket func(net.Addr, packetInfo)
} }
var _ packetHandler = &closedLocalConn{} var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it. // newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
return &closedLocalConn{ return &closedLocalConn{
sendPacket: sendPacket, sendPacket: sendPacket,
perspective: pers, perspective: pers,
@ -30,7 +30,7 @@ func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Pe
} }
} }
func (c *closedLocalConn) handlePacket(p *receivedPacket) { func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.counter++ c.counter++
// exponential backoff // exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
@ -58,7 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
return &closedRemoteConn{perspective: pers} return &closedRemoteConn{perspective: pers}
} }
func (s *closedRemoteConn) handlePacket(*receivedPacket) {} func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) shutdown() {} func (s *closedRemoteConn) shutdown() {}
func (s *closedRemoteConn) destroy(error) {} func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }

View File

@ -1,12 +1,13 @@
package quic package quic
import ( import (
"errors" "fmt"
"net" "net"
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
) )
// Clone clones a Config // Clone clones a Config
@ -23,11 +24,24 @@ func validateConfig(config *Config) error {
if config == nil { if config == nil {
return nil return nil
} }
if config.MaxIncomingStreams > 1<<60 { const maxStreams = 1 << 60
return errors.New("invalid value for Config.MaxIncomingStreams") if config.MaxIncomingStreams > maxStreams {
config.MaxIncomingStreams = maxStreams
} }
if config.MaxIncomingUniStreams > 1<<60 { if config.MaxIncomingUniStreams > maxStreams {
return errors.New("invalid value for Config.MaxIncomingUniStreams") config.MaxIncomingUniStreams = maxStreams
}
if config.MaxStreamReceiveWindow > quicvarint.Max {
config.MaxStreamReceiveWindow = quicvarint.Max
}
if config.MaxConnectionReceiveWindow > quicvarint.Max {
config.MaxConnectionReceiveWindow = quicvarint.Max
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return fmt.Errorf("invalid QUIC version: %s", v)
}
} }
return nil return nil
} }
@ -35,7 +49,7 @@ func validateConfig(config *Config) error {
// populateServerConfig populates fields in the quic.Config with their default values, if none are set // populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil // it may be called with nil
func populateServerConfig(config *Config) *Config { func populateServerConfig(config *Config) *Config {
config = populateConfig(config, protocol.DefaultConnectionIDLength) config = populateConfig(config)
if config.MaxTokenAge == 0 { if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity config.MaxTokenAge = protocol.TokenValidity
} }
@ -48,19 +62,9 @@ func populateServerConfig(config *Config) *Config {
return config return config
} }
// populateClientConfig populates fields in the quic.Config with their default values, if none are set // populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil // it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config { func populateConfig(config *Config) *Config {
defaultConnIDLen := protocol.DefaultConnectionIDLength
if createdPacketConn {
defaultConnIDLen = 0
}
config = populateConfig(config, defaultConnIDLen)
return config
}
func populateConfig(config *Config, defaultConnIDLen int) *Config {
if config == nil { if config == nil {
config = &Config{} config = &Config{}
} }
@ -68,10 +72,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
if len(versions) == 0 { if len(versions) == 0 {
versions = protocol.SupportedVersions versions = protocol.SupportedVersions
} }
conIDLen := config.ConnectionIDLength
if config.ConnectionIDLength == 0 {
conIDLen = defaultConnIDLen
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 { if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout handshakeIdleTimeout = config.HandshakeIdleTimeout
@ -108,12 +108,9 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
} else if maxIncomingUniStreams < 0 { } else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0 maxIncomingUniStreams = 0
} }
connIDGenerator := config.ConnectionIDGenerator
if connIDGenerator == nil {
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
}
return &Config{ return &Config{
GetConfigForClient: config.GetConfigForClient,
Versions: versions, Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout, HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout, MaxIdleTimeout: idleTimeout,
@ -128,9 +125,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams, MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams, MaxIncomingUniStreams: maxIncomingUniStreams,
ConnectionIDLength: conIDLen,
ConnectionIDGenerator: connIDGenerator,
StatelessResetKey: config.StatelessResetKey,
TokenStore: config.TokenStore, TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams, EnableDatagrams: config.EnableDatagrams,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,

View File

@ -61,11 +61,6 @@ type cryptoStreamHandler interface {
ConnectionState() handshake.ConnectionState ConnectionState() handshake.ConnectionState
} }
type packetInfo struct {
addr net.IP
ifIndex uint32
}
type receivedPacket struct { type receivedPacket struct {
buffer *packetBuffer buffer *packetBuffer
@ -75,7 +70,7 @@ type receivedPacket struct {
ecn protocol.ECN ecn protocol.ECN
info *packetInfo info packetInfo // only valid if the contained IP address is valid
} }
func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) }
@ -173,7 +168,7 @@ type connection struct {
oneRTTStream cryptoStream // only set for the server oneRTTStream cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler cryptoStreamHandler cryptoStreamHandler
receivedPackets chan *receivedPacket receivedPackets chan receivedPacket
sendingScheduled chan struct{} sendingScheduled chan struct{}
closeOnce sync.Once closeOnce sync.Once
@ -185,8 +180,8 @@ type connection struct {
handshakeCtx context.Context handshakeCtx context.Context
handshakeCtxCancel context.CancelFunc handshakeCtxCancel context.CancelFunc
undecryptablePackets []*receivedPacket // undecryptable packets, waiting for a change in encryption level undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []*receivedPacket undecryptablePacketsToProcess []receivedPacket
clientHelloWritten <-chan *wire.TransportParameters clientHelloWritten <-chan *wire.TransportParameters
earlyConnReadyChan chan struct{} earlyConnReadyChan chan struct{}
@ -199,6 +194,7 @@ type connection struct {
versionNegotiated bool versionNegotiated bool
receivedFirstPacket bool receivedFirstPacket bool
// the minimum of the max_idle_timeout values advertised by both endpoints
idleTimeout time.Duration idleTimeout time.Duration
creationTime time.Time creationTime time.Time
// The idle timeout is set based on the max of the time we received the last packet... // The idle timeout is set based on the max of the time we received the last packet...
@ -240,6 +236,7 @@ var newConnection = func(
clientDestConnID protocol.ConnectionID, clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID, destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID, srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetToken protocol.StatelessResetToken, statelessResetToken protocol.StatelessResetToken,
conf *Config, conf *Config,
tlsConf *tls.Config, tlsConf *tls.Config,
@ -283,7 +280,7 @@ var newConnection = func(
runner.Retire, runner.Retire,
runner.ReplaceWithClosed, runner.ReplaceWithClosed,
s.queueControlFrame, s.queueControlFrame,
s.config.ConnectionIDGenerator, connIDGenerator,
) )
s.preSetup() s.preSetup()
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
@ -296,6 +293,7 @@ var newConnection = func(
s.tracer, s.tracer,
s.logger, s.logger,
) )
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
params := &wire.TransportParameters{ params := &wire.TransportParameters{
@ -311,9 +309,14 @@ var newConnection = func(
DisableActiveMigration: true, DisableActiveMigration: true,
StatelessResetToken: &statelessResetToken, StatelessResetToken: &statelessResetToken,
OriginalDestinationConnectionID: origDestConnID, OriginalDestinationConnectionID: origDestConnID,
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, // For interoperability with quic-go versions before May 2023, this value must be set to a value
InitialSourceConnectionID: srcConnID, // different from protocol.DefaultActiveConnectionIDLimit.
RetrySourceConnectionID: retrySrcConnID, // If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
RetrySourceConnectionID: retrySrcConnID,
} }
if s.config.EnableDatagrams { if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
@ -323,10 +326,6 @@ var newConnection = func(
if s.tracer != nil { if s.tracer != nil {
s.tracer.SentTransportParameters(params) s.tracer.SentTransportParameters(params)
} }
var allow0RTT func() bool
if conf.Allow0RTT != nil {
allow0RTT = func() bool { return conf.Allow0RTT(conn.RemoteAddr()) }
}
cs := handshake.NewCryptoSetupServer( cs := handshake.NewCryptoSetupServer(
initialStream, initialStream,
handshakeStream, handshakeStream,
@ -344,14 +343,14 @@ var newConnection = func(
}, },
}, },
tlsConf, tlsConf,
allow0RTT, conf.Allow0RTT,
s.rttStats, s.rttStats,
tracer, tracer,
logger, logger,
s.version, s.version,
) )
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
return s return s
@ -363,6 +362,7 @@ var newClientConnection = func(
runner connRunner, runner connRunner,
destConnID protocol.ConnectionID, destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID, srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
conf *Config, conf *Config,
tlsConf *tls.Config, tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
@ -402,7 +402,7 @@ var newClientConnection = func(
runner.Retire, runner.Retire,
runner.ReplaceWithClosed, runner.ReplaceWithClosed,
s.queueControlFrame, s.queueControlFrame,
s.config.ConnectionIDGenerator, connIDGenerator,
) )
s.preSetup() s.preSetup()
s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
@ -415,6 +415,7 @@ var newClientConnection = func(
s.tracer, s.tracer,
s.logger, s.logger,
) )
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
initialStream := newCryptoStream() initialStream := newCryptoStream()
handshakeStream := newCryptoStream() handshakeStream := newCryptoStream()
params := &wire.TransportParameters{ params := &wire.TransportParameters{
@ -428,8 +429,13 @@ var newClientConnection = func(
MaxAckDelay: protocol.MaxAckDelayInclGranularity, MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent, AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true, DisableActiveMigration: true,
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, // For interoperability with quic-go versions before May 2023, this value must be set to a value
InitialSourceConnectionID: srcConnID, // different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
} }
if s.config.EnableDatagrams { if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
@ -463,7 +469,7 @@ var newClientConnection = func(
s.cryptoStreamHandler = cs s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 { if len(tlsConf.ServerName) > 0 {
s.tokenStoreKey = tlsConf.ServerName s.tokenStoreKey = tlsConf.ServerName
} else { } else {
@ -504,7 +510,7 @@ func (s *connection) preSetup() {
s.perspective, s.perspective,
) )
s.framer = newFramer(s.streamsMap) s.framer = newFramer(s.streamsMap)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background())
@ -657,7 +663,7 @@ runLoop:
} else { } else {
idleTimeoutStartTime := s.idleTimeoutStartTime() idleTimeoutStartTime := s.idleTimeoutStartTime()
if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) || if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) ||
(s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.idleTimeout) { (s.handshakeComplete && now.After(s.nextIdleTimeoutTime())) {
s.destroyImpl(qerr.ErrIdleTimeout) s.destroyImpl(qerr.ErrIdleTimeout)
continue continue
} }
@ -669,7 +675,7 @@ runLoop:
sendQueueAvailable = s.sendQueue.Available() sendQueueAvailable = s.sendQueue.Available()
continue continue
} }
if err := s.sendPackets(); err != nil { if err := s.triggerSending(); err != nil {
s.closeLocal(err) s.closeLocal(err)
} }
if s.sendQueue.WouldBlock() { if s.sendQueue.WouldBlock() {
@ -681,12 +687,12 @@ runLoop:
s.cryptoStreamHandler.Close() s.cryptoStreamHandler.Close()
<-handshaking <-handshaking
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
s.handleCloseError(&closeErr) s.handleCloseError(&closeErr)
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil {
s.tracer.Close() s.tracer.Close()
} }
s.logger.Infof("Connection %s closed.", s.logID) s.logger.Infof("Connection %s closed.", s.logID)
s.sendQueue.Close()
s.timer.Stop() s.timer.Stop()
return closeErr.err return closeErr.err
} }
@ -715,13 +721,20 @@ func (s *connection) ConnectionState() ConnectionState {
return s.connState return s.connState
} }
// Time when the connection should time out
func (s *connection) nextIdleTimeoutTime() time.Time {
idleTimeout := utils.Max(s.idleTimeout, s.rttStats.PTO(true)*3)
return s.idleTimeoutStartTime().Add(idleTimeout)
}
// Time when the next keep-alive packet should be sent. // Time when the next keep-alive packet should be sent.
// It returns a zero time if no keep-alive should be sent. // It returns a zero time if no keep-alive should be sent.
func (s *connection) nextKeepAliveTime() time.Time { func (s *connection) nextKeepAliveTime() time.Time {
if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() {
return time.Time{} return time.Time{}
} }
return s.lastPacketReceivedTime.Add(s.keepAliveInterval) keepAliveInterval := utils.Max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2)
return s.lastPacketReceivedTime.Add(keepAliveInterval)
} }
func (s *connection) maybeResetTimer() { func (s *connection) maybeResetTimer() {
@ -735,7 +748,7 @@ func (s *connection) maybeResetTimer() {
if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() { if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() {
deadline = keepAliveTime deadline = keepAliveTime
} else { } else {
deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) deadline = s.nextIdleTimeoutTime()
} }
} }
@ -792,25 +805,16 @@ func (s *connection) handleHandshakeConfirmed() {
s.sentPacketHandler.SetHandshakeConfirmed() s.sentPacketHandler.SetHandshakeConfirmed()
s.cryptoStreamHandler.SetHandshakeConfirmed() s.cryptoStreamHandler.SetHandshakeConfirmed()
if !s.config.DisablePathMTUDiscovery { if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF {
maxPacketSize := s.peerParams.MaxUDPPayloadSize maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 { if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount maxPacketSize = protocol.MaxByteCount
} }
maxPacketSize = utils.Min(maxPacketSize, protocol.MaxPacketBufferSize) s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize))
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
getMaxPacketSize(s.conn.RemoteAddr()),
maxPacketSize,
func(size protocol.ByteCount) {
s.sentPacketHandler.SetMaxDatagramSize(size)
s.packer.SetMaxPacketSize(size)
},
)
} }
} }
func (s *connection) handlePacketImpl(rp *receivedPacket) bool { func (s *connection) handlePacketImpl(rp receivedPacket) bool {
s.sentPacketHandler.ReceivedBytes(rp.Size()) s.sentPacketHandler.ReceivedBytes(rp.Size())
if wire.IsVersionNegotiationPacket(rp.data) { if wire.IsVersionNegotiationPacket(rp.data) {
@ -826,7 +830,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
for len(data) > 0 { for len(data) > 0 {
var destConnID protocol.ConnectionID var destConnID protocol.ConnectionID
if counter > 0 { if counter > 0 {
p = p.Clone() p = *(p.Clone())
p.data = data p.data = data
var err error var err error
@ -899,7 +903,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
return processed return processed
} }
func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool { func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protocol.ConnectionID) bool {
var wasQueued bool var wasQueued bool
defer func() { defer func() {
@ -950,7 +954,7 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID proto
return true return true
} }
func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
var wasQueued bool var wasQueued bool
defer func() { defer func() {
@ -1007,7 +1011,7 @@ func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header)
return true return true
} }
func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) { func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) {
switch err { switch err {
case handshake.ErrKeysDropped: case handshake.ErrKeysDropped:
if s.tracer != nil { if s.tracer != nil {
@ -1109,7 +1113,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa
return true return true
} }
func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets
if s.tracer != nil { if s.tracer != nil {
@ -1253,7 +1257,11 @@ func (s *connection) handleFrames(
) (isAckEliciting bool, _ error) { ) (isAckEliciting bool, _ error) {
// Only used for tracing. // Only used for tracing.
// If we're not tracing, this slice will always remain empty. // If we're not tracing, this slice will always remain empty.
var frames []wire.Frame var frames []logging.Frame
if log != nil {
frames = make([]logging.Frame, 0, 4)
}
var handleErr error
for len(data) > 0 { for len(data) > 0 {
l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version) l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version)
if err != nil { if err != nil {
@ -1266,27 +1274,27 @@ func (s *connection) handleFrames(
if ackhandler.IsFrameAckEliciting(frame) { if ackhandler.IsFrameAckEliciting(frame) {
isAckEliciting = true isAckEliciting = true
} }
// Only process frames now if we're not logging. if log != nil {
// If we're logging, we need to make sure that the packet_received event is logged first. frames = append(frames, logutils.ConvertFrame(frame))
if log == nil { }
if err := s.handleFrame(frame, encLevel, destConnID); err != nil { // An error occurred handling a previous frame.
// Don't handle the current frame.
if handleErr != nil {
continue
}
if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
if log == nil {
return false, err return false, err
} }
} else { // If we're logging, we need to keep parsing (but not handling) all frames.
frames = append(frames, frame) handleErr = err
} }
} }
if log != nil { if log != nil {
fs := make([]logging.Frame, len(frames)) log(frames)
for i, frame := range frames { if handleErr != nil {
fs[i] = logutils.ConvertFrame(frame) return false, handleErr
}
log(fs)
for _, frame := range frames {
if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
return false, err
}
} }
} }
return return
@ -1302,7 +1310,6 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel
err = s.handleStreamFrame(frame) err = s.handleStreamFrame(frame)
case *wire.AckFrame: case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel) err = s.handleAckFrame(frame, encLevel)
wire.PutAckFrame(frame)
case *wire.ConnectionCloseFrame: case *wire.ConnectionCloseFrame:
s.handleConnectionCloseFrame(frame) s.handleConnectionCloseFrame(frame)
case *wire.ResetStreamFrame: case *wire.ResetStreamFrame:
@ -1341,7 +1348,7 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel
} }
// handlePacket is called by the server with a new packet // handlePacket is called by the server with a new packet
func (s *connection) handlePacket(p *receivedPacket) { func (s *connection) handlePacket(p receivedPacket) {
// Discard packets once the amount of queued packets is larger than // Discard packets once the amount of queued packets is larger than
// the channel size, protocol.MaxConnUnprocessedPackets // the channel size, protocol.MaxConnUnprocessedPackets
select { select {
@ -1715,7 +1722,6 @@ func (s *connection) applyTransportParameters() {
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval))
s.streamsMap.UpdateLimits(params) s.streamsMap.UpdateLimits(params)
s.packer.HandleTransportParameters(params)
s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.frameParser.SetAckDelayExponent(params.AckDelayExponent)
s.connFlowController.UpdateSendWindow(params.InitialMaxData) s.connFlowController.UpdateSendWindow(params.InitialMaxData)
s.rttStats.SetMaxAckDelay(params.MaxAckDelay) s.rttStats.SetMaxAckDelay(params.MaxAckDelay)
@ -1730,75 +1736,208 @@ func (s *connection) applyTransportParameters() {
} }
} }
func (s *connection) sendPackets() error { func (s *connection) triggerSending() error {
s.pacingDeadline = time.Time{} s.pacingDeadline = time.Time{}
now := time.Now()
var sentPacket bool // only used in for packets sent in send mode SendAny sendMode := s.sentPacketHandler.SendMode(now)
for { //nolint:exhaustive // No need to handle pacing limited here.
sendMode := s.sentPacketHandler.SendMode() switch sendMode {
if sendMode == ackhandler.SendAny && s.handshakeComplete && !s.sentPacketHandler.HasPacingBudget() { case ackhandler.SendAny:
deadline := s.sentPacketHandler.TimeUntilSend() return s.sendPackets(now)
if deadline.IsZero() { case ackhandler.SendNone:
deadline = deadlineSendImmediately return nil
} case ackhandler.SendPacingLimited:
s.pacingDeadline = deadline deadline := s.sentPacketHandler.TimeUntilSend()
// Allow sending of an ACK if we're pacing limit (if we haven't sent out a packet yet). if deadline.IsZero() {
// This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate) deadline = deadlineSendImmediately
// sends enough ACKs to allow its peer to utilize the bandwidth.
if sentPacket {
return nil
}
sendMode = ackhandler.SendAck
} }
switch sendMode { s.pacingDeadline = deadline
case ackhandler.SendNone: // Allow sending of an ACK if we're pacing limit.
// This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate)
// sends enough ACKs to allow its peer to utilize the bandwidth.
fallthrough
case ackhandler.SendAck:
// We can at most send a single ACK only packet.
// There will only be a new ACK after receiving new packets.
// SendAck is only returned when we're congestion limited, so we don't need to set the pacinggs timer.
return s.maybeSendAckOnlyPacket(now)
case ackhandler.SendPTOInitial:
if err := s.sendProbePacket(protocol.EncryptionInitial, now); err != nil {
return err
}
if s.sendQueue.WouldBlock() {
s.scheduleSending()
return nil return nil
case ackhandler.SendAck: }
// If we already sent packets, and the send mode switches to SendAck, return s.triggerSending()
// as we've just become congestion limited. case ackhandler.SendPTOHandshake:
// There's no need to try to send an ACK at this moment. if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil {
if sentPacket { return err
}
if s.sendQueue.WouldBlock() {
s.scheduleSending()
return nil
}
return s.triggerSending()
case ackhandler.SendPTOAppData:
if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil {
return err
}
if s.sendQueue.WouldBlock() {
s.scheduleSending()
return nil
}
return s.triggerSending()
default:
return fmt.Errorf("BUG: invalid send mode %d", sendMode)
}
}
func (s *connection) sendPackets(now time.Time) error {
// Path MTU Discovery
// Can't use GSO, since we need to send a single packet that's larger than our current maximum size.
// Performance-wise, this doesn't matter, since we only send a very small (<10) number of
// MTU probe packets per connection.
if s.handshakeConfirmed && s.mtuDiscoverer != nil && s.mtuDiscoverer.ShouldSendProbe(now) {
ping, size := s.mtuDiscoverer.GetPing()
p, buf, err := s.packer.PackMTUProbePacket(ping, size, s.version)
if err != nil {
return err
}
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, now)
s.sendQueue.Send(buf, buf.Len())
// This is kind of a hack. We need to trigger sending again somehow.
s.pacingDeadline = deadlineSendImmediately
return nil
}
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked {
s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset})
}
s.windowUpdateQueue.QueueAll()
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil || packet == nil {
return err
}
s.sentFirstPacket = true
s.sendPackedCoalescedPacket(packet, now)
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
} else if sendMode == ackhandler.SendAny {
s.pacingDeadline = deadlineSendImmediately
}
return nil
}
if s.conn.capabilities().GSO {
return s.sendPacketsWithGSO(now)
}
return s.sendPacketsWithoutGSO(now)
}
func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil return nil
} }
// We can at most send a single ACK only packet. return err
// There will only be a new ACK after receiving new packets. }
// SendAck is only returned when we're congestion limited, so we don't need to set the pacinggs timer.
return s.maybeSendAckOnlyPacket() s.sendQueue.Send(buf, buf.Len())
case ackhandler.SendPTOInitial:
if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { if s.sendQueue.WouldBlock() {
return err return nil
} }
case ackhandler.SendPTOHandshake: sendMode := s.sentPacketHandler.SendMode(now)
if err := s.sendProbePacket(protocol.EncryptionHandshake); err != nil { if sendMode == ackhandler.SendPacingLimited {
return err s.resetPacingDeadline()
} return nil
case ackhandler.SendPTOAppData: }
if err := s.sendProbePacket(protocol.Encryption1RTT); err != nil { if sendMode != ackhandler.SendAny {
return err return nil
}
case ackhandler.SendAny:
sent, err := s.sendPacket()
if err != nil || !sent {
return err
}
sentPacket = true
default:
return fmt.Errorf("BUG: invalid send mode %d", sendMode)
} }
// Prioritize receiving of packets over sending out more packets. // Prioritize receiving of packets over sending out more packets.
if len(s.receivedPackets) > 0 { if len(s.receivedPackets) > 0 {
s.pacingDeadline = deadlineSendImmediately s.pacingDeadline = deadlineSendImmediately
return nil return nil
} }
if s.sendQueue.WouldBlock() {
return nil
}
} }
} }
func (s *connection) maybeSendAckOnlyPacket() error { func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := s.mtuDiscoverer.CurrentSize()
for {
var dontSendMore bool
size, err := s.appendPacket(buf, maxSize, now)
if err != nil {
if err != errNothingToPack {
return err
}
if buf.Len() == 0 {
buf.Release()
return nil
}
dontSendMore = true
}
if !dontSendMore {
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
}
if sendMode != ackhandler.SendAny {
dontSendMore = true
}
}
// Append another packet if
// 1. The congestion controller and pacer allow sending more
// 2. The last packet appended was a full-size packet
// 3. We still have enough space for another full-size packet in the buffer
if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() {
continue
}
s.sendQueue.Send(buf, maxSize)
if dontSendMore {
return nil
}
if s.sendQueue.WouldBlock() {
return nil
}
// Prioritize receiving of packets over sending out more packets.
if len(s.receivedPackets) > 0 {
s.pacingDeadline = deadlineSendImmediately
return nil
}
buf = getLargePacketBuffer()
}
}
func (s *connection) resetPacingDeadline() {
deadline := s.sentPacketHandler.TimeUntilSend()
if deadline.IsZero() {
deadline = deadlineSendImmediately
}
s.pacingDeadline = deadline
}
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed { if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(true, s.version) packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
return err return err
} }
@ -1809,20 +1948,20 @@ func (s *connection) maybeSendAckOnlyPacket() error {
return nil return nil
} }
now := time.Now() p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
p, buffer, err := s.packer.PackPacket(true, now, s.version)
if err != nil { if err != nil {
if err == errNothingToPack { if err == errNothingToPack {
return nil return nil
} }
return err return err
} }
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false)
s.sendPackedShortHeaderPacket(buffer, p.Packet, now) s.registerPackedShortHeaderPacket(p, now)
s.sendQueue.Send(buf, buf.Len())
return nil return nil
} }
func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time.Time) error {
// Queue probe packets until we actually send out a packet, // Queue probe packets until we actually send out a packet,
// or until there are no more packets to queue. // or until there are no more packets to queue.
var packet *coalescedPacket var packet *coalescedPacket
@ -1831,7 +1970,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error {
break break
} }
var err error var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
return err return err
} }
@ -1840,19 +1979,9 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error {
} }
} }
if packet == nil { if packet == nil {
//nolint:exhaustive // Cannot send probe packets for 0-RTT. s.retransmissionQueue.AddPing(encLevel)
switch encLevel {
case protocol.EncryptionInitial:
s.retransmissionQueue.AddInitial(&wire.PingFrame{})
case protocol.EncryptionHandshake:
s.retransmissionQueue.AddHandshake(&wire.PingFrame{})
case protocol.Encryption1RTT:
s.retransmissionQueue.AddAppData(&wire.PingFrame{})
default:
panic("unexpected encryption level")
}
var err error var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
return err return err
} }
@ -1860,55 +1989,35 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error {
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
} }
s.sendPackedCoalescedPacket(packet, time.Now()) s.sendPackedCoalescedPacket(packet, now)
return nil return nil
} }
func (s *connection) sendPacket() (bool, error) { // appendPacket appends a new packet to the given packetBuffer.
if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { // If there was nothing to pack, the returned size is 0.
s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) {
} startLen := buf.Len()
s.windowUpdateQueue.QueueAll() p, err := s.packer.AppendPacket(buf, maxSize, s.version)
now := time.Now()
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.version)
if err != nil || packet == nil {
return false, err
}
s.sentFirstPacket = true
s.sendPackedCoalescedPacket(packet, now)
return true, nil
} else if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) {
ping, size := s.mtuDiscoverer.GetPing()
p, buffer, err := s.packer.PackMTUProbePacket(ping, size, now, s.version)
if err != nil {
return false, err
}
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false)
s.sendPackedShortHeaderPacket(buffer, p.Packet, now)
return true, nil
}
p, buffer, err := s.packer.PackPacket(false, now, s.version)
if err != nil { if err != nil {
if err == errNothingToPack { return 0, err
return false, nil
}
return false, err
} }
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) size := buf.Len() - startLen
s.sendPackedShortHeaderPacket(buffer, p.Packet, now) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false)
return true, nil s.registerPackedShortHeaderPacket(p, now)
return size, nil
} }
func (s *connection) sendPackedShortHeaderPacket(buffer *packetBuffer, p *ackhandler.Packet, now time.Time) { func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && ackhandler.HasAckElicitingFrames(p.Frames) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
} }
s.sentPacketHandler.SentPacket(p) largestAcked := protocol.InvalidPacketNumber
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket)
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
s.sendQueue.Send(buffer)
} }
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) { func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) {
@ -1917,16 +2026,24 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
} }
s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) largestAcked := protocol.InvalidPacketNumber
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false)
} }
if p := packet.shortHdrPacket; p != nil { if p := packet.shortHdrPacket; p != nil {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
} }
s.sentPacketHandler.SentPacket(p.Packet) largestAcked := protocol.InvalidPacketNumber
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket)
} }
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer) s.sendQueue.Send(packet.buffer, packet.buffer.Len())
} }
func (s *connection) sendConnectionClose(e error) ([]byte, error) { func (s *connection) sendConnectionClose(e error) ([]byte, error) {
@ -1935,20 +2052,20 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
var transportErr *qerr.TransportError var transportErr *qerr.TransportError
var applicationErr *qerr.ApplicationError var applicationErr *qerr.ApplicationError
if errors.As(e, &transportErr) { if errors.As(e, &transportErr) {
packet, err = s.packer.PackConnectionClose(transportErr, s.version) packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version)
} else if errors.As(e, &applicationErr) { } else if errors.As(e, &applicationErr) {
packet, err = s.packer.PackApplicationClose(applicationErr, s.version) packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version)
} else { } else {
packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ packet, err = s.packer.PackConnectionClose(&qerr.TransportError{
ErrorCode: qerr.InternalError, ErrorCode: qerr.InternalError,
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
}, s.version) }, s.mtuDiscoverer.CurrentSize(), s.version)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.logCoalescedPacket(packet) s.logCoalescedPacket(packet)
return packet.buffer.Data, s.conn.Write(packet.buffer.Data) return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len())
} }
func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
@ -1980,7 +2097,8 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
func (s *connection) logShortHeaderPacket( func (s *connection) logShortHeaderPacket(
destConnID protocol.ConnectionID, destConnID protocol.ConnectionID,
ackFrame *wire.AckFrame, ackFrame *wire.AckFrame,
frames []*ackhandler.Frame, frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
pn protocol.PacketNumber, pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen, pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit, kp protocol.KeyPhaseBit,
@ -1996,17 +2114,23 @@ func (s *connection) logShortHeaderPacket(
if ackFrame != nil { if ackFrame != nil {
wire.LogFrame(s.logger, ackFrame, true) wire.LogFrame(s.logger, ackFrame, true)
} }
for _, frame := range frames { for _, f := range frames {
wire.LogFrame(s.logger, frame.Frame, true) wire.LogFrame(s.logger, f.Frame, true)
}
for _, f := range streamFrames {
wire.LogFrame(s.logger, f.Frame, true)
} }
} }
// tracing // tracing
if s.tracer != nil { if s.tracer != nil {
fs := make([]logging.Frame, 0, len(frames)) fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
for _, f := range frames { for _, f := range frames {
fs = append(fs, logutils.ConvertFrame(f.Frame)) fs = append(fs, logutils.ConvertFrame(f.Frame))
} }
for _, f := range streamFrames {
fs = append(fs, logutils.ConvertFrame(f.Frame))
}
var ack *logging.AckFrame var ack *logging.AckFrame
if ackFrame != nil { if ackFrame != nil {
ack = logutils.ConvertAckFrame(ackFrame) ack = logutils.ConvertAckFrame(ackFrame)
@ -2034,6 +2158,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
packet.shortHdrPacket.DestConnID, packet.shortHdrPacket.DestConnID,
packet.shortHdrPacket.Ack, packet.shortHdrPacket.Ack,
packet.shortHdrPacket.Frames, packet.shortHdrPacket.Frames,
packet.shortHdrPacket.StreamFrames,
packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase, packet.shortHdrPacket.KeyPhase,
@ -2052,7 +2177,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
s.logLongHeaderPacket(p) s.logLongHeaderPacket(p)
} }
if p := packet.shortHdrPacket; p != nil { if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true)
} }
} }
@ -2113,7 +2238,7 @@ func (s *connection) scheduleSending() {
// tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys. // tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys.
// The logging.PacketType is only used for logging purposes. // The logging.PacketType is only used for logging purposes.
func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, pt logging.PacketType) { func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging.PacketType) {
if s.handshakeComplete { if s.handshakeComplete {
panic("shouldn't queue undecryptable packets after handshake completion") panic("shouldn't queue undecryptable packets after handshake completion")
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/quicvarint"
) )
@ -14,10 +15,10 @@ type framer interface {
HasData() bool HasData() bool
QueueControlFrame(wire.Frame) QueueControlFrame(wire.Frame)
AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
AddActiveStream(protocol.StreamID) AddActiveStream(protocol.StreamID)
AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
Handle0RTTRejection() error Handle0RTTRejection() error
} }
@ -28,7 +29,7 @@ type framerI struct {
streamGetter streamGetter streamGetter streamGetter
activeStreams map[protocol.StreamID]struct{} activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID streamQueue ringbuffer.RingBuffer[protocol.StreamID]
controlFrameMutex sync.Mutex controlFrameMutex sync.Mutex
controlFrames []wire.Frame controlFrames []wire.Frame
@ -45,7 +46,7 @@ func newFramer(streamGetter streamGetter) framer {
func (f *framerI) HasData() bool { func (f *framerI) HasData() bool {
f.mutex.Lock() f.mutex.Lock()
hasData := len(f.streamQueue) > 0 hasData := !f.streamQueue.Empty()
f.mutex.Unlock() f.mutex.Unlock()
if hasData { if hasData {
return true return true
@ -62,7 +63,7 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Unlock() f.controlFrameMutex.Unlock()
} }
func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
var length protocol.ByteCount var length protocol.ByteCount
f.controlFrameMutex.Lock() f.controlFrameMutex.Lock()
for len(f.controlFrames) > 0 { for len(f.controlFrames) > 0 {
@ -71,9 +72,7 @@ func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protoco
if length+frameLen > maxLen { if length+frameLen > maxLen {
break break
} }
af := ackhandler.GetFrame() frames = append(frames, ackhandler.Frame{Frame: frame})
af.Frame = frame
frames = append(frames, af)
length += frameLen length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
} }
@ -84,24 +83,23 @@ func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protoco
func (f *framerI) AddActiveStream(id protocol.StreamID) { func (f *framerI) AddActiveStream(id protocol.StreamID) {
f.mutex.Lock() f.mutex.Lock()
if _, ok := f.activeStreams[id]; !ok { if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id) f.streamQueue.PushBack(id)
f.activeStreams[id] = struct{}{} f.activeStreams[id] = struct{}{}
} }
f.mutex.Unlock() f.mutex.Unlock()
} }
func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) {
startLen := len(frames)
var length protocol.ByteCount var length protocol.ByteCount
var lastFrame *ackhandler.Frame
f.mutex.Lock() f.mutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := len(f.streamQueue) numActiveStreams := f.streamQueue.Len()
for i := 0; i < numActiveStreams; i++ { for i := 0; i < numActiveStreams; i++ {
if protocol.MinStreamFrameSize+length > maxLen { if protocol.MinStreamFrameSize+length > maxLen {
break break
} }
id := f.streamQueue[0] id := f.streamQueue.PopFront()
f.streamQueue = f.streamQueue[1:]
// This should never return an error. Better check it anyway. // This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there. // The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id) str, err := f.streamGetter.GetOrOpenSendStream(id)
@ -115,28 +113,27 @@ func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol
// Therefore, we can pretend to have more bytes available when popping // Therefore, we can pretend to have more bytes available when popping
// the STREAM frame (which will always have the DataLen set). // the STREAM frame (which will always have the DataLen set).
remainingLen += quicvarint.Len(uint64(remainingLen)) remainingLen += quicvarint.Len(uint64(remainingLen))
frame, hasMoreData := str.popStreamFrame(remainingLen, v) frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v)
if hasMoreData { // put the stream back in the queue (at the end) if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue = append(f.streamQueue, id) f.streamQueue.PushBack(id)
} else { // no more data to send. Stream is not active any more } else { // no more data to send. Stream is not active
delete(f.activeStreams, id) delete(f.activeStreams, id)
} }
// The frame can be nil // The frame can be "nil"
// * if the receiveStream was canceled after it said it had data // * if the receiveStream was canceled after it said it had data
// * the remaining size doesn't allow us to add another STREAM frame // * the remaining size doesn't allow us to add another STREAM frame
if frame == nil { if !ok {
continue continue
} }
frames = append(frames, frame) frames = append(frames, frame)
length += frame.Length(v) length += frame.Frame.Length(v)
lastFrame = frame
} }
f.mutex.Unlock() f.mutex.Unlock()
if lastFrame != nil { if len(frames) > startLen {
lastFrameLen := lastFrame.Length(v) l := frames[len(frames)-1].Frame.Length(v)
// account for the smaller size of the last STREAM frame // account for the smaller size of the last STREAM frame
lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false frames[len(frames)-1].Frame.DataLenPresent = false
length += lastFrame.Length(v) - lastFrameLen length += frames[len(frames)-1].Frame.Length(v) - l
} }
return frames, length return frames, length
} }
@ -146,7 +143,7 @@ func (f *framerI) Handle0RTTRejection() error {
defer f.mutex.Unlock() defer f.mutex.Unlock()
f.controlFrameMutex.Lock() f.controlFrameMutex.Lock()
f.streamQueue = f.streamQueue[:0] f.streamQueue.Clear()
for id := range f.activeStreams { for id := range f.activeStreams {
delete(f.activeStreams, id) delete(f.activeStreams, id)
} }

View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"strconv" "strconv"
"sync" "sync"
@ -33,12 +34,11 @@ const (
var defaultQuicConfig = &quic.Config{ var defaultQuicConfig = &quic.Config{
MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
KeepAlivePeriod: 10 * time.Second, KeepAlivePeriod: 10 * time.Second,
Versions: []protocol.VersionNumber{protocol.Version1},
} }
type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
var dialAddr = quic.DialAddrEarlyContext var dialAddr dialFunc = quic.DialAddrEarly
type roundTripperOpts struct { type roundTripperOpts struct {
DisableCompression bool DisableCompression bool
@ -74,9 +74,10 @@ var _ roundTripCloser = &client{}
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
if conf == nil { if conf == nil {
conf = defaultQuicConfig.Clone() conf = defaultQuicConfig.Clone()
} else if len(conf.Versions) == 0 { }
if len(conf.Versions) == 0 {
conf = conf.Clone() conf = conf.Clone()
conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]}
} }
if len(conf.Versions) != 1 { if len(conf.Versions) != 1 {
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
@ -92,6 +93,14 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con
} else { } else {
tlsConf = tlsConf.Clone() tlsConf = tlsConf.Clone()
} }
if tlsConf.ServerName == "" {
sni, _, err := net.SplitHostPort(hostname)
if err != nil {
// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
sni = hostname
}
tlsConf.ServerName = sni
}
// Replace existing ALPNs by H3 // Replace existing ALPNs by H3
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}

View File

@ -4,3 +4,5 @@ package http3
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser" //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser"
type RoundTripCloser = roundTripCloser type RoundTripCloser = roundTripCloser
//go:generate sh -c "go run github.com/golang/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener"

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
@ -15,6 +16,7 @@ import (
type responseWriter struct { type responseWriter struct {
conn quic.Connection conn quic.Connection
str quic.Stream
bufferedStr *bufio.Writer bufferedStr *bufio.Writer
buf []byte buf []byte
@ -36,6 +38,7 @@ func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logge
header: http.Header{}, header: http.Header{},
buf: make([]byte, 16), buf: make([]byte, 16),
conn: conn, conn: conn,
str: str,
bufferedStr: bufio.NewWriter(str), bufferedStr: bufio.NewWriter(str),
logger: logger, logger: logger,
} }
@ -121,6 +124,14 @@ func (w *responseWriter) StreamCreator() StreamCreator {
return w.conn return w.conn
} }
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
return w.str.SetReadDeadline(deadline)
}
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
return w.str.SetWriteDeadline(deadline)
}
// copied from http2/http2.go // copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code // bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4. // permits a body. See RFC 2616, section 4.4.

View File

@ -10,21 +10,24 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
// declare this as a variable, such that we can it mock it in the tests
var quicDialer = quic.DialEarlyContext
type roundTripCloser interface { type roundTripCloser interface {
RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
HandshakeComplete() bool HandshakeComplete() bool
io.Closer io.Closer
} }
type roundTripCloserWithCount struct {
roundTripCloser
useCount atomic.Int64
}
// RoundTripper implements the http.RoundTripper interface // RoundTripper implements the http.RoundTripper interface
type RoundTripper struct { type RoundTripper struct {
mutex sync.Mutex mutex sync.Mutex
@ -82,8 +85,8 @@ type RoundTripper struct {
MaxResponseHeaderBytes int64 MaxResponseHeaderBytes int64
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
clients map[string]roundTripCloser clients map[string]*roundTripCloserWithCount
udpConn *net.UDPConn transport *quic.Transport
} }
// RoundTripOpt are options for the Transport.RoundTripOpt method. // RoundTripOpt are options for the Transport.RoundTripOpt method.
@ -143,6 +146,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cl.useCount.Add(-1)
rsp, err := cl.RoundTripOpt(req, opt) rsp, err := cl.RoundTripOpt(req, opt)
if err != nil { if err != nil {
r.removeClient(hostname) r.removeClient(hostname)
@ -160,12 +164,12 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{}) return r.RoundTripOpt(req, RoundTripOpt{})
} }
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) { func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.clients == nil { if r.clients == nil {
r.clients = make(map[string]roundTripCloser) r.clients = make(map[string]*roundTripCloserWithCount)
} }
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
@ -180,15 +184,16 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri
} }
dial := r.Dial dial := r.Dial
if dial == nil { if dial == nil {
if r.udpConn == nil { if r.transport == nil {
r.udpConn, err = net.ListenUDP("udp", nil) udpConn, err := net.ListenUDP("udp", nil)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
r.transport = &quic.Transport{Conn: udpConn}
} }
dial = r.makeDialer() dial = r.makeDialer()
} }
client, err = newCl( c, err := newCl(
hostname, hostname,
r.TLSClientConfig, r.TLSClientConfig,
&roundTripperOpts{ &roundTripperOpts{
@ -204,10 +209,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
client = &roundTripCloserWithCount{roundTripCloser: c}
r.clients[hostname] = client r.clients[hostname] = client
} else if client.HandshakeComplete() { } else if client.HandshakeComplete() {
isReused = true isReused = true
} }
client.useCount.Add(1)
return client, isReused, nil return client, isReused, nil
} }
@ -231,9 +238,14 @@ func (r *RoundTripper) Close() error {
} }
} }
r.clients = nil r.clients = nil
if r.udpConn != nil { if r.transport != nil {
r.udpConn.Close() if err := r.transport.Close(); err != nil {
r.udpConn = nil return err
}
if err := r.transport.Conn.Close(); err != nil {
return err
}
r.transport = nil
} }
return nil return nil
} }
@ -273,6 +285,17 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf
if err != nil { if err != nil {
return nil, err return nil, err
} }
return quicDialer(ctx, r.udpConn, udpAddr, addr, tlsCfg, cfg) return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
}
}
func (r *RoundTripper) CloseIdleConnections() {
r.mutex.Lock()
defer r.mutex.Unlock()
for hostname, client := range r.clients {
if client.useCount.Load() == 0 {
client.Close()
delete(r.clients, hostname)
}
} }
} }

View File

@ -23,8 +23,12 @@ import (
// allows mocking of quic.Listen and quic.ListenAddr // allows mocking of quic.Listen and quic.ListenAddr
var ( var (
quicListen = quic.ListenEarly quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
quicListenAddr = quic.ListenAddrEarly return quic.ListenEarly(conn, tlsConf, config)
}
quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) {
return quic.ListenAddrEarly(addr, tlsConf, config)
}
) )
const ( const (
@ -44,6 +48,15 @@ const (
streamTypeQPACKDecoderStream = 3 streamTypeQPACKDecoderStream = 3
) )
// A QUICEarlyListener listens for incoming QUIC connections.
type QUICEarlyListener interface {
Accept(context.Context) (quic.EarlyConnection, error)
Addr() net.Addr
io.Closer
}
var _ QUICEarlyListener = &quic.EarlyListener{}
func versionToALPN(v protocol.VersionNumber) string { func versionToALPN(v protocol.VersionNumber) string {
//nolint:exhaustive // These are all the versions we care about. //nolint:exhaustive // These are all the versions we care about.
switch v { switch v {
@ -193,7 +206,7 @@ type Server struct {
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
mutex sync.RWMutex mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo listeners map[*QUICEarlyListener]listenerInfo
closed bool closed bool
@ -249,13 +262,26 @@ func (s *Server) ServeQUICConn(conn quic.Connection) error {
// Make sure you use http3.ConfigureTLSConfig to configure a tls.Config // Make sure you use http3.ConfigureTLSConfig to configure a tls.Config
// and use it to construct a http3-friendly QUIC listener. // and use it to construct a http3-friendly QUIC listener.
// Closing the server does close the listener. // Closing the server does close the listener.
func (s *Server) ServeListener(ln quic.EarlyListener) error { // ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed.
func (s *Server) ServeListener(ln QUICEarlyListener) error {
if err := s.addListener(&ln); err != nil { if err := s.addListener(&ln); err != nil {
return err return err
} }
err := s.serveListener(ln) defer s.removeListener(&ln)
s.removeListener(&ln) for {
return err conn, err := ln.Accept(context.Background())
if err == quic.ErrServerClosed {
return http.ErrServerClosed
}
if err != nil {
return err
}
go func() {
if err := s.handleConn(conn); err != nil {
s.logger.Debugf(err.Error())
}
}()
}
} }
var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig") var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig")
@ -275,7 +301,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
baseConf := ConfigureTLSConfig(tlsConf) baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig quicConf := s.QuicConfig
if quicConf == nil { if quicConf == nil {
quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }} quicConf = &quic.Config{Allow0RTT: true}
} else { } else {
quicConf = s.QuicConfig.Clone() quicConf = s.QuicConfig.Clone()
} }
@ -283,7 +309,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
quicConf.EnableDatagrams = true quicConf.EnableDatagrams = true
} }
var ln quic.EarlyListener var ln QUICEarlyListener
var err error var err error
if conn == nil { if conn == nil {
addr := s.Addr addr := s.Addr
@ -297,26 +323,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
if err != nil { if err != nil {
return err return err
} }
if err := s.addListener(&ln); err != nil { return s.ServeListener(ln)
return err
}
err = s.serveListener(ln)
s.removeListener(&ln)
return err
}
func (s *Server) serveListener(ln quic.EarlyListener) error {
for {
conn, err := ln.Accept(context.Background())
if err != nil {
return err
}
go func() {
if err := s.handleConn(conn); err != nil {
s.logger.Debugf(err.Error())
}
}()
}
} }
func extractPort(addr string) (int, error) { func extractPort(addr string) (int, error) {
@ -391,7 +398,7 @@ func (s *Server) generateAltSvcHeader() {
// We store a pointer to interface in the map set. This is safe because we only // We store a pointer to interface in the map set. This is safe because we only
// call trackListener via Serve and can track+defer untrack the same pointer to // call trackListener via Serve and can track+defer untrack the same pointer to
// local variable there. We never need to compare a Listener from another caller. // local variable there. We never need to compare a Listener from another caller.
func (s *Server) addListener(l *quic.EarlyListener) error { func (s *Server) addListener(l *QUICEarlyListener) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -402,25 +409,24 @@ func (s *Server) addListener(l *quic.EarlyListener) error {
s.logger = utils.DefaultLogger.WithPrefix("server") s.logger = utils.DefaultLogger.WithPrefix("server")
} }
if s.listeners == nil { if s.listeners == nil {
s.listeners = make(map[*quic.EarlyListener]listenerInfo) s.listeners = make(map[*QUICEarlyListener]listenerInfo)
} }
if port, err := extractPort((*l).Addr().String()); err == nil { if port, err := extractPort((*l).Addr().String()); err == nil {
s.listeners[l] = listenerInfo{port} s.listeners[l] = listenerInfo{port}
} else { } else {
s.logger.Errorf( s.logger.Errorf("Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err)
"Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err)
s.listeners[l] = listenerInfo{} s.listeners[l] = listenerInfo{}
} }
s.generateAltSvcHeader() s.generateAltSvcHeader()
return nil return nil
} }
func (s *Server) removeListener(l *quic.EarlyListener) { func (s *Server) removeListener(l *QUICEarlyListener) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.listeners, l) delete(s.listeners, l)
s.generateAltSvcHeader() s.generateAltSvcHeader()
s.mutex.Unlock()
} }
func (s *Server) handleConn(conn quic.Connection) error { func (s *Server) handleConn(conn quic.Connection) error {

View File

@ -198,7 +198,7 @@ type EarlyConnection interface {
// HandshakeComplete blocks until the handshake completes (or fails). // HandshakeComplete blocks until the handshake completes (or fails).
// For the client, data sent before completion of the handshake is encrypted with 0-RTT keys. // For the client, data sent before completion of the handshake is encrypted with 0-RTT keys.
// For the serfer, data sent before completion of the handshake is encrypted with 1-RTT keys, // For the server, data sent before completion of the handshake is encrypted with 1-RTT keys,
// however the client's identity is only verified once the handshake completes. // however the client's identity is only verified once the handshake completes.
HandshakeComplete() <-chan struct{} HandshakeComplete() <-chan struct{}
@ -239,21 +239,12 @@ type ConnectionIDGenerator interface {
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.
type Config struct { type Config struct {
// GetConfigForClient is called for incoming connections.
// If the error is not nil, the connection attempt is refused.
GetConfigForClient func(info *ClientHelloInfo) (*Config, error)
// The QUIC versions that can be negotiated. // The QUIC versions that can be negotiated.
// If not set, it uses all versions available. // If not set, it uses all versions available.
Versions []VersionNumber Versions []VersionNumber
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
ConnectionIDGenerator ConnectionIDGenerator
// HandshakeIdleTimeout is the idle timeout before completion of the handshake. // HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
// If this value is zero, the timeout is set to 5 seconds. // If this value is zero, the timeout is set to 5 seconds.
@ -285,17 +276,21 @@ type Config struct {
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm // If the application is consuming data quickly enough, the flow control auto-tuning algorithm
// will increase the window up to MaxStreamReceiveWindow. // will increase the window up to MaxStreamReceiveWindow.
// If this value is zero, it will default to 512 KB. // If this value is zero, it will default to 512 KB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
InitialStreamReceiveWindow uint64 InitialStreamReceiveWindow uint64
// MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data. // MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 6 MB. // If this value is zero, it will default to 6 MB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
MaxStreamReceiveWindow uint64 MaxStreamReceiveWindow uint64
// InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data. // InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data.
// If the application is consuming data quickly enough, the flow control auto-tuning algorithm // If the application is consuming data quickly enough, the flow control auto-tuning algorithm
// will increase the window up to MaxConnectionReceiveWindow. // will increase the window up to MaxConnectionReceiveWindow.
// If this value is zero, it will default to 512 KB. // If this value is zero, it will default to 512 KB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
InitialConnectionReceiveWindow uint64 InitialConnectionReceiveWindow uint64
// MaxConnectionReceiveWindow is the connection-level flow control window for receiving data. // MaxConnectionReceiveWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 15 MB. // If this value is zero, it will default to 15 MB.
// Values larger than the maximum varint (quicvarint.Max) will be clipped to that value.
MaxConnectionReceiveWindow uint64 MaxConnectionReceiveWindow uint64
// AllowConnectionWindowIncrease is called every time the connection flow controller attempts // AllowConnectionWindowIncrease is called every time the connection flow controller attempts
// to increase the connection flow control window. // to increase the connection flow control window.
@ -305,64 +300,49 @@ type Config struct {
// in this callback. // in this callback.
AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// Values above 2^60 are invalid.
// If not set, it will default to 100. // If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams. // If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 2^60 will be clipped to that value.
MaxIncomingStreams int64 MaxIncomingStreams int64
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open. // MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// Values above 2^60 are invalid.
// If not set, it will default to 100. // If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams. // If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 2^60 will be clipped to that value.
MaxIncomingUniStreams int64 MaxIncomingUniStreams int64
// The StatelessResetKey is used to generate stateless reset tokens.
// If no key is configured, sending of stateless resets is disabled.
StatelessResetKey *StatelessResetKey
// KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive.
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
// every half of MaxIdleTimeout, whichever is smaller). // every half of MaxIdleTimeout, whichever is smaller).
KeepAlivePeriod time.Duration KeepAlivePeriod time.Duration
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. // This allows the sending of QUIC packets that fully utilize the available MTU of the path.
// Note that if Path MTU discovery is causing issues on your system, please open a new issue // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
DisablePathMTUDiscovery bool DisablePathMTUDiscovery bool
// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
// This can be useful if version information is exchanged out-of-band. // This can be useful if version information is exchanged out-of-band.
// It has no effect for a client. // It has no effect for a client.
DisableVersionNegotiationPackets bool DisableVersionNegotiationPackets bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// When set, 0-RTT is enabled. When not set, 0-RTT is disabled.
// Only valid for the server. // Only valid for the server.
// Warning: This API should not be considered stable and might change soon. Allow0RTT bool
Allow0RTT func(net.Addr) bool
// Enable QUIC datagram support (RFC 9221). // Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool EnableDatagrams bool
Tracer logging.Tracer Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer
}
type ClientHelloInfo struct {
RemoteAddr net.Addr
} }
// ConnectionState records basic details about a QUIC connection // ConnectionState records basic details about a QUIC connection
type ConnectionState struct { type ConnectionState struct {
TLS handshake.ConnectionState // TLS contains information about the TLS connection state, incl. the tls.ConnectionState.
TLS handshake.ConnectionState
// SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated.
// This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams).
// If datagram support was negotiated, datagrams can be sent and received using the
// SendMessage and ReceiveMessage methods on the Connection.
SupportsDatagrams bool SupportsDatagrams bool
Version VersionNumber // Version is the QUIC version of the QUIC connection.
} Version VersionNumber
// A Listener for incoming QUIC connections
type Listener interface {
// Close the server. All active connections will be closed.
Close() error
// Addr returns the local network addr that the server is listening on.
Addr() net.Addr
// Accept returns new connections. It should be called in a loop.
Accept(context.Context) (Connection, error)
}
// An EarlyListener listens for incoming QUIC connections,
// and returns them before the handshake completes.
type EarlyListener interface {
// Close the server. All active connections will be closed.
Close() error
// Addr returns the local network addr that the server is listening on.
Addr() net.Addr
// Accept returns new early connections. It should be called in a loop.
Accept(context.Context) (EarlyConnection, error)
} }

View File

@ -10,7 +10,7 @@ func IsFrameAckEliciting(f wire.Frame) bool {
} }
// HasAckElicitingFrames returns true if at least one frame is ack-eliciting. // HasAckElicitingFrames returns true if at least one frame is ack-eliciting.
func HasAckElicitingFrames(fs []*Frame) bool { func HasAckElicitingFrames(fs []Frame) bool {
for _, f := range fs { for _, f := range fs {
if IsFrameAckEliciting(f.Frame) { if IsFrameAckEliciting(f.Frame) {
return true return true

View File

@ -1,29 +1,21 @@
package ackhandler package ackhandler
import ( import (
"sync"
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
) )
// FrameHandler handles the acknowledgement and the loss of a frame.
type FrameHandler interface {
OnAcked(wire.Frame)
OnLost(wire.Frame)
}
type Frame struct { type Frame struct {
wire.Frame // nil if the frame has already been acknowledged in another packet Frame wire.Frame // nil if the frame has already been acknowledged in another packet
OnLost func(wire.Frame) Handler FrameHandler
OnAcked func(wire.Frame)
} }
var framePool = sync.Pool{New: func() any { return &Frame{} }} type StreamFrame struct {
Frame *wire.StreamFrame
func GetFrame() *Frame { Handler FrameHandler
f := framePool.Get().(*Frame)
f.OnLost = nil
f.OnAcked = nil
return f
}
func putFrame(f *Frame) {
f.Frame = nil
f.OnLost = nil
f.OnAcked = nil
framePool.Put(f)
} }

View File

@ -10,20 +10,20 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet // SentPacket may modify the packet
SentPacket(packet *Packet) SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool)
ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) // ReceivedAck processes an ACK frame.
// It does not store a copy of the frame.
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error)
ReceivedBytes(protocol.ByteCount) ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel) DropPackets(protocol.EncryptionLevel)
ResetForRetry() error ResetForRetry() error
SetHandshakeConfirmed() SetHandshakeConfirmed()
// The SendMode determines if and what kind of packets can be sent. // The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode SendMode(now time.Time) SendMode
// TimeUntilSend is the time when the next packet should be sent. // TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets. // It is used for pacing packets.
TimeUntilSend() time.Time TimeUntilSend() time.Time
// HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment.
HasPacingBudget() bool
SetMaxDatagramSize(count protocol.ByteCount) SetMaxDatagramSize(count protocol.ByteCount)
// only to be called once the handshake is complete // only to be called once the handshake is complete

View File

@ -8,13 +8,14 @@ import (
) )
// A Packet is a packet // A Packet is a packet
type Packet struct { type packet struct {
SendTime time.Time
PacketNumber protocol.PacketNumber PacketNumber protocol.PacketNumber
Frames []*Frame StreamFrames []StreamFrame
Frames []Frame
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
Length protocol.ByteCount Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller.
@ -23,15 +24,16 @@ type Packet struct {
skippedPacket bool skippedPacket bool
} }
func (p *Packet) outstanding() bool { func (p *packet) outstanding() bool {
return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket
} }
var packetPool = sync.Pool{New: func() any { return &Packet{} }} var packetPool = sync.Pool{New: func() any { return &packet{} }}
func GetPacket() *Packet { func getPacket() *packet {
p := packetPool.Get().(*Packet) p := packetPool.Get().(*packet)
p.PacketNumber = 0 p.PacketNumber = 0
p.StreamFrames = nil
p.Frames = nil p.Frames = nil
p.LargestAcked = 0 p.LargestAcked = 0
p.Length = 0 p.Length = 0
@ -46,10 +48,8 @@ func GetPacket() *Packet {
// We currently only return Packets back into the pool when they're acknowledged (not when they're lost). // We currently only return Packets back into the pool when they're acknowledged (not when they're lost).
// This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool. // This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool.
func putPacket(p *Packet) { func putPacket(p *packet) {
for _, f := range p.Frames {
putFrame(f)
}
p.Frames = nil p.Frames = nil
p.StreamFrames = nil
packetPool.Put(p) packetPool.Put(p)
} }

View File

@ -7,7 +7,10 @@ import (
type packetNumberGenerator interface { type packetNumberGenerator interface {
Peek() protocol.PacketNumber Peek() protocol.PacketNumber
Pop() protocol.PacketNumber // Pop pops the packet number.
// It reports if the packet number (before the one just popped) was skipped.
// It never skips more than one packet number in a row.
Pop() (skipped bool, _ protocol.PacketNumber)
} }
type sequentialPacketNumberGenerator struct { type sequentialPacketNumberGenerator struct {
@ -24,10 +27,10 @@ func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber {
return p.next return p.next
} }
func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
next := p.next next := p.next
p.next++ p.next++
return next return false, next
} }
// The skippingPacketNumberGenerator generates the packet number for the next packet // The skippingPacketNumberGenerator generates the packet number for the next packet
@ -56,21 +59,26 @@ func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol
} }
func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber {
if p.next == p.nextToSkip {
return p.next + 1
}
return p.next return p.next
} }
func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
next := p.next next := p.next
p.next++ // generate a new packet number for the next packet
if p.next == p.nextToSkip { if p.next == p.nextToSkip {
p.next++ next++
p.next += 2
p.generateNewSkip() p.generateNewSkip()
return true, next
} }
return next p.next++ // generate a new packet number for the next packet
return false, next
} }
func (p *skippingPacketNumberGenerator) generateNewSkip() { func (p *skippingPacketNumberGenerator) generateNewSkip() {
// make sure that there are never two consecutive packet numbers that are skipped // make sure that there are never two consecutive packet numbers that are skipped
p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
p.period = utils.Min(2*p.period, p.maxPeriod) p.period = utils.Min(2*p.period, p.maxPeriod)
} }

View File

@ -169,16 +169,18 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
} }
} }
ack := wire.GetAckFrame() // This function always returns the same ACK frame struct, filled with the most recent values.
ack := h.lastAck
if ack == nil {
ack = &wire.AckFrame{}
}
ack.Reset()
ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime)) ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime))
ack.ECT0 = h.ect0 ack.ECT0 = h.ect0
ack.ECT1 = h.ect1 ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce ack.ECNCE = h.ecnce
ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges)
if h.lastAck != nil {
wire.PutAckFrame(h.lastAck)
}
h.lastAck = ack h.lastAck = ack
h.ackAlarm = time.Time{} h.ackAlarm = time.Time{}
h.ackQueued = false h.ackQueued = false

View File

@ -16,6 +16,10 @@ const (
SendPTOHandshake SendPTOHandshake
// SendPTOAppData means that an Application data probe packet should be sent // SendPTOAppData means that an Application data probe packet should be sent
SendPTOAppData SendPTOAppData
// SendPacingLimited means that the pacer doesn't allow sending of a packet right now,
// but will do in a little while.
// The timestamp when sending is allowed again can be obtained via the SentPacketHandler.TimeUntilSend.
SendPacingLimited
// SendAny means that any packet should be sent // SendAny means that any packet should be sent
SendAny SendAny
) )
@ -34,6 +38,8 @@ func (s SendMode) String() string {
return "pto (Application Data)" return "pto (Application Data)"
case SendAny: case SendAny:
return "any" return "any"
case SendPacingLimited:
return "pacing limited"
default: default:
return fmt.Sprintf("invalid send mode: %d", s) return fmt.Sprintf("invalid send mode: %d", s)
} }

View File

@ -38,7 +38,7 @@ type packetNumberSpace struct {
largestSent protocol.PacketNumber largestSent protocol.PacketNumber
} }
func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool) *packetNumberSpace {
var pns packetNumberGenerator var pns packetNumberGenerator
if skipPNs { if skipPNs {
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
@ -46,7 +46,7 @@ func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStat
pns = newSequentialPacketNumberGenerator(initialPN) pns = newSequentialPacketNumberGenerator(initialPN)
} }
return &packetNumberSpace{ return &packetNumberSpace{
history: newSentPacketHistory(rttStats), history: newSentPacketHistory(),
pns: pns, pns: pns,
largestSent: protocol.InvalidPacketNumber, largestSent: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber,
@ -75,7 +75,7 @@ type sentPacketHandler struct {
// Only applies to the application-data packet number space. // Only applies to the application-data packet number space.
lowestNotConfirmedAcked protocol.PacketNumber lowestNotConfirmedAcked protocol.PacketNumber
ackedPackets []*Packet // to avoid allocations in detectAndRemoveAckedPackets ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets
bytesInFlight protocol.ByteCount bytesInFlight protocol.ByteCount
@ -125,9 +125,9 @@ func newSentPacketHandler(
return &sentPacketHandler{ return &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
initialPackets: newPacketNumberSpace(initialPN, false, rttStats), initialPackets: newPacketNumberSpace(initialPN, false),
handshakePackets: newPacketNumberSpace(0, false, rttStats), handshakePackets: newPacketNumberSpace(0, false),
appDataPackets: newPacketNumberSpace(0, true, rttStats), appDataPackets: newPacketNumberSpace(0, true),
rttStats: rttStats, rttStats: rttStats,
congestion: congestion, congestion: congestion,
perspective: pers, perspective: pers,
@ -146,7 +146,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
h.dropPackets(encLevel) h.dropPackets(encLevel)
} }
func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) { func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight { if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight { if p.Length > h.bytesInFlight {
panic("negative bytes_in_flight") panic("negative bytes_in_flight")
@ -165,7 +165,7 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
// remove outstanding packets from bytes_in_flight // remove outstanding packets from bytes_in_flight
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.history.Iterate(func(p *Packet) (bool, error) { pnSpace.history.Iterate(func(p *packet) (bool, error) {
h.removeFromBytesInFlight(p) h.removeFromBytesInFlight(p)
return true, nil return true, nil
}) })
@ -182,8 +182,8 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
// and not when the client drops 0-RTT keys when the handshake completes. // and not when the client drops 0-RTT keys when the handshake completes.
// When 0-RTT is rejected, all application data sent so far becomes invalid. // When 0-RTT is rejected, all application data sent so far becomes invalid.
// Delete the packets from the history and remove them from bytes_in_flight. // Delete the packets from the history and remove them from bytes_in_flight.
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
if p.EncryptionLevel != protocol.Encryption0RTT { if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket {
return false, nil return false, nil
} }
h.removeFromBytesInFlight(p) h.removeFromBytesInFlight(p)
@ -228,26 +228,64 @@ func (h *sentPacketHandler) packetsInFlight() int {
return packetsInFlight return packetsInFlight
} }
func (h *sentPacketHandler) SentPacket(p *Packet) { func (h *sentPacketHandler) SentPacket(
h.bytesSent += p.Length t time.Time,
pn, largestAcked protocol.PacketNumber,
streamFrames []StreamFrame,
frames []Frame,
encLevel protocol.EncryptionLevel,
size protocol.ByteCount,
isPathMTUProbePacket bool,
) {
h.bytesSent += size
// For the client, drop the Initial packet number space when the first Handshake packet is sent. // For the client, drop the Initial packet number space when the first Handshake packet is sent.
if h.perspective == protocol.PerspectiveClient && p.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake && h.initialPackets != nil {
h.dropPackets(protocol.EncryptionInitial) h.dropPackets(protocol.EncryptionInitial)
} }
isAckEliciting := h.sentPacketImpl(p)
if isAckEliciting { pnSpace := h.getPacketNumberSpace(encLevel)
h.getPacketNumberSpace(p.EncryptionLevel).history.SentAckElicitingPacket(p) if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
} else { for p := utils.Max(0, pnSpace.largestSent+1); p < pn; p++ {
h.getPacketNumberSpace(p.EncryptionLevel).history.SentNonAckElicitingPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) h.logger.Debugf("Skipping packet number %d", p)
putPacket(p) }
p = nil //nolint:ineffassign // This is just to be on the safe side.
} }
if h.tracer != nil && isAckEliciting {
pnSpace.largestSent = pn
isAckEliciting := len(streamFrames) > 0 || len(frames) > 0
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = t
h.bytesInFlight += size
if h.numProbesToSend > 0 {
h.numProbesToSend--
}
}
h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting)
if !isAckEliciting {
pnSpace.history.SentNonAckElicitingPacket(pn)
if !h.peerCompletedAddressValidation {
h.setLossDetectionTimer()
}
return
}
p := getPacket()
p.SendTime = t
p.PacketNumber = pn
p.EncryptionLevel = encLevel
p.Length = size
p.LargestAcked = largestAcked
p.StreamFrames = streamFrames
p.Frames = frames
p.IsPathMTUProbePacket = isPathMTUProbePacket
p.includedInBytesInFlight = true
pnSpace.history.SentAckElicitingPacket(p)
if h.tracer != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
if isAckEliciting || !h.peerCompletedAddressValidation { h.setLossDetectionTimer()
h.setLossDetectionTimer()
}
} }
func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
@ -263,31 +301,6 @@ func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLev
} }
} }
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ {
pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel)
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
for p := utils.Max(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %d", p)
}
}
pnSpace.largestSent = packet.PacketNumber
isAckEliciting := len(packet.Frames) > 0
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
if h.numProbesToSend > 0 {
h.numProbesToSend--
}
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isAckEliciting)
return isAckEliciting
}
func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) { func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
@ -361,7 +374,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
pnSpace.history.DeleteOldPackets(rcvTime)
h.setLossDetectionTimer() h.setLossDetectionTimer()
return acked1RTTPacket, nil return acked1RTTPacket, nil
} }
@ -371,13 +383,13 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu
} }
// Packets are returned in ascending packet number order. // Packets are returned in ascending packet number order.
func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) { func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*packet, error) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
h.ackedPackets = h.ackedPackets[:0] h.ackedPackets = h.ackedPackets[:0]
ackRangeIndex := 0 ackRangeIndex := 0
lowestAcked := ack.LowestAcked() lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked() largestAcked := ack.LargestAcked()
err := pnSpace.history.Iterate(func(p *Packet) (bool, error) { err := pnSpace.history.Iterate(func(p *packet) (bool, error) {
// Ignore packets below the lowest acked // Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked { if p.PacketNumber < lowestAcked {
return true, nil return true, nil
@ -425,8 +437,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
} }
for _, f := range p.Frames { for _, f := range p.Frames {
if f.OnAcked != nil { if f.Handler != nil {
f.OnAcked(f.Frame) f.Handler.OnAcked(f.Frame)
}
}
for _, f := range p.StreamFrames {
if f.Handler != nil {
f.Handler.OnAcked(f.Frame)
} }
} }
if err := pnSpace.history.Remove(p.PacketNumber); err != nil { if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
@ -587,30 +604,31 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
lostSendTime := now.Add(-lossDelay) lostSendTime := now.Add(-lossDelay)
priorInFlight := h.bytesInFlight priorInFlight := h.bytesInFlight
return pnSpace.history.Iterate(func(p *Packet) (bool, error) { return pnSpace.history.Iterate(func(p *packet) (bool, error) {
if p.PacketNumber > pnSpace.largestAcked { if p.PacketNumber > pnSpace.largestAcked {
return false, nil return false, nil
} }
if p.declaredLost || p.skippedPacket {
return true, nil
}
var packetLost bool var packetLost bool
if p.SendTime.Before(lostSendTime) { if p.SendTime.Before(lostSendTime) {
packetLost = true packetLost = true
if h.logger.Debug() { if !p.skippedPacket {
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) if h.logger.Debug() {
} h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
if h.tracer != nil { }
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) if h.tracer != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
}
} }
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold { } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
packetLost = true packetLost = true
if h.logger.Debug() { if !p.skippedPacket {
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) if h.logger.Debug() {
} h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
if h.tracer != nil { }
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) if h.tracer != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
}
} }
} else if pnSpace.lossTime.IsZero() { } else if pnSpace.lossTime.IsZero() {
// Note: This conditional is only entered once per call // Note: This conditional is only entered once per call
@ -621,12 +639,14 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
pnSpace.lossTime = lossTime pnSpace.lossTime = lossTime
} }
if packetLost { if packetLost {
p = pnSpace.history.DeclareLost(p) pnSpace.history.DeclareLost(p.PacketNumber)
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted if !p.skippedPacket {
h.removeFromBytesInFlight(p) // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
h.queueFramesForRetransmission(p) h.removeFromBytesInFlight(p)
if !p.IsPathMTUProbePacket { h.queueFramesForRetransmission(p)
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) if !p.IsPathMTUProbePacket {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
} }
} }
return true, nil return true, nil
@ -689,7 +709,8 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
h.ptoMode = SendPTOHandshake h.ptoMode = SendPTOHandshake
case protocol.Encryption1RTT: case protocol.Encryption1RTT:
// skip a packet number in order to elicit an immediate ACK // skip a packet number in order to elicit an immediate ACK
_ = h.PopPacketNumber(protocol.Encryption1RTT) pn := h.PopPacketNumber(protocol.Encryption1RTT)
h.getPacketNumberSpace(protocol.Encryption1RTT).history.SkippedPacket(pn)
h.ptoMode = SendPTOAppData h.ptoMode = SendPTOAppData
default: default:
return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel)
@ -703,23 +724,25 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
var lowestUnacked protocol.PacketNumber
if p := pnSpace.history.FirstOutstanding(); p != nil {
lowestUnacked = p.PacketNumber
} else {
lowestUnacked = pnSpace.largestAcked + 1
}
pn := pnSpace.pns.Peek() pn := pnSpace.pns.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, lowestUnacked) // See section 17.1 of RFC 9000.
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
} }
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
return h.getPacketNumberSpace(encLevel).pns.Pop() pnSpace := h.getPacketNumberSpace(encLevel)
skipped, pn := pnSpace.pns.Pop()
if skipped {
skippedPN := pn - 1
pnSpace.history.SkippedPacket(skippedPN)
if h.logger.Debug() {
h.logger.Debugf("Skipping packet number %d", skippedPN)
}
}
return pn
} }
func (h *sentPacketHandler) SendMode() SendMode { func (h *sentPacketHandler) SendMode(now time.Time) SendMode {
numTrackedPackets := h.appDataPackets.history.Len() numTrackedPackets := h.appDataPackets.history.Len()
if h.initialPackets != nil { if h.initialPackets != nil {
numTrackedPackets += h.initialPackets.history.Len() numTrackedPackets += h.initialPackets.history.Len()
@ -758,6 +781,9 @@ func (h *sentPacketHandler) SendMode() SendMode {
} }
return SendAck return SendAck
} }
if !h.congestion.HasPacingBudget(now) {
return SendPacingLimited
}
return SendAny return SendAny
} }
@ -765,10 +791,6 @@ func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.congestion.TimeUntilSend(h.bytesInFlight) return h.congestion.TimeUntilSend(h.bytesInFlight)
} }
func (h *sentPacketHandler) HasPacingBudget() bool {
return h.congestion.HasPacingBudget()
}
func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) {
h.congestion.SetMaxDatagramSize(s) h.congestion.SetMaxDatagramSize(s)
} }
@ -790,24 +812,32 @@ func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel)
// TODO: don't declare the packet lost here. // TODO: don't declare the packet lost here.
// Keep track of acknowledged frames instead. // Keep track of acknowledged frames instead.
h.removeFromBytesInFlight(p) h.removeFromBytesInFlight(p)
pnSpace.history.DeclareLost(p) pnSpace.history.DeclareLost(p.PacketNumber)
return true return true
} }
func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
if len(p.Frames) == 0 { if len(p.Frames) == 0 && len(p.StreamFrames) == 0 {
panic("no frames") panic("no frames")
} }
for _, f := range p.Frames { for _, f := range p.Frames {
f.OnLost(f.Frame) if f.Handler != nil {
f.Handler.OnLost(f.Frame)
}
} }
for _, f := range p.StreamFrames {
if f.Handler != nil {
f.Handler.OnLost(f.Frame)
}
}
p.StreamFrames = nil
p.Frames = nil p.Frames = nil
} }
func (h *sentPacketHandler) ResetForRetry() error { func (h *sentPacketHandler) ResetForRetry() error {
h.bytesInFlight = 0 h.bytesInFlight = 0
var firstPacketSendTime time.Time var firstPacketSendTime time.Time
h.initialPackets.history.Iterate(func(p *Packet) (bool, error) { h.initialPackets.history.Iterate(func(p *packet) (bool, error) {
if firstPacketSendTime.IsZero() { if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime firstPacketSendTime = p.SendTime
} }
@ -819,7 +849,7 @@ func (h *sentPacketHandler) ResetForRetry() error {
}) })
// All application data packets sent at this point are 0-RTT packets. // All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them. // In the case of a Retry, we can assume that the server dropped all of them.
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { h.appDataPackets.history.Iterate(func(p *packet) (bool, error) {
if !p.declaredLost && !p.skippedPacket { if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p) h.queueFramesForRetransmission(p)
} }
@ -839,8 +869,8 @@ func (h *sentPacketHandler) ResetForRetry() error {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
} }
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats) h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false)
h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats) h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true)
oldAlarm := h.alarm oldAlarm := h.alarm
h.alarm = time.Time{} h.alarm = time.Time{}
if h.tracer != nil { if h.tracer != nil {

View File

@ -2,162 +2,176 @@ package ackhandler
import ( import (
"fmt" "fmt"
"sync"
"time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
) )
type sentPacketHistory struct { type sentPacketHistory struct {
rttStats *utils.RTTStats packets []*packet
outstandingPacketList *list.List[*Packet]
etcPacketList *list.List[*Packet] numOutstanding int
packetMap map[protocol.PacketNumber]*list.Element[*Packet]
highestSent protocol.PacketNumber highestPacketNumber protocol.PacketNumber
} }
var packetElementPool sync.Pool func newSentPacketHistory() *sentPacketHistory {
func init() {
packetElementPool = *list.NewPool[*Packet]()
}
func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory {
return &sentPacketHistory{ return &sentPacketHistory{
rttStats: rttStats, packets: make([]*packet, 0, 32),
outstandingPacketList: list.NewWithPool[*Packet](&packetElementPool), highestPacketNumber: protocol.InvalidPacketNumber,
etcPacketList: list.NewWithPool[*Packet](&packetElementPool),
packetMap: make(map[protocol.PacketNumber]*list.Element[*Packet]),
highestSent: protocol.InvalidPacketNumber,
} }
} }
func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) { func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) {
h.registerSentPacket(pn, encLevel, t) if h.highestPacketNumber != protocol.InvalidPacketNumber {
if pn != h.highestPacketNumber+1 {
panic("non-sequential packet number use")
}
}
} }
func (h *sentPacketHistory) SentAckElicitingPacket(p *Packet) { func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) {
h.registerSentPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
h.packets = append(h.packets, &packet{
PacketNumber: pn,
skippedPacket: true,
})
}
var el *list.Element[*Packet] func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) {
h.checkSequentialPacketNumberUse(pn)
h.highestPacketNumber = pn
if len(h.packets) > 0 {
h.packets = append(h.packets, nil)
}
}
func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) {
h.checkSequentialPacketNumberUse(p.PacketNumber)
h.highestPacketNumber = p.PacketNumber
h.packets = append(h.packets, p)
if p.outstanding() { if p.outstanding() {
el = h.outstandingPacketList.PushBack(p) h.numOutstanding++
} else {
el = h.etcPacketList.PushBack(p)
} }
h.packetMap[p.PacketNumber] = el
}
func (h *sentPacketHistory) registerSentPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) {
if pn <= h.highestSent {
panic("non-sequential packet number use")
}
// Skipped packet numbers.
for p := h.highestSent + 1; p < pn; p++ {
el := h.etcPacketList.PushBack(&Packet{
PacketNumber: p,
EncryptionLevel: encLevel,
SendTime: t,
skippedPacket: true,
})
h.packetMap[p] = el
}
h.highestSent = pn
} }
// Iterate iterates through all packets. // Iterate iterates through all packets.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool, err error)) error {
cont := true for _, p := range h.packets {
outstandingEl := h.outstandingPacketList.Front() if p == nil {
etcEl := h.etcPacketList.Front() continue
var el *list.Element[*Packet]
// whichever has the next packet number is returned first
for cont {
if outstandingEl == nil || (etcEl != nil && etcEl.Value.PacketNumber < outstandingEl.Value.PacketNumber) {
el = etcEl
} else {
el = outstandingEl
} }
if el == nil { cont, err := cb(p)
return nil
}
if el == outstandingEl {
outstandingEl = outstandingEl.Next()
} else {
etcEl = etcEl.Next()
}
var err error
cont, err = cb(el.Value)
if err != nil { if err != nil {
return err return err
} }
if !cont {
return nil
}
} }
return nil return nil
} }
// FirstOutstanding returns the first outstanding packet. // FirstOutstanding returns the first outstanding packet.
func (h *sentPacketHistory) FirstOutstanding() *Packet { func (h *sentPacketHistory) FirstOutstanding() *packet {
el := h.outstandingPacketList.Front() if !h.HasOutstandingPackets() {
if el == nil {
return nil return nil
} }
return el.Value for _, p := range h.packets {
} if p != nil && p.outstanding() {
return p
func (h *sentPacketHistory) Len() int { }
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
} }
el.List().Remove(el)
delete(h.packetMap, p)
return nil return nil
} }
func (h *sentPacketHistory) HasOutstandingPackets() bool { func (h *sentPacketHistory) Len() int {
return h.outstandingPacketList.Len() > 0 return len(h.packets)
} }
func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error {
maxAge := 3 * h.rttStats.PTO(false) idx, ok := h.getIndex(pn)
var nextEl *list.Element[*Packet]
// we don't iterate outstandingPacketList, as we should not delete outstanding packets.
// being outstanding for more than 3*PTO should only happen in the case of drastic RTT changes.
for el := h.etcPacketList.Front(); el != nil; el = nextEl {
nextEl = el.Next()
p := el.Value
if p.SendTime.After(now.Add(-maxAge)) {
break
}
delete(h.packetMap, p.PacketNumber)
h.etcPacketList.Remove(el)
}
}
func (h *sentPacketHistory) DeclareLost(p *Packet) *Packet {
el, ok := h.packetMap[p.PacketNumber]
if !ok { if !ok {
return nil return fmt.Errorf("packet %d not found in sent packet history", pn)
} }
el.List().Remove(el) p := h.packets[idx]
p.declaredLost = true if p.outstanding() {
// move it to the correct position in the etc list (based on the packet number) h.numOutstanding--
for el = h.etcPacketList.Back(); el != nil; el = el.Prev() { if h.numOutstanding < 0 {
if el.Value.PacketNumber < p.PacketNumber { panic("negative number of outstanding packets")
break
} }
} }
if el == nil { h.packets[idx] = nil
el = h.etcPacketList.PushFront(p) // clean up all skipped packets directly before this packet number
} else { for idx > 0 {
el = h.etcPacketList.InsertAfter(p, el) idx--
p := h.packets[idx]
if p == nil || !p.skippedPacket {
break
}
h.packets[idx] = nil
}
if idx == 0 {
h.cleanupStart()
}
if len(h.packets) > 0 && h.packets[0] == nil {
panic("remove failed")
}
return nil
}
// getIndex gets the index of packet p in the packets slice.
func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) {
if len(h.packets) == 0 {
return 0, false
}
first := h.packets[0].PacketNumber
if p < first {
return 0, false
}
index := int(p - first)
if index > len(h.packets)-1 {
return 0, false
}
return index, true
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstanding > 0
}
// delete all nil entries at the beginning of the packets slice
func (h *sentPacketHistory) cleanupStart() {
for i, p := range h.packets {
if p != nil {
h.packets = h.packets[i:]
return
}
}
h.packets = h.packets[:0]
}
func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber {
if len(h.packets) == 0 {
return protocol.InvalidPacketNumber
}
return h.packets[0].PacketNumber
}
func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) {
idx, ok := h.getIndex(pn)
if !ok {
return
}
p := h.packets[idx]
if p.outstanding() {
h.numOutstanding--
if h.numOutstanding < 0 {
panic("negative number of outstanding packets")
}
}
h.packets[idx] = nil
if idx == 0 {
h.cleanupStart()
} }
h.packetMap[p.PacketNumber] = el
return el.Value
} }

View File

@ -120,8 +120,8 @@ func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time {
return c.pacer.TimeUntilSend() return c.pacer.TimeUntilSend()
} }
func (c *cubicSender) HasPacingBudget() bool { func (c *cubicSender) HasPacingBudget(now time.Time) bool {
return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize return c.pacer.Budget(now) >= c.maxDatagramSize
} }
func (c *cubicSender) maxCongestionWindow() protocol.ByteCount { func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {

View File

@ -9,7 +9,7 @@ import (
// A SendAlgorithm performs congestion control // A SendAlgorithm performs congestion control
type SendAlgorithm interface { type SendAlgorithm interface {
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time
HasPacingBudget() bool HasPacingBudget(now time.Time) bool
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
CanSend(bytesInFlight protocol.ByteCount) bool CanSend(bytesInFlight protocol.ByteCount) bool
MaybeExitSlowStart() MaybeExitSlowStart()

View File

@ -12,16 +12,16 @@ const maxBurstSizePackets = 10
// The pacer implements a token bucket pacing algorithm. // The pacer implements a token bucket pacing algorithm.
type pacer struct { type pacer struct {
budgetAtLastSent protocol.ByteCount budgetAtLastSent protocol.ByteCount
maxDatagramSize protocol.ByteCount maxDatagramSize protocol.ByteCount
lastSentTime time.Time lastSentTime time.Time
getAdjustedBandwidth func() uint64 // in bytes/s adjustedBandwidth func() uint64 // in bytes/s
} }
func newPacer(getBandwidth func() Bandwidth) *pacer { func newPacer(getBandwidth func() Bandwidth) *pacer {
p := &pacer{ p := &pacer{
maxDatagramSize: initialMaxDatagramSize, maxDatagramSize: initialMaxDatagramSize,
getAdjustedBandwidth: func() uint64 { adjustedBandwidth: func() uint64 {
// Bandwidth is in bits/s. We need the value in bytes/s. // Bandwidth is in bits/s. We need the value in bytes/s.
bw := uint64(getBandwidth() / BytesPerSecond) bw := uint64(getBandwidth() / BytesPerSecond)
// Use a slightly higher value than the actual measured bandwidth. // Use a slightly higher value than the actual measured bandwidth.
@ -49,13 +49,16 @@ func (p *pacer) Budget(now time.Time) protocol.ByteCount {
if p.lastSentTime.IsZero() { if p.lastSentTime.IsZero() {
return p.maxBurstSize() return p.maxBurstSize()
} }
budget := p.budgetAtLastSent + (protocol.ByteCount(p.getAdjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 budget := p.budgetAtLastSent + (protocol.ByteCount(p.adjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
if budget < 0 { // protect against overflows
budget = protocol.MaxByteCount
}
return utils.Min(p.maxBurstSize(), budget) return utils.Min(p.maxBurstSize(), budget)
} }
func (p *pacer) maxBurstSize() protocol.ByteCount { func (p *pacer) maxBurstSize() protocol.ByteCount {
return utils.Max( return utils.Max(
protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9,
maxBurstSizePackets*p.maxDatagramSize, maxBurstSizePackets*p.maxDatagramSize,
) )
} }
@ -68,7 +71,7 @@ func (p *pacer) TimeUntilSend() time.Time {
} }
return p.lastSentTime.Add(utils.Max( return p.lastSentTime.Add(utils.Max(
protocol.MinPacingDelay, protocol.MinPacingDelay,
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.adjustedBandwidth())))*time.Nanosecond,
)) ))
} }

View File

@ -116,7 +116,7 @@ type cryptoSetup struct {
clientHelloWritten bool clientHelloWritten bool
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
zeroRTTParametersChan chan<- *wire.TransportParameters zeroRTTParametersChan chan<- *wire.TransportParameters
allow0RTT func() bool allow0RTT bool
rttStats *utils.RTTStats rttStats *utils.RTTStats
@ -197,7 +197,7 @@ func NewCryptoSetupServer(
tp *wire.TransportParameters, tp *wire.TransportParameters,
runner handshakeRunner, runner handshakeRunner,
tlsConf *tls.Config, tlsConf *tls.Config,
allow0RTT func() bool, allow0RTT bool,
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
tracer logging.ConnectionTracer, tracer logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
@ -210,14 +210,13 @@ func NewCryptoSetupServer(
tp, tp,
runner, runner,
tlsConf, tlsConf,
allow0RTT != nil, allow0RTT,
rttStats, rttStats,
tracer, tracer,
logger, logger,
protocol.PerspectiveServer, protocol.PerspectiveServer,
version, version,
) )
cs.allow0RTT = allow0RTT
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
return cs return cs
} }
@ -253,6 +252,7 @@ func newCryptoSetup(
readEncLevel: protocol.EncryptionInitial, readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial,
runner: runner, runner: runner,
allow0RTT: enable0RTT,
ourParams: tp, ourParams: tp,
paramsChan: extHandler.TransportParameters(), paramsChan: extHandler.TransportParameters(),
rttStats: rttStats, rttStats: rttStats,
@ -503,7 +503,7 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
return false return false
} }
if !h.allow0RTT() { if !h.allow0RTT {
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.") h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
return false return false
} }

View File

@ -5,6 +5,9 @@ import "time"
// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. // DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use.
const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB
// DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use.
const DesiredSendBufferSize = (1 << 20) * 2 // 2 MB
// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. // InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
const InitialPacketSizeIPv4 = 1252 const InitialPacketSizeIPv4 = 1252

View File

@ -59,7 +59,10 @@ type StatelessResetToken [16]byte
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. // UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. // Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452.
const MaxPacketBufferSize ByteCount = 1452 const MaxPacketBufferSize = 1452
// MaxLargePacketBufferSize is used when using GSO
const MaxLargePacketBufferSize = 20 * 1024
// MinInitialPacketSize is the minimum size an Initial packet is required to have. // MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200 const MinInitialPacketSize = 1200
@ -77,6 +80,9 @@ const MinConnectionIDLenInitial = 8
// DefaultAckDelayExponent is the default ack delay exponent // DefaultAckDelayExponent is the default ack delay exponent
const DefaultAckDelayExponent = 3 const DefaultAckDelayExponent = 3
// DefaultActiveConnectionIDLimit is the default active connection ID limit
const DefaultActiveConnectionIDLimit = 2
// MaxAckDelayExponent is the maximum ack delay exponent // MaxAckDelayExponent is the maximum ack delay exponent
const MaxAckDelayExponent = 20 const MaxAckDelayExponent = 20

View File

@ -0,0 +1,86 @@
package ringbuffer
// A RingBuffer is a ring buffer.
// It acts as a heap that doesn't cause any allocations.
type RingBuffer[T any] struct {
ring []T
headPos, tailPos int
full bool
}
// Init preallocs a buffer with a certain size.
func (r *RingBuffer[T]) Init(size int) {
r.ring = make([]T, size)
}
// Len returns the number of elements in the ring buffer.
func (r *RingBuffer[T]) Len() int {
if r.full {
return len(r.ring)
}
if r.tailPos >= r.headPos {
return r.tailPos - r.headPos
}
return r.tailPos - r.headPos + len(r.ring)
}
// Empty says if the ring buffer is empty.
func (r *RingBuffer[T]) Empty() bool {
return !r.full && r.headPos == r.tailPos
}
// PushBack adds a new element.
// If the ring buffer is full, its capacity is increased first.
func (r *RingBuffer[T]) PushBack(t T) {
if r.full || len(r.ring) == 0 {
r.grow()
}
r.ring[r.tailPos] = t
r.tailPos++
if r.tailPos == len(r.ring) {
r.tailPos = 0
}
if r.tailPos == r.headPos {
r.full = true
}
}
// PopFront returns the next element.
// It must not be called when the buffer is empty, that means that
// callers might need to check if there are elements in the buffer first.
func (r *RingBuffer[T]) PopFront() T {
if r.Empty() {
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue")
}
r.full = false
t := r.ring[r.headPos]
r.ring[r.headPos] = *new(T)
r.headPos++
if r.headPos == len(r.ring) {
r.headPos = 0
}
return t
}
// Grow the maximum size of the queue.
// This method assume the queue is full.
func (r *RingBuffer[T]) grow() {
oldRing := r.ring
newSize := len(oldRing) * 2
if newSize == 0 {
newSize = 1
}
r.ring = make([]T, newSize)
headLen := copy(r.ring, oldRing[r.headPos:])
copy(r.ring[headLen:], oldRing[:r.headPos])
r.headPos, r.tailPos, r.full = 0, len(oldRing), false
}
// Clear removes all elements.
func (r *RingBuffer[T]) Clear() {
var zeroValue T
for i := range r.ring {
r.ring[i] = zeroValue
}
r.headPos, r.tailPos, r.full = 0, 0, false
}

View File

@ -103,8 +103,12 @@ func (r *RTTStats) SetMaxAckDelay(mad time.Duration) {
// SetInitialRTT sets the initial RTT. // SetInitialRTT sets the initial RTT.
// It is used during the 0-RTT handshake when restoring the RTT stats from the session state. // It is used during the 0-RTT handshake when restoring the RTT stats from the session state.
func (r *RTTStats) SetInitialRTT(t time.Duration) { func (r *RTTStats) SetInitialRTT(t time.Duration) {
// On the server side, by the time we get to process the session ticket,
// we might already have obtained an RTT measurement.
// This can happen if we received the ClientHello in multiple pieces, and one of those pieces was lost.
// Discard the restored value. A fresh measurement is always better.
if r.hasMeasurement { if r.hasMeasurement {
panic("initial RTT set after first measurement") return
} }
r.smoothedRTT = t r.smoothedRTT = t
r.latestRTT = t r.latestRTT = t

View File

@ -22,19 +22,17 @@ type AckFrame struct {
} }
// parseAckFrame reads an ACK frame // parseAckFrame reads an ACK frame
func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error {
ecn := typ == ackECNFrameType ecn := typ == ackECNFrameType
frame := GetAckFrame()
la, err := quicvarint.Read(r) la, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
largestAcked := protocol.PacketNumber(la) largestAcked := protocol.PacketNumber(la)
delay, err := quicvarint.Read(r) delay, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
@ -46,17 +44,17 @@ func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protoc
numBlocks, err := quicvarint.Read(r) numBlocks, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
// read the first ACK range // read the first ACK range
ab, err := quicvarint.Read(r) ab, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
ackBlock := protocol.PacketNumber(ab) ackBlock := protocol.PacketNumber(ab)
if ackBlock > largestAcked { if ackBlock > largestAcked {
return nil, errors.New("invalid first ACK range") return errors.New("invalid first ACK range")
} }
smallest := largestAcked - ackBlock smallest := largestAcked - ackBlock
@ -65,41 +63,50 @@ func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protoc
for i := uint64(0); i < numBlocks; i++ { for i := uint64(0); i < numBlocks; i++ {
g, err := quicvarint.Read(r) g, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
gap := protocol.PacketNumber(g) gap := protocol.PacketNumber(g)
if smallest < gap+2 { if smallest < gap+2 {
return nil, errInvalidAckRanges return errInvalidAckRanges
} }
largest := smallest - gap - 2 largest := smallest - gap - 2
ab, err := quicvarint.Read(r) ab, err := quicvarint.Read(r)
if err != nil { if err != nil {
return nil, err return err
} }
ackBlock := protocol.PacketNumber(ab) ackBlock := protocol.PacketNumber(ab)
if ackBlock > largest { if ackBlock > largest {
return nil, errInvalidAckRanges return errInvalidAckRanges
} }
smallest = largest - ackBlock smallest = largest - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
} }
if !frame.validateAckRanges() { if !frame.validateAckRanges() {
return nil, errInvalidAckRanges return errInvalidAckRanges
} }
// parse (and skip) the ECN section
if ecn { if ecn {
for i := 0; i < 3; i++ { ect0, err := quicvarint.Read(r)
if _, err := quicvarint.Read(r); err != nil { if err != nil {
return nil, err return err
}
} }
frame.ECT0 = ect0
ect1, err := quicvarint.Read(r)
if err != nil {
return err
}
frame.ECT1 = ect1
ecnce, err := quicvarint.Read(r)
if err != nil {
return err
}
frame.ECNCE = ecnce
} }
return frame, nil return nil
} }
// Append appends an ACK frame. // Append appends an ACK frame.
@ -242,6 +249,18 @@ func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool {
return p <= f.AckRanges[i].Largest return p <= f.AckRanges[i].Largest
} }
func (f *AckFrame) Reset() {
f.DelayTime = 0
f.ECT0 = 0
f.ECT1 = 0
f.ECNCE = 0
for _, r := range f.AckRanges {
r.Largest = 0
r.Smallest = 0
}
f.AckRanges = f.AckRanges[:0]
}
func encodeAckDelay(delay time.Duration) uint64 { func encodeAckDelay(delay time.Duration) uint64 {
return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent)))
} }

View File

@ -1,24 +0,0 @@
package wire
import "sync"
var ackFramePool = sync.Pool{New: func() any {
return &AckFrame{}
}}
func GetAckFrame() *AckFrame {
f := ackFramePool.Get().(*AckFrame)
f.AckRanges = f.AckRanges[:0]
f.ECNCE = 0
f.ECT0 = 0
f.ECT1 = 0
f.DelayTime = 0
return f
}
func PutAckFrame(f *AckFrame) {
if cap(f.AckRanges) > 4 {
return
}
ackFramePool.Put(f)
}

View File

@ -39,9 +39,12 @@ const (
type frameParser struct { type frameParser struct {
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
ackDelayExponent uint8 ackDelayExponent uint8
supportsDatagrams bool supportsDatagrams bool
// To avoid allocating when parsing, keep a single ACK frame struct.
// It is used over and over again.
ackFrame *AckFrame
} }
var _ FrameParser = &frameParser{} var _ FrameParser = &frameParser{}
@ -51,6 +54,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser {
return &frameParser{ return &frameParser{
r: *bytes.NewReader(nil), r: *bytes.NewReader(nil),
supportsDatagrams: supportsDatagrams, supportsDatagrams: supportsDatagrams,
ackFrame: &AckFrame{},
} }
} }
@ -105,7 +109,9 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
if encLevel != protocol.Encryption1RTT { if encLevel != protocol.Encryption1RTT {
ackDelayExponent = protocol.DefaultAckDelayExponent ackDelayExponent = protocol.DefaultAckDelayExponent
} }
frame, err = parseAckFrame(r, typ, ackDelayExponent, v) p.ackFrame.Reset()
err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v)
frame = p.ackFrame
case resetStreamFrameType: case resetStreamFrameType:
frame, err = parseResetStreamFrame(r, v) frame, err = parseResetStreamFrame(r, v)
case stopSendingFrameType: case stopSendingFrameType:

View File

@ -13,8 +13,6 @@ import (
) )
// ParseConnectionID parses the destination connection ID of a packet. // ParseConnectionID parses the destination connection ID of a packet.
// It uses the data slice for the connection ID.
// That means that the connection ID must not be used after the packet buffer is released.
func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
if len(data) == 0 { if len(data) == 0 {
return protocol.ConnectionID{}, io.EOF return protocol.ConnectionID{}, io.EOF

View File

@ -2,14 +2,13 @@ package wire
import ( import (
"bytes" "bytes"
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/rand"
"net" "net"
"sort" "sort"
"sync"
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -26,15 +25,6 @@ var AdditionalTransportParametersClient map[uint64][]byte
const transportParameterMarshalingVersion = 1 const transportParameterMarshalingVersion = 1
var (
randomMutex sync.Mutex
random rand.Rand
)
func init() {
random = *rand.New(rand.NewSource(time.Now().UnixNano()))
}
type transportParameterID uint64 type transportParameterID uint64
const ( const (
@ -118,6 +108,7 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
var ( var (
readOriginalDestinationConnectionID bool readOriginalDestinationConnectionID bool
readInitialSourceConnectionID bool readInitialSourceConnectionID bool
readActiveConnectionIDLimit bool
) )
p.AckDelayExponent = protocol.DefaultAckDelayExponent p.AckDelayExponent = protocol.DefaultAckDelayExponent
@ -139,6 +130,9 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
} }
parameterIDs = append(parameterIDs, paramID) parameterIDs = append(parameterIDs, paramID)
switch paramID { switch paramID {
case activeConnectionIDLimitParameterID:
readActiveConnectionIDLimit = true
fallthrough
case maxIdleTimeoutParameterID, case maxIdleTimeoutParameterID,
maxUDPPayloadSizeParameterID, maxUDPPayloadSizeParameterID,
initialMaxDataParameterID, initialMaxDataParameterID,
@ -148,7 +142,6 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
initialMaxStreamsBidiParameterID, initialMaxStreamsBidiParameterID,
initialMaxStreamsUniParameterID, initialMaxStreamsUniParameterID,
maxAckDelayParameterID, maxAckDelayParameterID,
activeConnectionIDLimitParameterID,
maxDatagramFrameSizeParameterID, maxDatagramFrameSizeParameterID,
ackDelayExponentParameterID: ackDelayExponentParameterID:
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
@ -196,6 +189,9 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
} }
} }
if !readActiveConnectionIDLimit {
p.ActiveConnectionIDLimit = protocol.DefaultActiveConnectionIDLimit
}
if !fromSessionTicket { if !fromSessionTicket {
if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID {
return errors.New("missing original_destination_connection_id") return errors.New("missing original_destination_connection_id")
@ -335,13 +331,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
b := make([]byte, 0, 256) b := make([]byte, 0, 256)
// add a greased value // add a greased value
b = quicvarint.Append(b, uint64(27+31*rand.Intn(100))) random := make([]byte, 18)
randomMutex.Lock() rand.Read(random)
length := random.Intn(16) b = quicvarint.Append(b, 27+31*uint64(random[0]))
length := random[1] % 16
b = quicvarint.Append(b, uint64(length)) b = quicvarint.Append(b, uint64(length))
b = b[:len(b)+length] b = append(b, random[2:2+length]...)
random.Read(b[len(b)-length:])
randomMutex.Unlock()
// initial_max_stream_data_bidi_local // initial_max_stream_data_bidi_local
b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
@ -402,7 +397,9 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
} }
} }
// active_connection_id_limit // active_connection_id_limit
b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) if p.ActiveConnectionIDLimit != protocol.DefaultActiveConnectionIDLimit {
b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
}
// initial_source_connection_id // initial_source_connection_id
b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID)) b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID))
b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len())) b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len()))

View File

@ -3,7 +3,6 @@
package logging package logging
import ( import (
"context"
"net" "net"
"time" "time"
@ -101,12 +100,6 @@ type ShortHeader struct {
// A Tracer traces events. // A Tracer traces events.
type Tracer interface { type Tracer interface {
// TracerForConnection requests a new tracer for a connection.
// The ODCID is the original destination connection ID:
// The destination connection ID that the client used on the first Initial packet it sent on this connection.
// If nil is returned, tracing will be disabled for this connection.
TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer
SentPacket(net.Addr, *Header, ByteCount, []Frame) SentPacket(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason)

View File

@ -1,7 +1,6 @@
package logging package logging
import ( import (
"context"
"net" "net"
"time" "time"
) )
@ -23,16 +22,6 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer {
return &tracerMultiplexer{tracers} return &tracerMultiplexer{tracers}
} }
func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer {
var connTracers []ConnectionTracer
for _, t := range m.tracers {
if ct := t.TracerForConnection(ctx, p, odcid); ct != nil {
connTracers = append(connTracers, ct)
}
}
return NewMultiplexedConnectionTracer(connTracers...)
}
func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) {
for _, t := range m.tracers { for _, t := range m.tracers {
t.SentPacket(remote, hdr, size, frames) t.SentPacket(remote, hdr, size, frames)

View File

@ -1,7 +1,6 @@
package logging package logging
import ( import (
"context"
"net" "net"
"time" "time"
) )
@ -12,9 +11,6 @@ type NullTracer struct{}
var _ Tracer = &NullTracer{} var _ Tracer = &NullTracer{}
func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer {
return NullConnectionTracer{}
}
func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {}
func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
} }

View File

@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
type PacketHandlerManager = packetHandlerManager type PacketHandlerManager = packetHandlerManager
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer"
type Multiplexer = multiplexer
// Need to use source mode for the batchConn, since reflect mode follows type aliases. // Need to use source mode for the batchConn, since reflect mode follows type aliases.
// See https://github.com/golang/mock/issues/244 for details. // See https://github.com/golang/mock/issues/244 for details.
// //

View File

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"net"
"time" "time"
"github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/ackhandler"
@ -10,7 +11,11 @@ import (
) )
type mtuDiscoverer interface { type mtuDiscoverer interface {
// Start starts the MTU discovery process.
// It's unnecessary to call ShouldSendProbe before that.
Start(maxPacketSize protocol.ByteCount)
ShouldSendProbe(now time.Time) bool ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount) GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount)
} }
@ -22,25 +27,38 @@ const (
mtuProbeDelay = 5 mtuProbeDelay = 5
) )
func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
// If this is not a UDP address, we don't know anything about the MTU.
// Use the minimum size of an Initial packet as the max packet size.
if udpAddr, ok := addr.(*net.UDPAddr); ok {
if utils.IsIPv4(udpAddr.IP) {
maxSize = protocol.InitialPacketSizeIPv4
} else {
maxSize = protocol.InitialPacketSizeIPv6
}
}
return maxSize
}
type mtuFinder struct { type mtuFinder struct {
lastProbeTime time.Time lastProbeTime time.Time
probeInFlight bool
mtuIncreased func(protocol.ByteCount) mtuIncreased func(protocol.ByteCount)
rttStats *utils.RTTStats rttStats *utils.RTTStats
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
current protocol.ByteCount current protocol.ByteCount
max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer)
} }
var _ mtuDiscoverer = &mtuFinder{} var _ mtuDiscoverer = &mtuFinder{}
func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) mtuDiscoverer { func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
return &mtuFinder{ return &mtuFinder{
current: start, inFlight: protocol.InvalidByteCount,
rttStats: rttStats, current: start,
lastProbeTime: time.Now(), // to make sure the first probe packet is not sent immediately rttStats: rttStats,
mtuIncreased: mtuIncreased, mtuIncreased: mtuIncreased,
max: max,
} }
} }
@ -48,8 +66,16 @@ func (f *mtuFinder) done() bool {
return f.max-f.current <= maxMTUDiff+1 return f.max-f.current <= maxMTUDiff+1
} }
func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) {
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
f.max = maxPacketSize
}
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
if f.probeInFlight || f.done() { if f.max == 0 || f.lastProbeTime.IsZero() {
return false
}
if f.inFlight != protocol.InvalidByteCount || f.done() {
return false return false
} }
return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT()))
@ -58,17 +84,36 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
size := (f.max + f.current) / 2 size := (f.max + f.current) / 2
f.lastProbeTime = time.Now() f.lastProbeTime = time.Now()
f.probeInFlight = true f.inFlight = size
return ackhandler.Frame{ return ackhandler.Frame{
Frame: &wire.PingFrame{}, Frame: &wire.PingFrame{},
OnLost: func(wire.Frame) { Handler: (*mtuFinderAckHandler)(f),
f.probeInFlight = false
f.max = size
},
OnAcked: func(wire.Frame) {
f.probeInFlight = false
f.current = size
f.mtuIncreased(size)
},
}, size }, size
} }
func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.current
}
type mtuFinderAckHandler mtuFinder
var _ ackhandler.FrameHandler = &mtuFinderAckHandler{}
func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnAcked callback called although there's no MTU probe packet in flight")
}
h.inFlight = protocol.InvalidByteCount
h.current = size
h.mtuIncreased(size)
}
func (h *mtuFinderAckHandler) OnLost(wire.Frame) {
size := h.inFlight
if size == protocol.InvalidByteCount {
panic("OnLost callback called although there's no MTU probe packet in flight")
}
h.max = size
h.inFlight = protocol.InvalidByteCount
}

View File

@ -6,7 +6,6 @@ import (
"sync" "sync"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
) )
var ( var (
@ -14,30 +13,19 @@ var (
connMuxer multiplexer connMuxer multiplexer
) )
type indexableConn interface { type indexableConn interface{ LocalAddr() net.Addr }
LocalAddr() net.Addr
}
type multiplexer interface { type multiplexer interface {
AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error) AddConn(conn indexableConn)
RemoveConn(indexableConn) error RemoveConn(indexableConn) error
} }
type connManager struct {
connIDLen int
statelessResetKey *StatelessResetKey
tracer logging.Tracer
manager packetHandlerManager
}
// The connMultiplexer listens on multiple net.PacketConns and dispatches // The connMultiplexer listens on multiple net.PacketConns and dispatches
// incoming packets to the connection handler. // incoming packets to the connection handler.
type connMultiplexer struct { type connMultiplexer struct {
mutex sync.Mutex mutex sync.Mutex
conns map[string] /* LocalAddr().String() */ connManager conns map[string] /* LocalAddr().String() */ indexableConn
newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests
logger utils.Logger logger utils.Logger
} }
@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{}
func getMultiplexer() multiplexer { func getMultiplexer() multiplexer {
connMuxerOnce.Do(func() { connMuxerOnce.Do(func() {
connMuxer = &connMultiplexer{ connMuxer = &connMultiplexer{
conns: make(map[string]connManager), conns: make(map[string]indexableConn),
logger: utils.DefaultLogger.WithPrefix("muxer"), logger: utils.DefaultLogger.WithPrefix("muxer"),
newPacketHandlerManager: newPacketHandlerMap,
} }
}) })
return connMuxer return connMuxer
} }
func (m *connMultiplexer) AddConn( func (m *connMultiplexer) index(addr net.Addr) string {
c net.PacketConn, return addr.Network() + " " + addr.String()
connIDLen int, }
statelessResetKey *StatelessResetKey,
tracer logging.Tracer, func (m *connMultiplexer) AddConn(c indexableConn) {
) (packetHandlerManager, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
addr := c.LocalAddr() connIndex := m.index(c.LocalAddr())
connIndex := addr.Network() + " " + addr.String()
p, ok := m.conns[connIndex] p, ok := m.conns[connIndex]
if !ok { if ok {
manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) // Panics if we're already listening on this connection.
if err != nil { // This is a safeguard because we're introducing a breaking API change, see
return nil, err // https://github.com/quic-go/quic-go/issues/3727 for details.
} // We'll remove this at a later time, when most users of the library have made the switch.
p = connManager{ panic("connection already exists") // TODO: write a nice message
connIDLen: connIDLen,
statelessResetKey: statelessResetKey,
manager: manager,
tracer: tracer,
}
m.conns[connIndex] = p
} else {
if p.connIDLen != connIDLen {
return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
}
if statelessResetKey != nil && p.statelessResetKey != statelessResetKey {
return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn")
}
if tracer != p.tracer {
return nil, fmt.Errorf("cannot use different tracers on the same packet conn")
}
} }
return p.manager, nil m.conns[connIndex] = p
} }
func (m *connMultiplexer) RemoveConn(c indexableConn) error { func (m *connMultiplexer) RemoveConn(c indexableConn) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() connIndex := m.index(c.LocalAddr())
if _, ok := m.conns[connIndex]; !ok { if _, ok := m.conns[connIndex]; !ok {
return fmt.Errorf("cannote remove connection, connection is unknown") return fmt.Errorf("cannote remove connection, connection is unknown")
} }

View File

@ -5,147 +5,86 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"errors" "errors"
"fmt"
"hash" "hash"
"io" "io"
"log"
"net" "net"
"os"
"strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
) )
// rawConn is a connection that allow reading of a receivedPacket. type connCapabilities struct {
// This connection has the Don't Fragment (DF) bit set.
// This means it makes to run DPLPMTUD.
DF bool
// GSO (Generic Segmentation Offload) supported
GSO bool
}
// rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface { type rawConn interface {
ReadPacket() (*receivedPacket, error) ReadPacket() (receivedPacket, error)
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) // The size parameter is used for GSO.
// If GSO is not support, len(b) must be equal to size.
WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error)
LocalAddr() net.Addr LocalAddr() net.Addr
SetReadDeadline(time.Time) error
io.Closer io.Closer
capabilities() connCapabilities
} }
type closePacket struct { type closePacket struct {
payload []byte payload []byte
addr net.Addr addr net.Addr
info *packetInfo info packetInfo
} }
// The packetHandlerMap stores packetHandlers, identified by connection ID. type unknownPacketHandler interface {
// It is used: handlePacket(receivedPacket)
// * by the server to store connections setCloseError(error)
// * when multiplexing outgoing connections to store clients }
var errListenerAlreadySet = errors.New("listener already set")
type packetHandlerMap struct { type packetHandlerMap struct {
mutex sync.Mutex mutex sync.Mutex
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
conn rawConn
connIDLen int
closeQueue chan closePacket
handlers map[protocol.ConnectionID]packetHandler
resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler
server unknownPacketHandler
numZeroRTTEntries int
listening chan struct{} // is closed when listen returns
closed bool closed bool
closeChan chan struct{}
enqueueClosePacket func(closePacket)
deleteRetiredConnsAfter time.Duration deleteRetiredConnsAfter time.Duration
zeroRTTQueueDuration time.Duration
statelessResetEnabled bool statelessResetMutex sync.Mutex
statelessResetMutex sync.Mutex statelessResetHasher hash.Hash
statelessResetHasher hash.Hash
tracer logging.Tracer
logger utils.Logger logger utils.Logger
} }
var _ packetHandlerManager = &packetHandlerMap{} var _ packetHandlerManager = &packetHandlerMap{}
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap {
conn, ok := c.(interface{ SetReadBuffer(int) error }) h := &packetHandlerMap{
if !ok { closeChan: make(chan struct{}),
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
}
size, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if size >= protocol.DesiredReceiveBufferSize {
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
return nil
}
if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
return fmt.Errorf("failed to increase receive buffer size: %w", err)
}
newSize, err := inspectReadBuffer(c)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredReceiveBufferSize {
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
return nil
}
// only print warnings about the UDP receive buffer size once
var receiveBufferWarningOnce sync.Once
func newPacketHandlerMap(
c net.PacketConn,
connIDLen int,
statelessResetKey *StatelessResetKey,
tracer logging.Tracer,
logger utils.Logger,
) (packetHandlerManager, error) {
if err := setReceiveBuffer(c, logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
receiveBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err)
})
}
}
conn, err := wrapConn(c)
if err != nil {
return nil, err
}
m := &packetHandlerMap{
conn: conn,
connIDLen: connIDLen,
listening: make(chan struct{}),
handlers: make(map[protocol.ConnectionID]packetHandler), handlers: make(map[protocol.ConnectionID]packetHandler),
resetTokens: make(map[protocol.StatelessResetToken]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler),
deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout,
zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, enqueueClosePacket: enqueueClosePacket,
closeQueue: make(chan closePacket, 4),
statelessResetEnabled: statelessResetKey != nil,
tracer: tracer,
logger: logger, logger: logger,
} }
if m.statelessResetEnabled { if key != nil {
m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:]) h.statelessResetHasher = hmac.New(sha256.New, key[:])
} }
go m.listen() if h.logger.Debug() {
go m.runCloseQueue() go h.logUsage()
if logger.Debug() {
go m.logUsage()
} }
return m, nil return h
} }
func (h *packetHandlerMap) logUsage() { func (h *packetHandlerMap) logUsage() {
@ -153,7 +92,7 @@ func (h *packetHandlerMap) logUsage() {
var printedZero bool var printedZero bool
for { for {
select { select {
case <-h.listening: case <-h.closeChan:
return return
case <-ticker.C: case <-ticker.C:
} }
@ -174,6 +113,14 @@ func (h *packetHandlerMap) logUsage() {
} }
} }
func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
h.mutex.Lock()
defer h.mutex.Unlock()
handler, ok := h.handlers[id]
return handler, ok
}
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
@ -187,26 +134,17 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
return true return true
} }
func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool {
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
var q *zeroRTTQueue if _, ok := h.handlers[clientDestConnID]; ok {
if handler, ok := h.handlers[clientDestConnID]; ok { h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
q, ok = handler.(*zeroRTTQueue) return false
if !ok {
h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID)
return false
}
q.retireTimer.Stop()
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
} }
conn := fn() conn, ok := fn()
if q != nil { if !ok {
q.EnqueueAll(conn) return false
} }
h.handlers[clientDestConnID] = conn h.handlers[clientDestConnID] = conn
h.handlers[newConnID] = conn h.handlers[newConnID] = conn
@ -239,13 +177,8 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
var handler packetHandler var handler packetHandler
if connClosePacket != nil { if connClosePacket != nil {
handler = newClosedLocalConn( handler = newClosedLocalConn(
func(addr net.Addr, info *packetInfo) { func(addr net.Addr, info packetInfo) {
select { h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info})
case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}:
default:
// Oops, we're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
}, },
pers, pers,
h.logger, h.logger,
@ -272,17 +205,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p
}) })
} }
func (h *packetHandlerMap) runCloseQueue() {
for {
select {
case <-h.listening:
return
case p := <-h.closeQueue:
h.conn.WritePacket(p.payload, p.addr, p.info.OOB())
}
}
}
func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) {
h.mutex.Lock() h.mutex.Lock()
h.resetTokens[token] = handler h.resetTokens[token] = handler
@ -295,19 +217,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken)
h.mutex.Unlock() h.mutex.Unlock()
} }
func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) {
h.mutex.Lock() h.mutex.Lock()
h.server = s defer h.mutex.Unlock()
h.mutex.Unlock()
handler, ok := h.resetTokens[token]
return handler, ok
} }
func (h *packetHandlerMap) CloseServer() { func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock() h.mutex.Lock()
if h.server == nil {
h.mutex.Unlock()
return
}
h.server = nil
var wg sync.WaitGroup var wg sync.WaitGroup
for _, handler := range h.handlers { for _, handler := range h.handlers {
if handler.getPerspective() == protocol.PerspectiveServer { if handler.getPerspective() == protocol.PerspectiveServer {
@ -323,23 +242,16 @@ func (h *packetHandlerMap) CloseServer() {
wg.Wait() wg.Wait()
} }
// Destroy closes the underlying connection and waits until listen() has returned. func (h *packetHandlerMap) Close(e error) {
// It does not close active connections.
func (h *packetHandlerMap) Destroy() error {
if err := h.conn.Close(); err != nil {
return err
}
<-h.listening // wait until listening returns
return nil
}
func (h *packetHandlerMap) close(e error) error {
h.mutex.Lock() h.mutex.Lock()
if h.closed { if h.closed {
h.mutex.Unlock() h.mutex.Unlock()
return nil return
} }
close(h.closeChan)
var wg sync.WaitGroup var wg sync.WaitGroup
for _, handler := range h.handlers { for _, handler := range h.handlers {
wg.Add(1) wg.Add(1)
@ -348,129 +260,14 @@ func (h *packetHandlerMap) close(e error) error {
wg.Done() wg.Done()
}(handler) }(handler)
} }
if h.server != nil {
h.server.setCloseError(e)
}
h.closed = true h.closed = true
h.mutex.Unlock() h.mutex.Unlock()
wg.Wait() wg.Wait()
return getMultiplexer().RemoveConn(h.conn)
}
func (h *packetHandlerMap) listen() {
defer close(h.listening)
for {
p, err := h.conn.ReadPacket()
//nolint:staticcheck // SA1019 ignore this!
// TODO: This code is used to ignore wsa errors on Windows.
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
// See https://github.com/quic-go/quic-go/issues/1737 for details.
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
h.logger.Debugf("Temporary error reading from conn: %w", err)
continue
}
if err != nil {
h.close(err)
return
}
h.handlePacket(p)
}
}
func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, h.connIDLen)
if err != nil {
h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
if h.tracer != nil {
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
p.buffer.MaybeRelease()
return
}
h.mutex.Lock()
defer h.mutex.Unlock()
if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if handler, ok := h.handlers[connID]; ok {
if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue
if wire.Is0RTTPacket(p.data) {
ha.handlePacket(p)
return
}
} else { // existing connection
handler.handlePacket(p)
return
}
}
if !wire.IsLongHeaderPacket(p.data[0]) {
go h.maybeSendStatelessReset(p, connID)
return
}
if h.server == nil { // no server set
h.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
return
}
if wire.Is0RTTPacket(p.data) {
if h.numZeroRTTEntries >= protocol.Max0RTTQueues {
if h.tracer != nil {
h.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return
}
h.numZeroRTTEntries++
queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)}
h.handlers[connID] = queue
queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() {
h.mutex.Lock()
defer h.mutex.Unlock()
// The entry might have been replaced by an actual connection.
// Only delete it if it's still a 0-RTT queue.
if handler, ok := h.handlers[connID]; ok {
if q, ok := handler.(*zeroRTTQueue); ok {
delete(h.handlers, connID)
h.numZeroRTTEntries--
if h.numZeroRTTEntries < 0 {
panic("number of 0-RTT queues < 0")
}
q.Clear()
if h.logger.Debug() {
h.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
}
})
queue.handlePacket(p)
return
}
h.server.handlePacket(p)
}
func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
return false
}
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go sess.destroy(&StatelessResetError{Token: token})
return true
}
return false
} }
func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken {
var token protocol.StatelessResetToken var token protocol.StatelessResetToken
if !h.statelessResetEnabled { if h.statelessResetHasher == nil {
// Return a random stateless reset token. // Return a random stateless reset token.
// This token will be sent in the server's transport parameters. // This token will be sent in the server's transport parameters.
// By using a random token, an off-path attacker won't be able to disrupt the connection. // By using a random token, an off-path attacker won't be able to disrupt the connection.
@ -484,24 +281,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID)
h.statelessResetMutex.Unlock() h.statelessResetMutex.Unlock()
return token return token
} }
func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) {
defer p.buffer.Release()
if !h.statelessResetEnabled {
return
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {
return
}
token := h.GetStatelessResetToken(connID)
h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
h.logger.Debugf("Error sending Stateless Reset: %s", err)
}
}

View File

@ -3,30 +3,25 @@ package quic
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"time"
"github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
) )
var errNothingToPack = errors.New("nothing to pack") var errNothingToPack = errors.New("nothing to pack")
type packer interface { type packer interface {
PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error)
PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
MaybePackProbePacket(protocol.EncryptionLevel, protocol.VersionNumber) (*coalescedPacket, error) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error)
PackConnectionClose(*qerr.TransportError, protocol.VersionNumber) (*coalescedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.VersionNumber) (*coalescedPacket, error) PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
SetMaxPacketSize(protocol.ByteCount)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
HandleTransportParameters(*wire.TransportParameters)
SetToken([]byte) SetToken([]byte)
} }
@ -35,26 +30,31 @@ type sealer interface {
} }
type payload struct { type payload struct {
frames []*ackhandler.Frame streamFrames []ackhandler.StreamFrame
ack *wire.AckFrame frames []ackhandler.Frame
length protocol.ByteCount ack *wire.AckFrame
length protocol.ByteCount
} }
type longHeaderPacket struct { type longHeaderPacket struct {
header *wire.ExtendedHeader header *wire.ExtendedHeader
ack *wire.AckFrame ack *wire.AckFrame
frames []*ackhandler.Frame frames []ackhandler.Frame
streamFrames []ackhandler.StreamFrame // only used for 0-RTT packets
length protocol.ByteCount length protocol.ByteCount
isMTUProbePacket bool
} }
type shortHeaderPacket struct { type shortHeaderPacket struct {
*ackhandler.Packet PacketNumber protocol.PacketNumber
Frames []ackhandler.Frame
StreamFrames []ackhandler.StreamFrame
Ack *wire.AckFrame
Length protocol.ByteCount
IsPathMTUProbePacket bool
// used for logging // used for logging
DestConnID protocol.ConnectionID DestConnID protocol.ConnectionID
Ack *wire.AckFrame
PacketNumberLen protocol.PacketNumberLen PacketNumberLen protocol.PacketNumberLen
KeyPhase protocol.KeyPhaseBit KeyPhase protocol.KeyPhaseBit
} }
@ -83,52 +83,6 @@ func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel {
func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) } func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) }
func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet {
largestAcked := protocol.InvalidPacketNumber
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
}
encLevel := p.EncryptionLevel()
for i := range p.frames {
if p.frames[i].OnLost != nil {
continue
}
//nolint:exhaustive // Short header packets are handled separately.
switch encLevel {
case protocol.EncryptionInitial:
p.frames[i].OnLost = q.AddInitial
case protocol.EncryptionHandshake:
p.frames[i].OnLost = q.AddHandshake
case protocol.Encryption0RTT:
p.frames[i].OnLost = q.AddAppData
}
}
ap := ackhandler.GetPacket()
ap.PacketNumber = p.header.PacketNumber
ap.LargestAcked = largestAcked
ap.Frames = p.frames
ap.Length = p.length
ap.EncryptionLevel = encLevel
ap.SendTime = now
ap.IsPathMTUProbePacket = p.isMTUProbePacket
return ap
}
func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
// If this is not a UDP address, we don't know anything about the MTU.
// Use the minimum size of an Initial packet as the max packet size.
if udpAddr, ok := addr.(*net.UDPAddr); ok {
if utils.IsIPv4(udpAddr.IP) {
maxSize = protocol.InitialPacketSizeIPv4
} else {
maxSize = protocol.InitialPacketSizeIPv6
}
}
return maxSize
}
type packetNumberManager interface { type packetNumberManager interface {
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
@ -143,8 +97,8 @@ type sealingManager interface {
type frameSource interface { type frameSource interface {
HasData() bool HasData() bool
AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
} }
type ackFrameSource interface { type ackFrameSource interface {
@ -169,13 +123,23 @@ type packetPacker struct {
datagramQueue *datagramQueue datagramQueue *datagramQueue
retransmissionQueue *retransmissionQueue retransmissionQueue *retransmissionQueue
maxPacketSize protocol.ByteCount
numNonAckElicitingAcks int numNonAckElicitingAcks int
} }
var _ packer = &packetPacker{} var _ packer = &packetPacker{}
func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() protocol.ConnectionID, initialStream cryptoStream, handshakeStream cryptoStream, packetNumberManager packetNumberManager, retransmissionQueue *retransmissionQueue, remoteAddr net.Addr, cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, datagramQueue *datagramQueue, perspective protocol.Perspective) *packetPacker { func newPacketPacker(
srcConnID protocol.ConnectionID,
getDestConnID func() protocol.ConnectionID,
initialStream, handshakeStream cryptoStream,
packetNumberManager packetNumberManager,
retransmissionQueue *retransmissionQueue,
cryptoSetup sealingManager,
framer frameSource,
acks ackFrameSource,
datagramQueue *datagramQueue,
perspective protocol.Perspective,
) *packetPacker {
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
getDestConnID: getDestConnID, getDestConnID: getDestConnID,
@ -188,23 +152,22 @@ func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() proto
framer: framer, framer: framer,
acks: acks, acks: acks,
pnManager: packetNumberManager, pnManager: packetNumberManager,
maxPacketSize: getMaxPacketSize(remoteAddr),
} }
} }
// PackConnectionClose packs a packet that closes the connection with a transport error. // PackConnectionClose packs a packet that closes the connection with a transport error.
func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
var reason string var reason string
// don't send details of crypto errors // don't send details of crypto errors
if !e.ErrorCode.IsCryptoError() { if !e.ErrorCode.IsCryptoError() {
reason = e.ErrorMessage reason = e.ErrorMessage
} }
return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, v) return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, maxPacketSize, v)
} }
// PackApplicationClose packs a packet that closes the connection with an application error. // PackApplicationClose packs a packet that closes the connection with an application error.
func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, v) return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
} }
func (p *packetPacker) packConnectionClose( func (p *packetPacker) packConnectionClose(
@ -212,6 +175,7 @@ func (p *packetPacker) packConnectionClose(
errorCode uint64, errorCode uint64,
frameType uint64, frameType uint64,
reason string, reason string,
maxPacketSize protocol.ByteCount,
v protocol.VersionNumber, v protocol.VersionNumber,
) (*coalescedPacket, error) { ) (*coalescedPacket, error) {
var sealers [4]sealer var sealers [4]sealer
@ -241,7 +205,7 @@ func (p *packetPacker) packConnectionClose(
ccf.ReasonPhrase = "" ccf.ReasonPhrase = ""
} }
pl := payload{ pl := payload{
frames: []*ackhandler.Frame{{Frame: ccf}}, frames: []ackhandler.Frame{{Frame: ccf}},
length: ccf.Length(v), length: ccf.Length(v),
} }
@ -293,20 +257,14 @@ func (p *packetPacker) packConnectionClose(
} }
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
if encLevel == protocol.EncryptionInitial { if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size) paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize)
} }
if encLevel == protocol.Encryption1RTT { if encLevel == protocol.Encryption1RTT {
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false, v) shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.shortHdrPacket = &shortHeaderPacket{ packet.shortHdrPacket = &shp
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: oneRTTPacketNumberLen,
KeyPhase: keyPhase,
}
} else { } else {
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v) longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v)
if err != nil { if err != nil {
@ -342,25 +300,21 @@ func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnL
} }
// size is the expected size of the packet, if no padding was applied. // size is the expected size of the packet, if no padding was applied.
func (p *packetPacker) initialPaddingLen(frames []*ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, maxPacketSize protocol.ByteCount) protocol.ByteCount {
// For the server, only ack-eliciting Initial packets need to be padded. // For the server, only ack-eliciting Initial packets need to be padded.
if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) {
return 0 return 0
} }
if size >= p.maxPacketSize { if currentSize >= maxPacketSize {
return 0 return 0
} }
return p.maxPacketSize - size return maxPacketSize - currentSize
} }
// PackCoalescedPacket packs a new packet. // PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces. // It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed. // It should only be called before the handshake is confirmed.
func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
maxPacketSize := p.maxPacketSize
if p.perspective == protocol.PerspectiveClient {
maxPacketSize = protocol.MinInitialPacketSize
}
var ( var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
@ -417,7 +371,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe
if oneRTTPayload.length > 0 { if oneRTTPayload.length > 0 {
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead())
} }
} else if p.perspective == protocol.PerspectiveClient { // 0-RTT } else if p.perspective == protocol.PerspectiveClient && !onlyAck { // 0-RTT packets can't contain ACK frames
var err error var err error
zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer() zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
@ -442,7 +396,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe
longHdrPackets: make([]*longHeaderPacket, 0, 3), longHdrPackets: make([]*longHeaderPacket, 0, 3),
} }
if initialPayload.length > 0 { if initialPayload.length > 0 {
padding := p.initialPaddingLen(initialPayload.frames, size) padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
if err != nil { if err != nil {
return nil, err return nil, err
@ -463,48 +417,44 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe
} }
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
} else if oneRTTPayload.length > 0 { } else if oneRTTPayload.length > 0 {
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false, v) shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.shortHdrPacket = &shortHeaderPacket{ packet.shortHdrPacket = &shp
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: oneRTTPacketNumberLen,
KeyPhase: kp,
}
} }
return packet, nil return packet, nil
} }
// PackPacket packs a packet in the application data packet number space. // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
// It should be called after the handshake is confirmed. // It should be called after the handshake is confirmed.
func (p *packetPacker) PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
buf := getPacketBuffer()
packet, err := p.appendPacket(buf, true, maxPacketSize, v)
return packet, buf, err
}
// AppendPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
return p.appendPacket(buf, false, maxPacketSize, v)
}
func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
sealer, err := p.cryptoSetup.Get1RTTSealer() sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil { if err != nil {
return shortHeaderPacket{}, nil, err return shortHeaderPacket{}, err
} }
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
connID := p.getDestConnID() connID := p.getDestConnID()
hdrLen := wire.ShortHeaderLen(connID, pnLen) hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true, v) pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, v)
if pl.length == 0 { if pl.length == 0 {
return shortHeaderPacket{}, nil, errNothingToPack return shortHeaderPacket{}, errNothingToPack
} }
kp := sealer.KeyPhase() kp := sealer.KeyPhase()
buffer := getPacketBuffer()
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, sealer, false, v) return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
if err != nil {
return shortHeaderPacket{}, nil, err
}
return shortHeaderPacket{
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: pnLen,
KeyPhase: kp,
}, buffer, nil
} }
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
@ -519,14 +469,17 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
} }
var s cryptoStream var s cryptoStream
var handler ackhandler.FrameHandler
var hasRetransmission bool var hasRetransmission bool
//nolint:exhaustive // Initial and Handshake are the only two encryption levels here. //nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
s = p.initialStream s = p.initialStream
handler = p.retransmissionQueue.InitialAckHandler()
hasRetransmission = p.retransmissionQueue.HasInitialData() hasRetransmission = p.retransmissionQueue.HasInitialData()
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
s = p.handshakeStream s = p.handshakeStream
handler = p.retransmissionQueue.HandshakeAckHandler()
hasRetransmission = p.retransmissionQueue.HasHandshakeData() hasRetransmission = p.retransmissionQueue.HasHandshakeData()
} }
@ -550,27 +503,27 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
maxPacketSize -= hdr.GetLength(v) maxPacketSize -= hdr.GetLength(v)
if hasRetransmission { if hasRetransmission {
for { for {
var f wire.Frame var f ackhandler.Frame
//nolint:exhaustive // 0-RTT packets can't contain any retransmission.s //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
f = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v) f.Frame = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v)
f.Handler = p.retransmissionQueue.InitialAckHandler()
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v) f.Frame = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v)
f.Handler = p.retransmissionQueue.HandshakeAckHandler()
} }
if f == nil { if f.Frame == nil {
break break
} }
af := ackhandler.GetFrame() pl.frames = append(pl.frames, f)
af.Frame = f frameLen := f.Frame.Length(v)
pl.frames = append(pl.frames, af)
frameLen := f.Length(v)
pl.length += frameLen pl.length += frameLen
maxPacketSize -= frameLen maxPacketSize -= frameLen
} }
} else if s.HasData() { } else if s.HasData() {
cf := s.PopCryptoFrame(maxPacketSize) cf := s.PopCryptoFrame(maxPacketSize)
pl.frames = []*ackhandler.Frame{{Frame: cf}} pl.frames = []ackhandler.Frame{{Frame: cf, Handler: handler}}
pl.length += cf.Length(v) pl.length += cf.Length(v)
} }
return hdr, pl return hdr, pl
@ -595,18 +548,14 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
// check if we have anything to send // check if we have anything to send
if len(pl.frames) == 0 { if len(pl.frames) == 0 && len(pl.streamFrames) == 0 {
if pl.ack == nil { if pl.ack == nil {
return payload{} return payload{}
} }
// the packet only contains an ACK // the packet only contains an ACK
if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks {
ping := &wire.PingFrame{} ping := &wire.PingFrame{}
// don't retransmit the PING frame when it is lost pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping})
af := ackhandler.GetFrame()
af.Frame = ping
af.OnLost = func(wire.Frame) {}
pl.frames = append(pl.frames, af)
pl.length += ping.Length(v) pl.length += ping.Length(v)
p.numNonAckElicitingAcks = 0 p.numNonAckElicitingAcks = 0
} else { } else {
@ -621,15 +570,12 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount,
func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
if onlyAck { if onlyAck {
if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
return payload{ return payload{ack: ack, length: ack.Length(v)}
ack: ack,
length: ack.Length(v),
}
} }
return payload{} return payload{}
} }
pl := payload{frames: make([]*ackhandler.Frame, 0, 1)} pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)}
hasData := p.framer.HasData() hasData := p.framer.HasData()
hasRetransmission := p.retransmissionQueue.HasAppData() hasRetransmission := p.retransmissionQueue.HasAppData()
@ -647,11 +593,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if f := p.datagramQueue.Peek(); f != nil { if f := p.datagramQueue.Peek(); f != nil {
size := f.Length(v) size := f.Length(v)
if size <= maxFrameSize-pl.length { if size <= maxFrameSize-pl.length {
af := ackhandler.GetFrame() pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
af.Frame = f
// set it to a no-op. Then we won't set the default callback, which would retransmit the frame.
af.OnLost = func(wire.Frame) {}
pl.frames = append(pl.frames, af)
pl.length += size pl.length += size
p.datagramQueue.Pop() p.datagramQueue.Pop()
} }
@ -672,25 +614,28 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
if f == nil { if f == nil {
break break
} }
af := ackhandler.GetFrame() pl.frames = append(pl.frames, ackhandler.Frame{Frame: f, Handler: p.retransmissionQueue.AppDataAckHandler()})
af.Frame = f
pl.frames = append(pl.frames, af)
pl.length += f.Length(v) pl.length += f.Length(v)
} }
} }
if hasData { if hasData {
var lengthAdded protocol.ByteCount var lengthAdded protocol.ByteCount
startLen := len(pl.frames)
pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v) pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v)
pl.length += lengthAdded pl.length += lengthAdded
// add handlers for the control frames that were added
for i := startLen; i < len(pl.frames); i++ {
pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler()
}
pl.frames, lengthAdded = p.framer.AppendStreamFrames(pl.frames, maxFrameSize-pl.length, v) pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v)
pl.length += lengthAdded pl.length += lengthAdded
} }
return pl return pl
} }
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (*coalescedPacket, error) { func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
if encLevel == protocol.Encryption1RTT { if encLevel == protocol.Encryption1RTT {
s, err := p.cryptoSetup.Get1RTTSealer() s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil { if err != nil {
@ -700,23 +645,17 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v
connID := p.getDestConnID() connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, pnLen) hdrLen := wire.ShortHeaderLen(connID, pnLen)
pl := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v) pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v)
if pl.length == 0 { if pl.length == 0 {
return nil, nil return nil, nil
} }
buffer := getPacketBuffer() buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer} packet := &coalescedPacket{buffer: buffer}
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, s, false, v) shp, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
packet.shortHdrPacket = &shortHeaderPacket{ packet.shortHdrPacket = &shp
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: pnLen,
KeyPhase: kp,
}
return packet, nil return packet, nil
} }
@ -731,14 +670,14 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v)
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
var err error var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer() sealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v) hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v)
default: default:
panic("unknown encryption level") panic("unknown encryption level")
} }
@ -751,7 +690,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v
size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead()) size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead())
var padding protocol.ByteCount var padding protocol.ByteCount
if encLevel == protocol.EncryptionInitial { if encLevel == protocol.EncryptionInitial {
padding = p.initialPaddingLen(pl.frames, size) padding = p.initialPaddingLen(pl.frames, size, maxPacketSize)
} }
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v) longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v)
@ -762,10 +701,10 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v
return packet, nil return packet, nil
} }
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
pl := payload{ pl := payload{
frames: []*ackhandler.Frame{&ping}, frames: []ackhandler.Frame{ping},
length: ping.Length(v), length: ping.Frame.Length(v),
} }
buffer := getPacketBuffer() buffer := getPacketBuffer()
s, err := p.cryptoSetup.Get1RTTSealer() s, err := p.cryptoSetup.Get1RTTSealer()
@ -776,17 +715,8 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead()) padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead())
kp := s.KeyPhase() kp := s.KeyPhase()
ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, s, true, v) packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, size, s, true, v)
if err != nil { return packet, buffer, err
return shortHeaderPacket{}, nil, err
}
return shortHeaderPacket{
Packet: ap,
DestConnID: connID,
Ack: ack,
PacketNumberLen: pnLen,
KeyPhase: kp,
}, buffer, nil
} }
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader { func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader {
@ -829,23 +759,22 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
} }
payloadOffset := protocol.ByteCount(len(raw)) payloadOffset := protocol.ByteCount(len(raw))
pn := p.pnManager.PopPacketNumber(encLevel)
if pn != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen) raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber {
return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber)
}
return &longHeaderPacket{ return &longHeaderPacket{
header: header, header: header,
ack: pl.ack, ack: pl.ack,
frames: pl.frames, frames: pl.frames,
length: protocol.ByteCount(len(raw)), streamFrames: pl.streamFrames,
length: protocol.ByteCount(len(raw)),
}, nil }, nil
} }
@ -856,11 +785,11 @@ func (p *packetPacker) appendShortHeaderPacket(
pnLen protocol.PacketNumberLen, pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit, kp protocol.KeyPhaseBit,
pl payload, pl payload,
padding protocol.ByteCount, padding, maxPacketSize protocol.ByteCount,
sealer sealer, sealer sealer,
isMTUProbePacket bool, isMTUProbePacket bool,
v protocol.VersionNumber, v protocol.VersionNumber,
) (*ackhandler.Packet, *wire.AckFrame, error) { ) (shortHeaderPacket, error) {
var paddingLen protocol.ByteCount var paddingLen protocol.ByteCount
if pl.length < 4-protocol.ByteCount(pnLen) { if pl.length < 4-protocol.ByteCount(pnLen) {
paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length
@ -871,48 +800,36 @@ func (p *packetPacker) appendShortHeaderPacket(
raw := buffer.Data[startLen:] raw := buffer.Data[startLen:]
raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp) raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp)
if err != nil { if err != nil {
return nil, nil, err return shortHeaderPacket{}, err
} }
payloadOffset := protocol.ByteCount(len(raw)) payloadOffset := protocol.ByteCount(len(raw))
if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) {
return nil, nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
if err != nil { if err != nil {
return nil, nil, err return shortHeaderPacket{}, err
} }
if !isMTUProbePacket { if !isMTUProbePacket {
if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize { if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > maxPacketSize {
return nil, nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) return shortHeaderPacket{}, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, maxPacketSize)
} }
} }
raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen)) raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen))
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
// create the ackhandler.Packet if newPN := p.pnManager.PopPacketNumber(protocol.Encryption1RTT); newPN != pn {
largestAcked := protocol.InvalidPacketNumber return shortHeaderPacket{}, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, newPN)
if pl.ack != nil {
largestAcked = pl.ack.LargestAcked()
} }
for i := range pl.frames { return shortHeaderPacket{
if pl.frames[i].OnLost != nil { PacketNumber: pn,
continue PacketNumberLen: pnLen,
} KeyPhase: kp,
pl.frames[i].OnLost = p.retransmissionQueue.AddAppData StreamFrames: pl.streamFrames,
} Frames: pl.frames,
Ack: pl.ack,
ap := ackhandler.GetPacket() Length: protocol.ByteCount(len(raw)),
ap.PacketNumber = pn DestConnID: connID,
ap.LargestAcked = largestAcked IsPathMTUProbePacket: isMTUProbePacket,
ap.Frames = pl.frames }, nil
ap.Length = protocol.ByteCount(len(raw))
ap.EncryptionLevel = protocol.Encryption1RTT
ap.SendTime = time.Now()
ap.IsPathMTUProbePacket = isMTUProbePacket
return ap, pl.ack, nil
} }
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) {
@ -927,9 +844,16 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr
if paddingLen > 0 { if paddingLen > 0 {
raw = append(raw, make([]byte, paddingLen)...) raw = append(raw, make([]byte, paddingLen)...)
} }
for _, frame := range pl.frames { for _, f := range pl.frames {
var err error var err error
raw, err = frame.Append(raw, v) raw, err = f.Frame.Append(raw, v)
if err != nil {
return nil, err
}
}
for _, f := range pl.streamFrames {
var err error
raw, err = f.Frame.Append(raw, v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -953,16 +877,3 @@ func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.Pack
func (p *packetPacker) SetToken(token []byte) { func (p *packetPacker) SetToken(token []byte) {
p.token = token p.token = token
} }
// When a higher MTU is discovered, use it.
func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) {
p.maxPacketSize = s
}
// If the peer sets a max_packet_size that's smaller than the size we're currently using,
// we need to reduce the size of packets we send.
func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) {
if params.MaxUDPPayloadSize != 0 {
p.maxPacketSize = utils.Min(p.maxPacketSize, params.MaxUDPPayloadSize)
}
}

View File

@ -70,25 +70,6 @@ func Read(r io.ByteReader) (uint64, error) {
return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil
} }
// Write writes i in the QUIC varint format to w.
// Deprecated: use Append instead.
func Write(w Writer, i uint64) {
if i <= maxVarInt1 {
w.WriteByte(uint8(i))
} else if i <= maxVarInt2 {
w.Write([]byte{uint8(i>>8) | 0x40, uint8(i)})
} else if i <= maxVarInt4 {
w.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)})
} else if i <= maxVarInt8 {
w.Write([]byte{
uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
} else {
panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
}
}
// Append appends i in the QUIC varint format. // Append appends i in the QUIC varint format.
func Append(b []byte, i uint64) []byte { func Append(b []byte, i uint64) []byte {
if i <= maxVarInt1 { if i <= maxVarInt1 {

View File

@ -179,6 +179,10 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
s.finRead = true s.finRead = true
s.currentFrame = nil
if s.currentFrameDone != nil {
s.currentFrameDone()
}
return true, bytesRead, io.EOF return true, bytesRead, io.EOF
} }
} }

View File

@ -3,6 +3,8 @@ package quic
import ( import (
"fmt" "fmt"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/internal/wire"
) )
@ -21,7 +23,23 @@ func newRetransmissionQueue() *retransmissionQueue {
return &retransmissionQueue{} return &retransmissionQueue{}
} }
func (q *retransmissionQueue) AddInitial(f wire.Frame) { // AddPing queues a ping.
// It is used when a probe packet needs to be sent
func (q *retransmissionQueue) AddPing(encLevel protocol.EncryptionLevel) {
//nolint:exhaustive // Cannot send probe packets for 0-RTT.
switch encLevel {
case protocol.EncryptionInitial:
q.addInitial(&wire.PingFrame{})
case protocol.EncryptionHandshake:
q.addHandshake(&wire.PingFrame{})
case protocol.Encryption1RTT:
q.addAppData(&wire.PingFrame{})
default:
panic("unexpected encryption level")
}
}
func (q *retransmissionQueue) addInitial(f wire.Frame) {
if cf, ok := f.(*wire.CryptoFrame); ok { if cf, ok := f.(*wire.CryptoFrame); ok {
q.initialCryptoData = append(q.initialCryptoData, cf) q.initialCryptoData = append(q.initialCryptoData, cf)
return return
@ -29,7 +47,7 @@ func (q *retransmissionQueue) AddInitial(f wire.Frame) {
q.initial = append(q.initial, f) q.initial = append(q.initial, f)
} }
func (q *retransmissionQueue) AddHandshake(f wire.Frame) { func (q *retransmissionQueue) addHandshake(f wire.Frame) {
if cf, ok := f.(*wire.CryptoFrame); ok { if cf, ok := f.(*wire.CryptoFrame); ok {
q.handshakeCryptoData = append(q.handshakeCryptoData, cf) q.handshakeCryptoData = append(q.handshakeCryptoData, cf)
return return
@ -49,7 +67,7 @@ func (q *retransmissionQueue) HasAppData() bool {
return len(q.appData) > 0 return len(q.appData) > 0
} }
func (q *retransmissionQueue) AddAppData(f wire.Frame) { func (q *retransmissionQueue) addAppData(f wire.Frame) {
if _, ok := f.(*wire.StreamFrame); ok { if _, ok := f.(*wire.StreamFrame); ok {
panic("STREAM frames are handled with their respective streams.") panic("STREAM frames are handled with their respective streams.")
} }
@ -127,3 +145,36 @@ func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) {
panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) panic(fmt.Sprintf("unexpected encryption level: %s", encLevel))
} }
} }
func (q *retransmissionQueue) InitialAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueInitialAckHandler)(q)
}
func (q *retransmissionQueue) HandshakeAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueHandshakeAckHandler)(q)
}
func (q *retransmissionQueue) AppDataAckHandler() ackhandler.FrameHandler {
return (*retransmissionQueueAppDataAckHandler)(q)
}
type retransmissionQueueInitialAckHandler retransmissionQueue
func (q *retransmissionQueueInitialAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueInitialAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addInitial(f)
}
type retransmissionQueueHandshakeAckHandler retransmissionQueue
func (q *retransmissionQueueHandshakeAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueHandshakeAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addHandshake(f)
}
type retransmissionQueueAppDataAckHandler retransmissionQueue
func (q *retransmissionQueueAppDataAckHandler) OnAcked(wire.Frame) {}
func (q *retransmissionQueueAppDataAckHandler) OnLost(f wire.Frame) {
(*retransmissionQueue)(q).addAppData(f)
}

View File

@ -1,38 +1,65 @@
package quic package quic
import ( import (
"math"
"net" "net"
"github.com/quic-go/quic-go/internal/protocol"
) )
// A sendConn allows sending using a simple Write() on a non-connected packet conn. // A sendConn allows sending using a simple Write() on a non-connected packet conn.
type sendConn interface { type sendConn interface {
Write([]byte) error Write(b []byte, size protocol.ByteCount) error
Close() error Close() error
LocalAddr() net.Addr LocalAddr() net.Addr
RemoteAddr() net.Addr RemoteAddr() net.Addr
capabilities() connCapabilities
} }
type sconn struct { type sconn struct {
rawConn rawConn
remoteAddr net.Addr remoteAddr net.Addr
info *packetInfo info packetInfo
oob []byte oob []byte
} }
var _ sendConn = &sconn{} var _ sendConn = &sconn{}
func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { func newSendConn(c rawConn, remote net.Addr) *sconn {
sc := &sconn{
rawConn: c,
remoteAddr: remote,
}
if c.capabilities().GSO {
// add 32 bytes, so we can add the UDP_SEGMENT msg
sc.oob = make([]byte, 0, 32)
}
return sc
}
func newSendConnWithPacketInfo(c rawConn, remote net.Addr, info packetInfo) *sconn {
oob := info.OOB()
if c.capabilities().GSO {
// add 32 bytes, so we can add the UDP_SEGMENT msg
l := len(oob)
oob = append(oob, make([]byte, 32)...)
oob = oob[:l]
}
return &sconn{ return &sconn{
rawConn: c, rawConn: c,
remoteAddr: remote, remoteAddr: remote,
info: info, info: info,
oob: info.OOB(), oob: oob,
} }
} }
func (c *sconn) Write(p []byte) error { func (c *sconn) Write(p []byte, size protocol.ByteCount) error {
_, err := c.WritePacket(p, c.remoteAddr, c.oob) if size > math.MaxUint16 {
panic("size overflow")
}
_, err := c.WritePacket(p, uint16(size), c.remoteAddr, c.oob)
return err return err
} }
@ -42,33 +69,12 @@ func (c *sconn) RemoteAddr() net.Addr {
func (c *sconn) LocalAddr() net.Addr { func (c *sconn) LocalAddr() net.Addr {
addr := c.rawConn.LocalAddr() addr := c.rawConn.LocalAddr()
if c.info != nil { if c.info.addr.IsValid() {
if udpAddr, ok := addr.(*net.UDPAddr); ok { if udpAddr, ok := addr.(*net.UDPAddr); ok {
addrCopy := *udpAddr addrCopy := *udpAddr
addrCopy.IP = c.info.addr addrCopy.IP = c.info.addr.AsSlice()
addr = &addrCopy addr = &addrCopy
} }
} }
return addr return addr
} }
type spconn struct {
net.PacketConn
remoteAddr net.Addr
}
var _ sendConn = &spconn{}
func newSendPconn(c net.PacketConn, remote net.Addr) sendConn {
return &spconn{PacketConn: c, remoteAddr: remote}
}
func (c *spconn) Write(p []byte) error {
_, err := c.WriteTo(p, c.remoteAddr)
return err
}
func (c *spconn) RemoteAddr() net.Addr {
return c.remoteAddr
}

View File

@ -1,15 +1,22 @@
package quic package quic
import "github.com/quic-go/quic-go/internal/protocol"
type sender interface { type sender interface {
Send(p *packetBuffer) Send(p *packetBuffer, packetSize protocol.ByteCount)
Run() error Run() error
WouldBlock() bool WouldBlock() bool
Available() <-chan struct{} Available() <-chan struct{}
Close() Close()
} }
type queueEntry struct {
buf *packetBuffer
size protocol.ByteCount
}
type sendQueue struct { type sendQueue struct {
queue chan *packetBuffer queue chan queueEntry
closeCalled chan struct{} // runStopped when Close() is called closeCalled chan struct{} // runStopped when Close() is called
runStopped chan struct{} // runStopped when the run loop returns runStopped chan struct{} // runStopped when the run loop returns
available chan struct{} available chan struct{}
@ -26,16 +33,16 @@ func newSendQueue(conn sendConn) sender {
runStopped: make(chan struct{}), runStopped: make(chan struct{}),
closeCalled: make(chan struct{}), closeCalled: make(chan struct{}),
available: make(chan struct{}, 1), available: make(chan struct{}, 1),
queue: make(chan *packetBuffer, sendQueueCapacity), queue: make(chan queueEntry, sendQueueCapacity),
} }
} }
// Send sends out a packet. It's guaranteed to not block. // Send sends out a packet. It's guaranteed to not block.
// Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock.
// Otherwise Send will panic. // Otherwise Send will panic.
func (h *sendQueue) Send(p *packetBuffer) { func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) {
select { select {
case h.queue <- p: case h.queue <- queueEntry{buf: p, size: size}:
// clear available channel if we've reached capacity // clear available channel if we've reached capacity
if len(h.queue) == sendQueueCapacity { if len(h.queue) == sendQueueCapacity {
select { select {
@ -69,8 +76,8 @@ func (h *sendQueue) Run() error {
h.closeCalled = nil // prevent this case from being selected again h.closeCalled = nil // prevent this case from being selected again
// make sure that all queued packets are actually sent out // make sure that all queued packets are actually sent out
shouldClose = true shouldClose = true
case p := <-h.queue: case e := <-h.queue:
if err := h.conn.Write(p.Data); err != nil { if err := h.conn.Write(e.buf.Data, e.size); err != nil {
// This additional check enables: // This additional check enables:
// 1. Checking for "datagram too large" message from the kernel, as such, // 1. Checking for "datagram too large" message from the kernel, as such,
// 2. Path MTU discovery,and // 2. Path MTU discovery,and
@ -79,7 +86,7 @@ func (h *sendQueue) Run() error {
return err return err
} }
} }
p.Release() e.buf.Release()
select { select {
case h.available <- struct{}{}: case h.available <- struct{}{}:
default: default:

View File

@ -18,7 +18,7 @@ type sendStreamI interface {
SendStream SendStream
handleStopSendingFrame(*wire.StopSendingFrame) handleStopSendingFrame(*wire.StopSendingFrame)
hasData() bool hasData() bool
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool)
closeForShutdown(error) closeForShutdown(error)
updateSendWindow(protocol.ByteCount) updateSendWindow(protocol.ByteCount)
} }
@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool {
// popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream
// maxBytes is the maximum length this frame (including frame header) will have. // maxBytes is the maximum length this frame (including frame header) will have.
func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool /* has more data to send */) { func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) {
s.mutex.Lock() s.mutex.Lock()
f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v)
if f != nil { if f != nil {
@ -207,13 +207,12 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers
s.mutex.Unlock() s.mutex.Unlock()
if f == nil { if f == nil {
return nil, hasMoreData return ackhandler.StreamFrame{}, false, hasMoreData
} }
af := ackhandler.GetFrame() return ackhandler.StreamFrame{
af.Frame = f Frame: f,
af.OnLost = s.queueRetransmission Handler: (*sendStreamAckHandler)(s),
af.OnAcked = s.frameAcked }, true, hasMoreData
return af, hasMoreData
} }
func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) { func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) {
@ -348,26 +347,6 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By
} }
} }
func (s *sendStream) frameAcked(f wire.Frame) {
f.(*wire.StreamFrame).PutBack()
s.mutex.Lock()
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
newlyCompleted := s.isNewlyCompleted()
s.mutex.Unlock()
if newlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *sendStream) isNewlyCompleted() bool { func (s *sendStream) isNewlyCompleted() bool {
completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
if completed && !s.completed { if completed && !s.completed {
@ -377,24 +356,6 @@ func (s *sendStream) isNewlyCompleted() bool {
return false return false
} }
func (s *sendStream) queueRetransmission(f wire.Frame) {
sf := f.(*wire.StreamFrame)
sf.DataLenPresent = true
s.mutex.Lock()
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
s.retransmissionQueue = append(s.retransmissionQueue, sf)
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID)
}
func (s *sendStream) Close() error { func (s *sendStream) Close() error {
s.mutex.Lock() s.mutex.Lock()
if s.closeForShutdownErr != nil { if s.closeForShutdownErr != nil {
@ -487,3 +448,45 @@ func (s *sendStream) signalWrite() {
default: default:
} }
} }
type sendStreamAckHandler sendStream
var _ ackhandler.FrameHandler = &sendStreamAckHandler{}
func (s *sendStreamAckHandler) OnAcked(f wire.Frame) {
sf := f.(*wire.StreamFrame)
sf.PutBack()
s.mutex.Lock()
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
newlyCompleted := (*sendStream)(s).isNewlyCompleted()
s.mutex.Unlock()
if newlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
}
func (s *sendStreamAckHandler) OnLost(f wire.Frame) {
sf := f.(*wire.StreamFrame)
s.mutex.Lock()
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
}
sf.DataLenPresent = true
s.retransmissionQueue = append(s.retransmissionQueue, sf)
s.numOutstandingFrames--
if s.numOutstandingFrames < 0 {
panic("numOutStandingFrames negative")
}
s.mutex.Unlock()
s.sender.onHasStreamData(s.streamID)
}

View File

@ -20,33 +20,29 @@ import (
) )
// ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. // ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close.
var ErrServerClosed = errors.New("quic: Server closed") var ErrServerClosed = errors.New("quic: server closed")
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
handlePacket(*receivedPacket) handlePacket(receivedPacket)
shutdown() shutdown()
destroy(error) destroy(error)
getPerspective() protocol.Perspective getPerspective() protocol.Perspective
} }
type unknownPacketHandler interface {
handlePacket(*receivedPacket)
setCloseError(error)
}
type packetHandlerManager interface { type packetHandlerManager interface {
AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool Get(protocol.ConnectionID) (packetHandler, bool)
Destroy() error GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool)
connRunner AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool
SetServer(unknownPacketHandler) Close(error)
CloseServer() CloseServer()
connRunner
} }
type quicConn interface { type quicConn interface {
EarlyConnection EarlyConnection
earlyConnReady() <-chan struct{} earlyConnReady() <-chan struct{}
handlePacket(*receivedPacket) handlePacket(receivedPacket)
GetVersion() protocol.VersionNumber GetVersion() protocol.VersionNumber
getPerspective() protocol.Perspective getPerspective() protocol.Perspective
run() error run() error
@ -54,6 +50,11 @@ type quicConn interface {
shutdown() shutdown()
} }
type zeroRTTQueue struct {
packets []receivedPacket
expiration time.Time
}
// A Listener of QUIC // A Listener of QUIC
type baseServer struct { type baseServer struct {
mutex sync.Mutex mutex sync.Mutex
@ -64,15 +65,17 @@ type baseServer struct {
config *Config config *Config
conn rawConn conn rawConn
// If the server is started with ListenAddr, we create a packet conn.
// If it is started with Listen, we take a packet conn as a parameter.
createdPacketConn bool
tokenGenerator *handshake.TokenGenerator tokenGenerator *handshake.TokenGenerator
connHandler packetHandlerManager connIDGenerator ConnectionIDGenerator
connHandler packetHandlerManager
onClose func()
receivedPackets chan *receivedPacket receivedPackets chan receivedPacket
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
// set as a member, so they can be set in the tests // set as a member, so they can be set in the tests
newConn func( newConn func(
@ -83,6 +86,7 @@ type baseServer struct {
protocol.ConnectionID, /* client dest connection ID */ protocol.ConnectionID, /* client dest connection ID */
protocol.ConnectionID, /* destination connection ID */ protocol.ConnectionID, /* destination connection ID */
protocol.ConnectionID, /* source connection ID */ protocol.ConnectionID, /* source connection ID */
ConnectionIDGenerator,
protocol.StatelessResetToken, protocol.StatelessResetToken,
*Config, *Config,
*tls.Config, *tls.Config,
@ -94,128 +98,164 @@ type baseServer struct {
protocol.VersionNumber, protocol.VersionNumber,
) quicConn ) quicConn
serverError error serverError error
errorChan chan struct{} errorChan chan struct{}
closed bool closed bool
running chan struct{} // closed as soon as run() returns running chan struct{} // closed as soon as run() returns
versionNegotiationQueue chan receivedPacket
invalidTokenQueue chan receivedPacket
connQueue chan quicConn connQueue chan quicConn
connQueueLen int32 // to be used as an atomic connQueueLen int32 // to be used as an atomic
tracer logging.Tracer
logger utils.Logger logger utils.Logger
} }
var ( // A Listener listens for incoming QUIC connections.
_ Listener = &baseServer{} // It returns connections once the handshake has completed.
_ unknownPacketHandler = &baseServer{} type Listener struct {
) baseServer *baseServer
}
type earlyServer struct{ *baseServer } // Accept returns new connections. It should be called in a loop.
func (l *Listener) Accept(ctx context.Context) (Connection, error) {
return l.baseServer.Accept(ctx)
}
var _ EarlyListener = &earlyServer{} // Close the server. All active connections will be closed.
func (l *Listener) Close() error {
return l.baseServer.Close()
}
func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { // Addr returns the local network address that the server is listening on.
return s.baseServer.accept(ctx) func (l *Listener) Addr() net.Addr {
return l.baseServer.Addr()
}
// An EarlyListener listens for incoming QUIC connections, and returns them before the handshake completes.
// For connections that don't use 0-RTT, this allows the server to send 0.5-RTT data.
// This data is encrypted with forward-secure keys, however, the client's identity has not yet been verified.
// For connection using 0-RTT, this allows the server to accept and respond to streams that the client opened in the
// 0-RTT data it sent. Note that at this point during the handshake, the live-ness of the
// client has not yet been confirmed, and the 0-RTT data could have been replayed by an attacker.
type EarlyListener struct {
baseServer *baseServer
}
// Accept returns a new connections. It should be called in a loop.
func (l *EarlyListener) Accept(ctx context.Context) (EarlyConnection, error) {
return l.baseServer.accept(ctx)
}
// Close the server. All active connections will be closed.
func (l *EarlyListener) Close() error {
return l.baseServer.Close()
}
// Addr returns the local network addr that the server is listening on.
func (l *EarlyListener) Addr() net.Addr {
return l.baseServer.Addr()
} }
// ListenAddr creates a QUIC server listening on a given address. // ListenAddr creates a QUIC server listening on a given address.
// The tls.Config must not be nil and must contain a certificate configuration. // See Listen for more details.
// The quic.Config may be nil, in that case the default values will be used. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) {
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { conn, err := listenUDP(addr)
return listenAddr(addr, tlsConf, config, false)
}
// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) {
s, err := listenAddr(addr, tlsConf, config, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &earlyServer{s}, nil return (&Transport{
Conn: conn,
createdConn: true,
isSingleUse: true,
}).Listen(tlsConf, config)
} }
func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { // ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes.
func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
conn, err := listenUDP(addr)
if err != nil {
return nil, err
}
return (&Transport{
Conn: conn,
createdConn: true,
isSingleUse: true,
}).ListenEarly(tlsConf, config)
}
func listenUDP(addr string) (*net.UDPConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := net.ListenUDP("udp", udpAddr) return net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
serv, err := listen(conn, tlsConf, config, acceptEarly)
if err != nil {
return nil, err
}
serv.createdPacketConn = true
return serv, nil
} }
// Listen listens for QUIC connections on a given net.PacketConn. If the // Listen listens for QUIC connections on a given net.PacketConn.
// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn // If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP // ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write // will be used instead of ReadFrom and WriteTo to read/write packets.
// packets. A single net.PacketConn only be used for a single call to Listen. // A single net.PacketConn can only be used for a single call to Listen.
// The PacketConn can be used for simultaneous calls to Dial. QUIC connection //
// IDs are used for demultiplexing the different connections.
// The tls.Config must not be nil and must contain a certificate configuration. // The tls.Config must not be nil and must contain a certificate configuration.
// Furthermore, it must define an application control (using NextProtos). // Furthermore, it must define an application control (using NextProtos).
// The quic.Config may be nil, in that case the default values will be used. // The quic.Config may be nil, in that case the default values will be used.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { //
return listen(conn, tlsConf, config, false) // This is a convenience function. More advanced use cases should instantiate a Transport,
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for outgoing QUIC connections.
func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) {
tr := &Transport{Conn: conn, isSingleUse: true}
return tr.Listen(tlsConf, config)
} }
// ListenEarly works like Listen, but it returns connections before the handshake completes. // ListenEarly works like Listen, but it returns connections before the handshake completes.
func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) {
s, err := listen(conn, tlsConf, config, true) tr := &Transport{Conn: conn, isSingleUse: true}
if err != nil { return tr.ListenEarly(tlsConf, config)
return nil, err
}
return &earlyServer{s}, nil
} }
func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { func newServer(
if tlsConf == nil { conn rawConn,
return nil, errors.New("quic: tls.Config not set") connHandler packetHandlerManager,
} connIDGenerator ConnectionIDGenerator,
if err := validateConfig(config); err != nil { tlsConf *tls.Config,
return nil, err config *Config,
} tracer logging.Tracer,
config = populateServerConfig(config) onClose func(),
for _, v := range config.Versions { acceptEarly bool,
if !protocol.IsValidVersion(v) { ) (*baseServer, error) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {
return nil, err
}
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c, err := wrapConn(conn)
if err != nil {
return nil, err
}
s := &baseServer{ s := &baseServer{
conn: c, conn: conn,
tlsConf: tlsConf, tlsConf: tlsConf,
config: config, config: config,
tokenGenerator: tokenGenerator, tokenGenerator: tokenGenerator,
connHandler: connHandler, connIDGenerator: connIDGenerator,
connQueue: make(chan quicConn), connHandler: connHandler,
errorChan: make(chan struct{}), connQueue: make(chan quicConn),
running: make(chan struct{}), errorChan: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), running: make(chan struct{}),
newConn: newConnection, receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
logger: utils.DefaultLogger.WithPrefix("server"), versionNegotiationQueue: make(chan receivedPacket, 4),
acceptEarlyConns: acceptEarly, invalidTokenQueue: make(chan receivedPacket, 4),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
onClose: onClose,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
} }
go s.run() go s.run()
connHandler.SetServer(s) go s.runSendQueue()
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil return s, nil
} }
@ -239,6 +279,19 @@ func (s *baseServer) run() {
} }
} }
func (s *baseServer) runSendQueue() {
for {
select {
case <-s.running:
return
case p := <-s.versionNegotiationQueue:
s.maybeSendVersionNegotiationPacket(p)
case p := <-s.invalidTokenQueue:
s.maybeSendInvalidToken(p)
}
}
}
// Accept returns connections that already completed the handshake. // Accept returns connections that already completed the handshake.
// It is only valid if acceptEarlyConns is false. // It is only valid if acceptEarlyConns is false.
func (s *baseServer) Accept(ctx context.Context) (Connection, error) { func (s *baseServer) Accept(ctx context.Context) (Connection, error) {
@ -267,18 +320,12 @@ func (s *baseServer) Close() error {
if s.serverError == nil { if s.serverError == nil {
s.serverError = ErrServerClosed s.serverError = ErrServerClosed
} }
// If the server was started with ListenAddr, we created the packet conn.
// We need to close it in order to make the go routine reading from that conn return.
createdPacketConn := s.createdPacketConn
s.closed = true s.closed = true
close(s.errorChan) close(s.errorChan)
s.mutex.Unlock() s.mutex.Unlock()
<-s.running <-s.running
s.connHandler.CloseServer() s.onClose()
if createdPacketConn {
return s.connHandler.Destroy()
}
return nil return nil
} }
@ -298,22 +345,26 @@ func (s *baseServer) Addr() net.Addr {
return s.conn.LocalAddr() return s.conn.LocalAddr()
} }
func (s *baseServer) handlePacket(p *receivedPacket) { func (s *baseServer) handlePacket(p receivedPacket) {
select { select {
case s.receivedPackets <- p: case s.receivedPackets <- p:
default: default:
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
} }
} }
} }
func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer still in use? */ {
if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) {
defer s.cleanupZeroRTTQueues(p.rcvTime)
}
if wire.IsVersionNegotiationPacket(p.data) { if wire.IsVersionNegotiationPacket(p.data) {
s.logger.Debugf("Dropping Version Negotiation packet.") s.logger.Debugf("Dropping Version Negotiation packet.")
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
} }
@ -322,42 +373,54 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
panic(fmt.Sprintf("misrouted packet: %#v", p.data)) panic(fmt.Sprintf("misrouted packet: %#v", p.data))
} }
v, err := wire.ParseVersion(p.data) v, err := wire.ParseVersion(p.data)
// send a Version Negotiation Packet if the client is speaking a different protocol version // drop the packet if we failed to parse the protocol version
if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) { if err != nil {
if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize { s.logger.Debugf("Dropping a packet with an unknown version")
s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) if s.tracer != nil {
if s.config.Tracer != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
if err != nil { // should never happen
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
if !s.config.DisableVersionNegotiationPackets {
go s.sendVersionNegotiationPacket(p.remoteAddr, src, dest, p.info.OOB(), v)
} }
return false return false
} }
// send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, v) {
if s.config.DisableVersionNegotiationPackets {
return false
}
if p.Size() < protocol.MinUnknownVersionPacketSize {
s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size())
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
return s.enqueueVersionNegotiationPacket(p)
}
if wire.Is0RTTPacket(p.data) {
if !s.acceptEarlyConns {
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false
}
return s.handle0RTTPacket(p)
}
// If we're creating a new connection, the packet will be passed to the connection. // If we're creating a new connection, the packet will be passed to the connection.
// The header will then be parsed again. // The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data) hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil { if err != nil {
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
} }
s.logger.Debugf("Error parsing packet: %s", err) s.logger.Debugf("Error parsing packet: %s", err)
return false return false
} }
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
} }
@ -367,8 +430,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
// There's little point in sending a Stateless Reset, since the client // There's little point in sending a Stateless Reset, since the client
// might not have received the token yet. // might not have received the token yet.
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
} }
@ -383,6 +446,74 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
return true return true
} }
func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil {
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
}
return false
}
// check again if we might have a connection now
if handler, ok := s.connHandler.Get(connID); ok {
handler.handlePacket(p)
return true
}
if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
q.packets = append(q.packets, p)
return true
}
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
return false
}
queue := &zeroRTTQueue{packets: make([]receivedPacket, 1, 8)}
queue.packets[0] = p
expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration)
queue.expiration = expiration
if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) {
s.nextZeroRTTCleanup = expiration
}
s.zeroRTTQueues[connID] = queue
return true
}
func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
// Iterate over all queues to find those that are expired.
// This is ok since we're placing a pretty low limit on the number of queues.
var nextCleanup time.Time
for connID, q := range s.zeroRTTQueues {
if q.expiration.After(now) {
if nextCleanup.IsZero() || nextCleanup.After(q.expiration) {
nextCleanup = q.expiration
}
continue
}
for _, p := range q.packets {
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
}
p.buffer.Release()
}
delete(s.zeroRTTQueues, connID)
if s.logger.Debug() {
s.logger.Debugf("Removing 0-RTT queue for %s.", connID)
}
}
s.nextZeroRTTCleanup = nextCleanup
}
// validateToken returns false if: // validateToken returns false if:
// - address is invalid // - address is invalid
// - token is expired // - token is expired
@ -403,15 +534,23 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
return true return true
} }
func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release() p.buffer.Release()
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return errors.New("too short connection ID") return errors.New("too short connection ID")
} }
// The server queues packets for a while, and we might already have established a connection by now.
// This results in a second check in the connection map.
// That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets).
if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok {
handler.handlePacket(p)
return nil
}
var ( var (
token *handshake.Token token *handshake.Token
retrySrcConnID *protocol.ConnectionID retrySrcConnID *protocol.ConnectionID
@ -429,7 +568,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
} }
clientAddrIsValid := s.validateToken(token, p.remoteAddr) clientAddrIsValid := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrIsValid { if token != nil && !clientAddrIsValid {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all. // We just ignore them, and act as if there was no token on this packet at all.
@ -440,16 +578,13 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
// For Retry tokens, we send an INVALID_ERROR if // For Retry tokens, we send an INVALID_ERROR if
// * the token is too old, or // * the token is too old, or
// * the token is invalid, in case of a retry token. // * the token is invalid, in case of a retry token.
go func() { s.enqueueInvalidToken(p)
defer p.buffer.Release()
if err := s.maybeSendInvalidToken(p, hdr); err != nil {
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
}
}()
return nil return nil
} }
} }
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
go func() { go func() {
defer p.buffer.Release() defer p.buffer.Release()
if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil {
@ -470,37 +605,43 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return nil return nil
} }
connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil { if err != nil {
return err return err
} }
s.logger.Debugf("Changing connection ID to %s.", connID) s.logger.Debugf("Changing connection ID to %s.", connID)
var conn quicConn var conn quicConn
tracingID := nextConnTracingID() tracingID := nextConnTracingID()
if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) {
config := s.config
if s.config.GetConfigForClient != nil {
conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr})
if err != nil {
s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback")
return nil, false
}
config = populateConfig(conf)
}
var tracer logging.ConnectionTracer var tracer logging.ConnectionTracer
if s.config.Tracer != nil { if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback. // Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID connID := hdr.DestConnectionID
if origDestConnID.Len() > 0 { if origDestConnID.Len() > 0 {
connID = origDestConnID connID = origDestConnID
} }
tracer = s.config.Tracer.TracerForConnection( tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
context.WithValue(context.Background(), ConnectionTracingKey, tracingID),
protocol.PerspectiveServer,
connID,
)
} }
conn = s.newConn( conn = s.newConn(
newSendConn(s.conn, p.remoteAddr, p.info), newSendConnWithPacketInfo(s.conn, p.remoteAddr, p.info),
s.connHandler, s.connHandler,
origDestConnID, origDestConnID,
retrySrcConnID, retrySrcConnID,
hdr.DestConnectionID, hdr.DestConnectionID,
hdr.SrcConnectionID, hdr.SrcConnectionID,
connID, connID,
s.connIDGenerator,
s.connHandler.GetStatelessResetToken(connID), s.connHandler.GetStatelessResetToken(connID),
s.config, config,
s.tlsConf, s.tlsConf,
s.tokenGenerator, s.tokenGenerator,
clientAddrIsValid, clientAddrIsValid,
@ -510,8 +651,22 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
hdr.Version, hdr.Version,
) )
conn.handlePacket(p) conn.handlePacket(p)
return conn
if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok {
for _, p := range q.packets {
conn.handlePacket(p)
}
delete(s.zeroRTTQueues, hdr.DestConnectionID)
}
return conn, true
}); !added { }); !added {
go func() {
defer p.buffer.Release()
if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil {
s.logger.Debugf("Error rejecting connection: %s", err)
}
}()
return nil return nil
} }
go conn.run() go conn.run()
@ -551,11 +706,11 @@ func (s *baseServer) handleNewConn(conn quicConn) {
} }
} }
func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error {
// Log the Initial packet now. // Log the Initial packet now.
// If no Retry is sent, the packet will be logged by the connection. // If no Retry is sent, the packet will be logged by the connection.
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() srcConnID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil { if err != nil {
return err return err
} }
@ -584,47 +739,69 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
// append the Retry integrity tag // append the Retry integrity tag
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
buf.Data = append(buf.Data, tag[:]...) buf.Data = append(buf.Data, tag[:]...)
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
} }
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) _, err = s.conn.WritePacket(buf.Data, uint16(len(buf.Data)), remoteAddr, info.OOB())
return err return err
} }
func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { func (s *baseServer) enqueueInvalidToken(p receivedPacket) {
select {
case s.invalidTokenQueue <- p:
default:
// it's fine to drop INVALID_TOKEN packets when we are busy
p.buffer.Release()
}
}
func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
defer p.buffer.Release()
hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil {
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Error parsing packet: %s", err)
return
}
// Only send INVALID_TOKEN if we can unprotect the packet. // Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted. // This makes sure that we won't send it for packets that were corrupted.
sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
data := p.data[:hdr.ParsedLen()+hdr.Length] data := p.data[:hdr.ParsedLen()+hdr.Length]
extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version)
// Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted.
if err != nil { if err != nil {
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
} }
// don't return the error here. Just drop the packet. return
return nil
} }
hdrLen := extHdr.ParsedLen() hdrLen := extHdr.ParsedLen()
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
// don't return the error here. Just drop the packet. if s.tracer != nil {
if s.config.Tracer != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
} }
return nil return
} }
if s.logger.Debug() { if s.logger.Debug() {
s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr)
} }
return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info) if err := s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info); err != nil {
s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err)
}
} }
func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error {
sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version)
return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info)
} }
// sendError sends the error as a response to the packet received with header hdr // sendError sends the error as a response to the packet received with header hdr
func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info packetInfo) error {
b := getPacketBuffer() b := getPacketBuffer()
defer b.Release() defer b.Release()
@ -661,21 +838,48 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
replyHdr.Log(s.logger) replyHdr.Log(s.logger)
wire.LogFrame(s.logger, ccf, true) wire.LogFrame(s.logger, ccf, true)
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
} }
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) _, err = s.conn.WritePacket(b.Data, uint16(len(b.Data)), remoteAddr, info.OOB())
return err return err
} }
func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest protocol.ArbitraryLenConnectionID, oob []byte, v protocol.VersionNumber) { func (s *baseServer) enqueueVersionNegotiationPacket(p receivedPacket) (bufferInUse bool) {
select {
case s.versionNegotiationQueue <- p:
return true
default:
// it's fine to not send version negotiation packets when we are busy
}
return false
}
func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
defer p.buffer.Release()
v, err := wire.ParseVersion(p.data)
if err != nil {
s.logger.Debugf("failed to parse version for sending version negotiation packet: %s", err)
return
}
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
if err != nil { // should never happen
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
if s.tracer != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
}
return
}
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
if s.config.Tracer != nil { if s.tracer != nil {
s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
} }
if _, err := s.conn.WritePacket(data, remote, oob); err != nil { if _, err := s.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err) s.logger.Debugf("Error sending Version Negotiation: %s", err)
} }
} }

View File

@ -60,7 +60,7 @@ type streamI interface {
// for sending // for sending
hasData() bool hasData() bool
handleStopSendingFrame(*wire.StopSendingFrame) handleStopSendingFrame(*wire.StopSendingFrame)
popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool)
updateSendWindow(protocol.ByteCount) updateSendWindow(protocol.ByteCount)
} }

View File

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"fmt"
"net" "net"
"syscall" "syscall"
"time" "time"
@ -15,16 +16,38 @@ import (
type OOBCapablePacketConn interface { type OOBCapablePacketConn interface {
net.PacketConn net.PacketConn
SyscallConn() (syscall.RawConn, error) SyscallConn() (syscall.RawConn, error)
SetReadBuffer(int) error
ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error)
WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error)
} }
var _ OOBCapablePacketConn = &net.UDPConn{} var _ OOBCapablePacketConn = &net.UDPConn{}
func wrapConn(pc net.PacketConn) (rawConn, error) { // OptimizeConn takes a net.PacketConn and attempts to enable various optimizations that will improve QUIC performance:
// 1. It enables the Don't Fragment (DF) bit on the IP header.
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
// 2. It enables reading of the ECN bits from the IP header.
// This allows the remote node to speed up its loss detection and recovery.
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
//
// In order for this to work, the connection needs to implement the OOBCapablePacketConn interface (as a *net.UDPConn does).
//
// It's only necessary to call this function explicitly if the application calls WriteTo
// after passing the connection to the Transport.
func OptimizeConn(c net.PacketConn) (net.PacketConn, error) {
return wrapConn(c)
}
func wrapConn(pc net.PacketConn) (interface {
net.PacketConn
rawConn
}, error,
) {
conn, ok := pc.(interface { conn, ok := pc.(interface {
SyscallConn() (syscall.RawConn, error) SyscallConn() (syscall.RawConn, error)
}) })
var supportsDF bool
if ok { if ok {
rawConn, err := conn.SyscallConn() rawConn, err := conn.SyscallConn()
if err != nil { if err != nil {
@ -33,7 +56,8 @@ func wrapConn(pc net.PacketConn) (rawConn, error) {
if _, ok := pc.LocalAddr().(*net.UDPAddr); ok { if _, ok := pc.LocalAddr().(*net.UDPAddr); ok {
// Only set DF on sockets that we expect to be able to handle that configuration. // Only set DF on sockets that we expect to be able to handle that configuration.
err = setDF(rawConn) var err error
supportsDF, err = setDF(rawConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -42,32 +66,33 @@ func wrapConn(pc net.PacketConn) (rawConn, error) {
c, ok := pc.(OOBCapablePacketConn) c, ok := pc.(OOBCapablePacketConn)
if !ok { if !ok {
utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.")
return &basicConn{PacketConn: pc}, nil return &basicConn{PacketConn: pc, supportsDF: supportsDF}, nil
} }
return newConn(c) return newConn(c, supportsDF)
} }
// The basicConn is the most trivial implementation of a connection. // The basicConn is the most trivial implementation of a rawConn.
// It reads a single packet from the underlying net.PacketConn. // It reads a single packet from the underlying net.PacketConn.
// It is used when // It is used when
// * the net.PacketConn is not a OOBCapablePacketConn, and // * the net.PacketConn is not a OOBCapablePacketConn, and
// * when the OS doesn't support OOB. // * when the OS doesn't support OOB.
type basicConn struct { type basicConn struct {
net.PacketConn net.PacketConn
supportsDF bool
} }
var _ rawConn = &basicConn{} var _ rawConn = &basicConn{}
func (c *basicConn) ReadPacket() (*receivedPacket, error) { func (c *basicConn) ReadPacket() (receivedPacket, error) {
buffer := getPacketBuffer() buffer := getPacketBuffer()
// The packet size should not exceed protocol.MaxPacketBufferSize bytes // The packet size should not exceed protocol.MaxPacketBufferSize bytes
// If it does, we only read a truncated packet, which will then end up undecryptable // If it does, we only read a truncated packet, which will then end up undecryptable
buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
n, addr, err := c.PacketConn.ReadFrom(buffer.Data) n, addr, err := c.PacketConn.ReadFrom(buffer.Data)
if err != nil { if err != nil {
return nil, err return receivedPacket{}, err
} }
return &receivedPacket{ return receivedPacket{
remoteAddr: addr, remoteAddr: addr,
rcvTime: time.Now(), rcvTime: time.Now(),
data: buffer.Data[:n], data: buffer.Data[:n],
@ -75,6 +100,11 @@ func (c *basicConn) ReadPacket() (*receivedPacket, error) {
}, nil }, nil
} }
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { func (c *basicConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, _ []byte) (n int, err error) {
if uint16(len(b)) != packetSize {
panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b)))
}
return c.PacketConn.WriteTo(b, addr) return c.PacketConn.WriteTo(b, addr)
} }
func (c *basicConn) capabilities() connCapabilities { return connCapabilities{DF: c.supportsDF} }

68
vendor/github.com/quic-go/quic-go/sys_conn_buffers.go generated vendored Normal file
View File

@ -0,0 +1,68 @@
package quic
import (
"errors"
"fmt"
"net"
"syscall"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
//go:generate sh -c "echo '// Code generated by go generate. DO NOT EDIT.\n// Source: sys_conn_buffers.go\n' > sys_conn_buffers_write.go && sed -e 's/SetReadBuffer/SetWriteBuffer/g' -e 's/setReceiveBuffer/setSendBuffer/g' -e 's/inspectReadBuffer/inspectWriteBuffer/g' -e 's/protocol\\.DesiredReceiveBufferSize/protocol\\.DesiredSendBufferSize/g' -e 's/forceSetReceiveBuffer/forceSetSendBuffer/g' -e 's/receive buffer/send buffer/g' sys_conn_buffers.go | sed '/^\\/\\/go:generate/d' >> sys_conn_buffers_write.go"
func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error {
conn, ok := c.(interface{ SetReadBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?")
}
var syscallConn syscall.RawConn
if sc, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
}); ok {
var err error
syscallConn, err = sc.SyscallConn()
if err != nil {
syscallConn = nil
}
}
// The connection has a SetReadBuffer method, but we couldn't obtain a syscall.RawConn.
// This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the
// net.PacketConn interface and the SetReadBuffer method.
// We have no way of checking if increasing the buffer size actually worked.
if syscallConn == nil {
return conn.SetReadBuffer(protocol.DesiredReceiveBufferSize)
}
size, err := inspectReadBuffer(syscallConn)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if size >= protocol.DesiredReceiveBufferSize {
logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024)
return nil
}
// Ignore the error. We check if we succeeded by querying the buffer size afterward.
_ = conn.SetReadBuffer(protocol.DesiredReceiveBufferSize)
newSize, err := inspectReadBuffer(syscallConn)
if newSize < protocol.DesiredReceiveBufferSize {
// Try again with RCVBUFFORCE on Linux
_ = forceSetReceiveBuffer(syscallConn, protocol.DesiredReceiveBufferSize)
newSize, err = inspectReadBuffer(syscallConn)
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
}
if err != nil {
return fmt.Errorf("failed to determine receive buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredReceiveBufferSize {
return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024)
return nil
}

View File

@ -0,0 +1,70 @@
// Code generated by go generate. DO NOT EDIT.
// Source: sys_conn_buffers.go
package quic
import (
"errors"
"fmt"
"net"
"syscall"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
)
func setSendBuffer(c net.PacketConn, logger utils.Logger) error {
conn, ok := c.(interface{ SetWriteBuffer(int) error })
if !ok {
return errors.New("connection doesn't allow setting of send buffer size. Not a *net.UDPConn?")
}
var syscallConn syscall.RawConn
if sc, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
}); ok {
var err error
syscallConn, err = sc.SyscallConn()
if err != nil {
syscallConn = nil
}
}
// The connection has a SetWriteBuffer method, but we couldn't obtain a syscall.RawConn.
// This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the
// net.PacketConn interface and the SetWriteBuffer method.
// We have no way of checking if increasing the buffer size actually worked.
if syscallConn == nil {
return conn.SetWriteBuffer(protocol.DesiredSendBufferSize)
}
size, err := inspectWriteBuffer(syscallConn)
if err != nil {
return fmt.Errorf("failed to determine send buffer size: %w", err)
}
if size >= protocol.DesiredSendBufferSize {
logger.Debugf("Conn has send buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024)
return nil
}
// Ignore the error. We check if we succeeded by querying the buffer size afterward.
_ = conn.SetWriteBuffer(protocol.DesiredSendBufferSize)
newSize, err := inspectWriteBuffer(syscallConn)
if newSize < protocol.DesiredSendBufferSize {
// Try again with RCVBUFFORCE on Linux
_ = forceSetSendBuffer(syscallConn, protocol.DesiredSendBufferSize)
newSize, err = inspectWriteBuffer(syscallConn)
if err != nil {
return fmt.Errorf("failed to determine send buffer size: %w", err)
}
}
if err != nil {
return fmt.Errorf("failed to determine send buffer size: %w", err)
}
if newSize == size {
return fmt.Errorf("failed to increase send buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredSendBufferSize/1024, newSize/1024)
}
if newSize < protocol.DesiredSendBufferSize {
return fmt.Errorf("failed to sufficiently increase send buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024, newSize/1024)
}
logger.Debugf("Increased send buffer size to %d kiB", newSize/1024)
return nil
}

View File

@ -2,11 +2,13 @@
package quic package quic
import "syscall" import (
"syscall"
)
func setDF(rawConn syscall.RawConn) error { func setDF(syscall.RawConn) (bool, error) {
// no-op on unsupported platforms // no-op on unsupported platforms
return nil return false, nil
} }
func isMsgSizeErr(err error) bool { func isMsgSizeErr(err error) bool {

View File

@ -4,14 +4,23 @@ package quic
import ( import (
"errors" "errors"
"log"
"os"
"strconv"
"syscall" "syscall"
"unsafe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
) )
func setDF(rawConn syscall.RawConn) error { // UDP_SEGMENT controls GSO (Generic Segmentation Offload)
//
//nolint:stylecheck
const UDP_SEGMENT = 103
func setDF(rawConn syscall.RawConn) (bool, error) {
// Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long" // Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long"
// and the datagram will not be fragmented // and the datagram will not be fragmented
var errDFIPv4, errDFIPv6 error var errDFIPv4, errDFIPv6 error
@ -19,7 +28,7 @@ func setDF(rawConn syscall.RawConn) error {
errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO)
errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO) errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO)
}); err != nil { }); err != nil {
return err return false, err
} }
switch { switch {
case errDFIPv4 == nil && errDFIPv6 == nil: case errDFIPv4 == nil && errDFIPv6 == nil:
@ -29,12 +38,46 @@ func setDF(rawConn syscall.RawConn) error {
case errDFIPv4 != nil && errDFIPv6 == nil: case errDFIPv4 != nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv6.") utils.DefaultLogger.Debugf("Setting DF for IPv6.")
case errDFIPv4 != nil && errDFIPv6 != nil: case errDFIPv4 != nil && errDFIPv6 != nil:
return errors.New("setting DF failed for both IPv4 and IPv6") return false, errors.New("setting DF failed for both IPv4 and IPv6")
} }
return nil return true, nil
}
func maybeSetGSO(rawConn syscall.RawConn) bool {
enable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_ENABLE_GSO"))
if !enable {
return false
}
var setErr error
if err := rawConn.Control(func(fd uintptr) {
setErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_UDP, UDP_SEGMENT, 1)
}); err != nil {
setErr = err
}
if setErr != nil {
log.Println("failed to enable GSO")
return false
}
return true
} }
func isMsgSizeErr(err error) bool { func isMsgSizeErr(err error) bool {
// https://man7.org/linux/man-pages/man7/udp.7.html // https://man7.org/linux/man-pages/man7/udp.7.html
return errors.Is(err, unix.EMSGSIZE) return errors.Is(err, unix.EMSGSIZE)
} }
func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte {
startLen := len(b)
const dataLen = 2 // payload is a uint16
b = append(b, make([]byte, unix.CmsgSpace(dataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_UDP
h.Type = UDP_SEGMENT
h.SetLen(unix.CmsgLen(dataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
*(*uint16)(unsafe.Pointer(&b[offset])) = size
return b
}

View File

@ -12,20 +12,23 @@ import (
) )
const ( const (
// same for both IPv4 and IPv6 on Windows // IP_DONTFRAGMENT controls the Don't Fragment (DF) bit.
//
// It's the same code point for both IPv4 and IPv6 on Windows.
// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html
// https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html
//
//nolint:stylecheck
IP_DONTFRAGMENT = 14 IP_DONTFRAGMENT = 14
IPV6_DONTFRAG = 14
) )
func setDF(rawConn syscall.RawConn) error { func setDF(rawConn syscall.RawConn) (bool, error) {
var errDFIPv4, errDFIPv6 error var errDFIPv4, errDFIPv6 error
if err := rawConn.Control(func(fd uintptr) { if err := rawConn.Control(func(fd uintptr) {
errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1)
errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IP_DONTFRAGMENT, 1)
}); err != nil { }); err != nil {
return err return false, err
} }
switch { switch {
case errDFIPv4 == nil && errDFIPv6 == nil: case errDFIPv4 == nil && errDFIPv6 == nil:
@ -35,9 +38,9 @@ func setDF(rawConn syscall.RawConn) error {
case errDFIPv4 != nil && errDFIPv6 == nil: case errDFIPv4 != nil && errDFIPv6 == nil:
utils.DefaultLogger.Debugf("Setting DF for IPv6.") utils.DefaultLogger.Debugf("Setting DF for IPv6.")
case errDFIPv4 != nil && errDFIPv6 != nil: case errDFIPv4 != nil && errDFIPv6 != nil:
return errors.New("setting DF failed for both IPv4 and IPv6") return false, errors.New("setting DF failed for both IPv4 and IPv6")
} }
return nil return true, nil
} }
func isMsgSizeErr(err error) bool { func isMsgSizeErr(err error) bool {

View File

@ -2,7 +2,11 @@
package quic package quic
import "golang.org/x/sys/unix" import (
"syscall"
"golang.org/x/sys/unix"
)
const msgTypeIPTOS = unix.IP_TOS const msgTypeIPTOS = unix.IP_TOS
@ -17,3 +21,23 @@ const (
) )
const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed)
func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error {
var serr error
if err := c.Control(func(fd uintptr) {
serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bytes)
}); err != nil {
return err
}
return serr
}
func forceSetSendBuffer(c syscall.RawConn, bytes int) error {
var serr error
if err := c.Control(func(fd uintptr) {
serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, bytes)
}); err != nil {
return err
}
return serr
}

View File

@ -0,0 +1,6 @@
//go:build !linux
package quic
func forceSetReceiveBuffer(c any, bytes int) error { return nil }
func forceSetSendBuffer(c any, bytes int) error { return nil }

8
vendor/github.com/quic-go/quic-go/sys_conn_no_gso.go generated vendored Normal file
View File

@ -0,0 +1,8 @@
//go:build darwin || freebsd
package quic
import "syscall"
func maybeSetGSO(_ syscall.RawConn) bool { return false }
func appendUDPSegmentSizeMsg(_ []byte, _ uint16) []byte { return nil }

View File

@ -2,14 +2,20 @@
package quic package quic
import "net" import (
"net"
"net/netip"
)
func newConn(c net.PacketConn) (rawConn, error) { func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) {
return &basicConn{PacketConn: c}, nil return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil
} }
func inspectReadBuffer(interface{}) (int, error) { func inspectReadBuffer(any) (int, error) { return 0, nil }
return 0, nil func inspectWriteBuffer(any) (int, error) { return 0, nil }
type packetInfo struct {
addr netip.Addr
} }
func (i *packetInfo) OOB() []byte { return nil } func (i *packetInfo) OOB() []byte { return nil }

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@ -32,20 +33,10 @@ type batchConn interface {
ReadBatch(ms []ipv4.Message, flags int) (int, error) ReadBatch(ms []ipv4.Message, flags int) (int, error)
} }
func inspectReadBuffer(c interface{}) (int, error) { func inspectReadBuffer(c syscall.RawConn) (int, error) {
conn, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return 0, errors.New("doesn't have a SyscallConn")
}
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
}
var size int var size int
var serr error var serr error
if err := rawConn.Control(func(fd uintptr) { if err := c.Control(func(fd uintptr) {
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
}); err != nil { }); err != nil {
return 0, err return 0, err
@ -53,6 +44,17 @@ func inspectReadBuffer(c interface{}) (int, error) {
return size, serr return size, serr
} }
func inspectWriteBuffer(c syscall.RawConn) (int, error) {
var size int
var serr error
if err := c.Control(func(fd uintptr) {
size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
}); err != nil {
return 0, err
}
return size, serr
}
type oobConn struct { type oobConn struct {
OOBCapablePacketConn OOBCapablePacketConn
batchConn batchConn batchConn batchConn
@ -61,11 +63,13 @@ type oobConn struct {
// Packets received from the kernel, but not yet returned by ReadPacket(). // Packets received from the kernel, but not yet returned by ReadPacket().
messages []ipv4.Message messages []ipv4.Message
buffers [batchSize]*packetBuffer buffers [batchSize]*packetBuffer
cap connCapabilities
} }
var _ rawConn = &oobConn{} var _ rawConn = &oobConn{}
func newConn(c OOBCapablePacketConn) (*oobConn, error) { func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
rawConn, err := c.SyscallConn() rawConn, err := c.SyscallConn()
if err != nil { if err != nil {
return nil, err return nil, err
@ -122,6 +126,10 @@ func newConn(c OOBCapablePacketConn) (*oobConn, error) {
bc = ipv4.NewPacketConn(c) bc = ipv4.NewPacketConn(c)
} }
// Try enabling GSO.
// This will only succeed on Linux, and only for kernels > 4.18.
supportsGSO := maybeSetGSO(rawConn)
msgs := make([]ipv4.Message, batchSize) msgs := make([]ipv4.Message, batchSize)
for i := range msgs { for i := range msgs {
// preallocate the [][]byte // preallocate the [][]byte
@ -133,13 +141,15 @@ func newConn(c OOBCapablePacketConn) (*oobConn, error) {
messages: msgs, messages: msgs,
readPos: batchSize, readPos: batchSize,
} }
oobConn.cap.DF = supportsDF
oobConn.cap.GSO = supportsGSO
for i := 0; i < batchSize; i++ { for i := 0; i < batchSize; i++ {
oobConn.messages[i].OOB = make([]byte, oobBufferSize) oobConn.messages[i].OOB = make([]byte, oobBufferSize)
} }
return oobConn, nil return oobConn, nil
} }
func (c *oobConn) ReadPacket() (*receivedPacket, error) { func (c *oobConn) ReadPacket() (receivedPacket, error) {
if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
c.messages = c.messages[:batchSize] c.messages = c.messages[:batchSize]
// replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call
@ -153,7 +163,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
n, err := c.batchConn.ReadBatch(c.messages, 0) n, err := c.batchConn.ReadBatch(c.messages, 0)
if n == 0 || err != nil { if n == 0 || err != nil {
return nil, err return receivedPacket{}, err
} }
c.messages = c.messages[:n] c.messages = c.messages[:n]
} }
@ -163,18 +173,21 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
c.readPos++ c.readPos++
data := msg.OOB[:msg.NN] data := msg.OOB[:msg.NN]
var ecn protocol.ECN p := receivedPacket{
var destIP net.IP remoteAddr: msg.Addr,
var ifIndex uint32 rcvTime: time.Now(),
data: msg.Buffers[0][:msg.N],
buffer: buffer,
}
for len(data) > 0 { for len(data) > 0 {
hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data)
if err != nil { if err != nil {
return nil, err return receivedPacket{}, err
} }
if hdr.Level == unix.IPPROTO_IP { if hdr.Level == unix.IPPROTO_IP {
switch hdr.Type { switch hdr.Type {
case msgTypeIPTOS: case msgTypeIPTOS:
ecn = protocol.ECN(body[0] & ecnMask) p.ecn = protocol.ECN(body[0] & ecnMask)
case msgTypeIPv4PKTINFO: case msgTypeIPv4PKTINFO:
// struct in_pktinfo { // struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */ // unsigned int ipi_ifindex; /* Interface index */
@ -182,80 +195,94 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) {
// struct in_addr ipi_addr; /* Header Destination // struct in_addr ipi_addr; /* Header Destination
// address */ // address */
// }; // };
ip := make([]byte, 4) var ip [4]byte
if len(body) == 12 { if len(body) == 12 {
ifIndex = binary.LittleEndian.Uint32(body) copy(ip[:], body[8:12])
copy(ip, body[8:12]) p.info.ifIndex = binary.LittleEndian.Uint32(body)
} else if len(body) == 4 { } else if len(body) == 4 {
// FreeBSD // FreeBSD
copy(ip, body) copy(ip[:], body)
} }
destIP = net.IP(ip) p.info.addr = netip.AddrFrom4(ip)
} }
} }
if hdr.Level == unix.IPPROTO_IPV6 { if hdr.Level == unix.IPPROTO_IPV6 {
switch hdr.Type { switch hdr.Type {
case unix.IPV6_TCLASS: case unix.IPV6_TCLASS:
ecn = protocol.ECN(body[0] & ecnMask) p.ecn = protocol.ECN(body[0] & ecnMask)
case msgTypeIPv6PKTINFO: case msgTypeIPv6PKTINFO:
// struct in6_pktinfo { // struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */ // unsigned int ipi6_ifindex; /* send/recv interface index */
// }; // };
if len(body) == 20 { if len(body) == 20 {
ip := make([]byte, 16) var ip [16]byte
copy(ip, body[:16]) copy(ip[:], body[:16])
destIP = net.IP(ip) p.info.addr = netip.AddrFrom16(ip)
ifIndex = binary.LittleEndian.Uint32(body[16:]) p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
} }
} }
} }
data = remainder data = remainder
} }
var info *packetInfo return p, nil
if destIP != nil {
info = &packetInfo{
addr: destIP,
ifIndex: ifIndex,
}
}
return &receivedPacket{
remoteAddr: msg.Addr,
rcvTime: time.Now(),
data: msg.Buffers[0][:msg.N],
ecn: ecn,
info: info,
buffer: buffer,
}, nil
} }
func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) { // WriteTo (re)implements the net.PacketConn method.
// This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection.
// With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set.
func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return c.WritePacket(p, uint16(len(p)), addr, nil)
}
// WritePacket writes a new packet.
// If the connection supports GSO (and we activated GSO support before),
// it appends the UDP_SEGMENT size message to oob.
// Callers are advised to make sure that oob has a sufficient capacity,
// such that appending the UDP_SEGMENT size message doesn't cause an allocation.
func (c *oobConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, oob []byte) (n int, err error) {
if c.cap.GSO {
oob = appendUDPSegmentSizeMsg(oob, packetSize)
} else if uint16(len(b)) != packetSize {
panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b)))
}
n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
return n, err return n, err
} }
func (c *oobConn) capabilities() connCapabilities {
return c.cap
}
type packetInfo struct {
addr netip.Addr
ifIndex uint32
}
func (info *packetInfo) OOB() []byte { func (info *packetInfo) OOB() []byte {
if info == nil { if info == nil {
return nil return nil
} }
if ip4 := info.addr.To4(); ip4 != nil { if info.addr.Is4() {
ip := info.addr.As4()
// struct in_pktinfo { // struct in_pktinfo {
// unsigned int ipi_ifindex; /* Interface index */ // unsigned int ipi_ifindex; /* Interface index */
// struct in_addr ipi_spec_dst; /* Local address */ // struct in_addr ipi_spec_dst; /* Local address */
// struct in_addr ipi_addr; /* Header Destination address */ // struct in_addr ipi_addr; /* Header Destination address */
// }; // };
cm := ipv4.ControlMessage{ cm := ipv4.ControlMessage{
Src: ip4, Src: ip[:],
IfIndex: int(info.ifIndex), IfIndex: int(info.ifIndex),
} }
return cm.Marshal() return cm.Marshal()
} else if len(info.addr) == 16 { } else if info.addr.Is6() {
ip := info.addr.As16()
// struct in6_pktinfo { // struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // struct in6_addr ipi6_addr; /* src/dst IPv6 address */
// unsigned int ipi6_ifindex; /* send/recv interface index */ // unsigned int ipi6_ifindex; /* send/recv interface index */
// }; // };
cm := ipv6.ControlMessage{ cm := ipv6.ControlMessage{
Src: info.addr, Src: ip[:],
IfIndex: int(info.ifIndex), IfIndex: int(info.ifIndex),
} }
return cm.Marshal() return cm.Marshal()

View File

@ -3,32 +3,20 @@
package quic package quic
import ( import (
"errors" "net/netip"
"fmt"
"net"
"syscall" "syscall"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
func newConn(c OOBCapablePacketConn) (rawConn, error) { func newConn(c OOBCapablePacketConn, supportsDF bool) (*basicConn, error) {
return &basicConn{PacketConn: c}, nil return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil
} }
func inspectReadBuffer(c net.PacketConn) (int, error) { func inspectReadBuffer(c syscall.RawConn) (int, error) {
conn, ok := c.(interface {
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return 0, errors.New("doesn't have a SyscallConn")
}
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
}
var size int var size int
var serr error var serr error
if err := rawConn.Control(func(fd uintptr) { if err := c.Control(func(fd uintptr) {
size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF) size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF)
}); err != nil { }); err != nil {
return 0, err return 0, err
@ -36,4 +24,19 @@ func inspectReadBuffer(c net.PacketConn) (int, error) {
return size, serr return size, serr
} }
func inspectWriteBuffer(c syscall.RawConn) (int, error) {
var size int
var serr error
if err := c.Control(func(fd uintptr) {
size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF)
}); err != nil {
return 0, err
}
return size, serr
}
type packetInfo struct {
addr netip.Addr
}
func (i *packetInfo) OOB() []byte { return nil } func (i *packetInfo) OOB() []byte { return nil }

435
vendor/github.com/quic-go/quic-go/transport.go generated vendored Normal file
View File

@ -0,0 +1,435 @@
package quic
import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
// The Transport is the central point to manage incoming and outgoing QUIC connections.
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
// This means that a single UDP socket can be used for listening for incoming connections, as well as
// for dialing an arbitrary number of outgoing connections.
// A Transport handles a single net.PacketConn, and offers a range of configuration options
// compared to the simple helper functions like Listen and Dial that this package provides.
type Transport struct {
// A single net.PacketConn can only be handled by one Transport.
// Bad things will happen if passed to multiple Transports.
//
// If not done by the user, the connection is passed through OptimizeConn to enable a number of optimizations.
// After passing the connection to the Transport, it's invalid to call ReadFrom on the connection.
// Calling WriteTo is only valid on the connection returned by OptimizeConn.
Conn net.PacketConn
// The length of the connection ID in bytes.
// It can be 0, or any value between 4 and 18.
// If unset, a 4 byte connection ID will be used.
ConnectionIDLength int
// Use for generating new connection IDs.
// This allows the application to control of the connection IDs used,
// which allows routing / load balancing based on connection IDs.
// All Connection IDs returned by the ConnectionIDGenerator MUST
// have the same length.
ConnectionIDGenerator ConnectionIDGenerator
// The StatelessResetKey is used to generate stateless reset tokens.
// If no key is configured, sending of stateless resets is disabled.
// It is highly recommended to configure a stateless reset key, as stateless resets
// allow the peer to quickly recover from crashes and reboots of this node.
// See section 10.3 of RFC 9000 for details.
StatelessResetKey *StatelessResetKey
// A Tracer traces events that don't belong to a single QUIC connection.
Tracer logging.Tracer
handlerMap packetHandlerManager
mutex sync.Mutex
initOnce sync.Once
initErr error
// Set in init.
// If no ConnectionIDGenerator is set, this is the ConnectionIDLength.
connIDLen int
// Set in init.
// If no ConnectionIDGenerator is set, this is set to a default.
connIDGenerator ConnectionIDGenerator
server unknownPacketHandler
conn rawConn
closeQueue chan closePacket
statelessResetQueue chan receivedPacket
listening chan struct{} // is closed when listen returns
closed bool
createdConn bool
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
logger utils.Logger
}
// Listen starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed.
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(conf); err != nil {
return nil, err
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(true); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
if err != nil {
return nil, err
}
t.server = s
return &Listener{baseServer: s}, nil
}
// ListenEarly starts listening for incoming QUIC connections.
// There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed.
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(conf); err != nil {
return nil, err
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(true); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
if err != nil {
return nil, err
}
t.server = s
return &EarlyListener{baseServer: s}, nil
}
// Dial dials a new connection to a remote host (not using 0-RTT).
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
}
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
}
func (t *Transport) init(isServer bool) error {
t.initOnce.Do(func() {
getMultiplexer().AddConn(t.Conn)
var conn rawConn
if c, ok := t.Conn.(rawConn); ok {
conn = c
} else {
var err error
conn, err = wrapConn(t.Conn)
if err != nil {
t.initErr = err
return
}
}
t.conn = conn
t.logger = utils.DefaultLogger // TODO: make this configurable
t.conn = conn
t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger)
t.listening = make(chan struct{})
t.closeQueue = make(chan closePacket, 4)
t.statelessResetQueue = make(chan receivedPacket, 4)
if t.ConnectionIDGenerator != nil {
t.connIDGenerator = t.ConnectionIDGenerator
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
} else {
connIDLen := t.ConnectionIDLength
if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
connIDLen = protocol.DefaultConnectionIDLength
}
t.connIDLen = connIDLen
t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen}
}
go t.listen(conn)
go t.runSendQueue()
})
return t.initErr
}
func (t *Transport) enqueueClosePacket(p closePacket) {
select {
case t.closeQueue <- p:
default:
// Oops, we're backlogged.
// Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway.
}
}
func (t *Transport) runSendQueue() {
for {
select {
case <-t.listening:
return
case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB())
case p := <-t.statelessResetQueue:
t.sendStatelessReset(p)
}
}
}
// Close closes the underlying connection and waits until listen has returned.
// It is invalid to start new listeners or connections after that.
func (t *Transport) Close() error {
t.close(errors.New("closing"))
if t.createdConn {
if err := t.Conn.Close(); err != nil {
return err
}
} else if t.conn != nil {
t.conn.SetReadDeadline(time.Now())
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
}
if t.listening != nil {
<-t.listening // wait until listening returns
}
return nil
}
func (t *Transport) closeServer() {
t.handlerMap.CloseServer()
t.mutex.Lock()
t.server = nil
if t.isSingleUse {
t.closed = true
}
t.mutex.Unlock()
if t.createdConn {
t.Conn.Close()
}
if t.isSingleUse {
t.conn.SetReadDeadline(time.Now())
defer func() { t.conn.SetReadDeadline(time.Time{}) }()
<-t.listening // wait until listening returns
}
}
func (t *Transport) close(e error) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.closed {
return
}
if t.handlerMap != nil {
t.handlerMap.Close(e)
}
if t.server != nil {
t.server.setCloseError(e)
}
t.closed = true
}
// only print warnings about the UDP receive buffer size once
var setBufferWarningOnce sync.Once
func (t *Transport) listen(conn rawConn) {
defer close(t.listening)
defer getMultiplexer().RemoveConn(t.Conn)
if err := setReceiveBuffer(t.Conn, t.logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err)
})
}
}
if err := setSendBuffer(t.Conn, t.logger); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {
if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable {
return
}
log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err)
})
}
}
for {
p, err := conn.ReadPacket()
//nolint:staticcheck // SA1019 ignore this!
// TODO: This code is used to ignore wsa errors on Windows.
// Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution.
// See https://github.com/quic-go/quic-go/issues/1737 for details.
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
t.mutex.Lock()
closed := t.closed
t.mutex.Unlock()
if closed {
return
}
t.logger.Debugf("Temporary error reading from conn: %w", err)
continue
}
if err != nil {
t.close(err)
return
}
t.handlePacket(p)
}
}
func (t *Transport) handlePacket(p receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil {
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
if t.Tracer != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
}
p.buffer.MaybeRelease()
return
}
if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset {
return
}
if handler, ok := t.handlerMap.Get(connID); ok {
handler.handlePacket(p)
return
}
if !wire.IsLongHeaderPacket(p.data[0]) {
t.maybeSendStatelessReset(p)
return
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server == nil { // no server set
t.logger.Debugf("received a packet with an unexpected connection ID %s", connID)
return
}
t.server.handlePacket(p)
}
func (t *Transport) maybeSendStatelessReset(p receivedPacket) {
if t.StatelessResetKey == nil {
p.buffer.Release()
return
}
// Don't send a stateless reset in response to very small packets.
// This includes packets that could be stateless resets.
if len(p.data) <= protocol.MinStatelessResetSize {
p.buffer.Release()
return
}
select {
case t.statelessResetQueue <- p:
default:
// it's fine to not send a stateless reset when we're busy
p.buffer.Release()
}
}
func (t *Transport) sendStatelessReset(p receivedPacket) {
defer p.buffer.Release()
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil {
t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
return
}
token := t.handlerMap.GetStatelessResetToken(connID)
t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token)
data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize)
rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil {
t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
}
}
func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
return false
}
token := *(*protocol.StatelessResetToken)(data[len(data)-16:])
if conn, ok := t.handlerMap.GetByResetToken(token); ok {
t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token)
go conn.destroy(&StatelessResetError{Token: token})
return true
}
return false
}

View File

@ -1,34 +0,0 @@
package quic
import (
"time"
"github.com/quic-go/quic-go/internal/protocol"
)
type zeroRTTQueue struct {
queue []*receivedPacket
retireTimer *time.Timer
}
var _ packetHandler = &zeroRTTQueue{}
func (h *zeroRTTQueue) handlePacket(p *receivedPacket) {
if len(h.queue) < protocol.Max0RTTQueueLen {
h.queue = append(h.queue, p)
}
}
func (h *zeroRTTQueue) shutdown() {}
func (h *zeroRTTQueue) destroy(error) {}
func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient }
func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) {
for _, p := range h.queue {
sess.handlePacket(p)
}
}
func (h *zeroRTTQueue) Clear() {
for _, p := range h.queue {
p.buffer.Release()
}
}

Some files were not shown because too many files have changed in this diff Show More