From 89ccc59f0e3b812012cfd34bcde44f1c0750b8f6 Mon Sep 17 00:00:00 2001 From: YX Hao Date: Mon, 3 Jul 2023 18:37:42 +0800 Subject: [PATCH] Upgrade quic-go to v0.36.1 quic-go has made breaking changes since v0.35.0, includes implementing `CloseIdleConnections`. Now, the local listener UDPConn are reused, and don't pile up. But, 1 instance (IPv4/IPv6) persists for each connected server. --- dnscrypt-proxy/xtransport.go | 44 +- go.mod | 12 +- go.sum | 36 +- .../onsi/ginkgo/v2/formatter/formatter.go | 61 +- .../ginkgo/v2/ginkgo/build/build_command.go | 2 + .../v2/ginkgo/generators/bootstrap_command.go | 24 +- .../v2/ginkgo/generators/generate_command.go | 24 +- .../v2/ginkgo/generators/generators_common.go | 1 + .../onsi/ginkgo/v2/ginkgo/internal/compile.go | 11 +- .../onsi/ginkgo/v2/ginkgo/internal/run.go | 7 + .../v2/ginkgo/internal/verify_version.go | 54 ++ .../onsi/ginkgo/v2/ginkgo/outline/ginkgo.go | 90 ++- .../onsi/ginkgo/v2/ginkgo/outline/import.go | 2 +- .../onsi/ginkgo/v2/ginkgo/outline/outline.go | 11 +- .../onsi/ginkgo/v2/ginkgo/run/run_command.go | 6 +- .../ginkgo/v2/ginkgo/watch/watch_command.go | 8 +- .../interrupt_handler/interrupt_handler.go | 189 ++--- .../parallel_support/client_server.go | 2 + .../internal/parallel_support/http_client.go | 13 + .../internal/parallel_support/http_server.go | 29 +- .../internal/parallel_support/rpc_client.go | 13 + .../parallel_support/server_handler.go | 55 +- .../ginkgo/v2/reporters/default_reporter.go | 640 ++++++++++----- .../v2/reporters/deprecated_reporter.go | 2 +- .../onsi/ginkgo/v2/reporters/junit_report.go | 123 +-- .../onsi/ginkgo/v2/reporters/reporter.go | 18 +- .../ginkgo/v2/reporters/teamcity_report.go | 12 +- .../onsi/ginkgo/v2/types/code_location.go | 77 +- .../github.com/onsi/ginkgo/v2/types/config.go | 61 +- .../ginkgo/v2/types/deprecation_support.go | 9 +- .../github.com/onsi/ginkgo/v2/types/errors.go | 103 ++- .../onsi/ginkgo/v2/types/label_filter.go | 11 + .../onsi/ginkgo/v2/types/report_entry.go | 14 +- .../github.com/onsi/ginkgo/v2/types/types.go | 364 +++++++-- .../onsi/ginkgo/v2/types/version.go | 2 +- vendor/github.com/quic-go/quic-go/README.md | 212 ++++- .../github.com/quic-go/quic-go/buffer_pool.go | 34 +- vendor/github.com/quic-go/quic-go/client.go | 337 +++----- .../github.com/quic-go/quic-go/closed_conn.go | 8 +- vendor/github.com/quic-go/quic-go/config.go | 52 +- .../github.com/quic-go/quic-go/connection.go | 499 +++++++----- vendor/github.com/quic-go/quic-go/framer.go | 49 +- .../quic-go/quic-go/http3/client.go | 17 +- .../quic-go/quic-go/http3/mockgen.go | 2 + .../quic-go/quic-go/http3/response_writer.go | 11 + .../quic-go/quic-go/http3/roundtrip.go | 51 +- .../quic-go/quic-go/http3/server.go | 76 +- .../github.com/quic-go/quic-go/interface.go | 74 +- .../internal/ackhandler/ack_eliciting.go | 2 +- .../quic-go/internal/ackhandler/frame.go | 30 +- .../quic-go/internal/ackhandler/interfaces.go | 10 +- .../quic-go/internal/ackhandler/packet.go | 22 +- .../ackhandler/packet_number_generator.go | 24 +- .../ackhandler/received_packet_tracker.go | 10 +- .../quic-go/internal/ackhandler/send_mode.go | 6 + .../ackhandler/sent_packet_handler.go | 226 +++--- .../ackhandler/sent_packet_history.go | 248 +++--- .../internal/congestion/cubic_sender.go | 4 +- .../quic-go/internal/congestion/interface.go | 2 +- .../quic-go/internal/congestion/pacer.go | 19 +- .../internal/handshake/crypto_setup.go | 10 +- .../quic-go/internal/protocol/params.go | 3 + .../quic-go/internal/protocol/protocol.go | 8 +- .../internal/utils/ringbuffer/ringbuffer.go | 86 ++ .../quic-go/internal/utils/rtt_stats.go | 6 +- .../quic-go/internal/wire/ack_frame.go | 57 +- .../quic-go/internal/wire/ack_frame_pool.go | 24 - .../quic-go/internal/wire/frame_parser.go | 12 +- .../quic-go/quic-go/internal/wire/header.go | 2 - .../internal/wire/transport_parameters.go | 35 +- .../quic-go/quic-go/logging/interface.go | 7 - .../quic-go/quic-go/logging/multiplex.go | 11 - .../quic-go/quic-go/logging/null_tracer.go | 4 - vendor/github.com/quic-go/quic-go/mockgen.go | 3 - .../quic-go/quic-go/mtu_discoverer.go | 83 +- .../github.com/quic-go/quic-go/multiplexer.go | 69 +- .../quic-go/quic-go/packet_handler_map.go | 364 ++------- .../quic-go/quic-go/packet_packer.go | 379 ++++----- .../quic-go/quic-go/quicvarint/varint.go | 19 - .../quic-go/quic-go/receive_stream.go | 4 + .../quic-go/quic-go/retransmission_queue.go | 57 +- .../github.com/quic-go/quic-go/send_conn.go | 64 +- .../github.com/quic-go/quic-go/send_queue.go | 23 +- .../github.com/quic-go/quic-go/send_stream.go | 95 +-- vendor/github.com/quic-go/quic-go/server.go | 570 ++++++++----- vendor/github.com/quic-go/quic-go/stream.go | 2 +- vendor/github.com/quic-go/quic-go/sys_conn.go | 48 +- .../quic-go/quic-go/sys_conn_buffers.go | 68 ++ .../quic-go/quic-go/sys_conn_buffers_write.go | 70 ++ .../github.com/quic-go/quic-go/sys_conn_df.go | 8 +- .../quic-go/quic-go/sys_conn_df_linux.go | 51 +- .../quic-go/quic-go/sys_conn_df_windows.go | 17 +- .../quic-go/quic-go/sys_conn_helper_linux.go | 26 +- .../quic-go/sys_conn_helper_nonlinux.go | 6 + .../quic-go/quic-go/sys_conn_no_gso.go | 8 + .../quic-go/quic-go/sys_conn_no_oob.go | 16 +- .../quic-go/quic-go/sys_conn_oob.go | 127 +-- .../quic-go/quic-go/sys_conn_windows.go | 37 +- .../github.com/quic-go/quic-go/transport.go | 435 ++++++++++ .../quic-go/quic-go/zero_rtt_queue.go | 34 - .../x/tools/go/types/objectpath/objectpath.go | 764 ++++++++++++++++++ .../x/tools/internal/gcimporter/gcimporter.go | 12 + .../x/tools/internal/gcimporter/iexport.go | 8 +- .../x/tools/internal/gcimporter/iimport.go | 34 +- .../tools/internal/gcimporter/ureader_yes.go | 41 +- .../x/tools/internal/gocommand/invoke.go | 128 ++- .../x/tools/internal/gocommand/version.go | 18 +- .../x/tools/internal/imports/fix.go | 28 +- .../internal/tokeninternal/tokeninternal.go | 92 +++ .../x/tools/internal/typeparams/common.go | 1 - .../x/tools/internal/typesinternal/types.go | 9 + vendor/modules.txt | 14 +- 112 files changed, 5632 insertions(+), 2620 deletions(-) create mode 100644 vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/verify_version.go create mode 100644 vendor/github.com/quic-go/quic-go/internal/utils/ringbuffer/ringbuffer.go delete mode 100644 vendor/github.com/quic-go/quic-go/internal/wire/ack_frame_pool.go create mode 100644 vendor/github.com/quic-go/quic-go/sys_conn_buffers.go create mode 100644 vendor/github.com/quic-go/quic-go/sys_conn_buffers_write.go create mode 100644 vendor/github.com/quic-go/quic-go/sys_conn_helper_nonlinux.go create mode 100644 vendor/github.com/quic-go/quic-go/sys_conn_no_gso.go create mode 100644 vendor/github.com/quic-go/quic-go/transport.go delete mode 100644 vendor/github.com/quic-go/quic-go/zero_rtt_queue.go create mode 100644 vendor/golang.org/x/tools/go/types/objectpath/objectpath.go diff --git a/dnscrypt-proxy/xtransport.go b/dnscrypt-proxy/xtransport.go index b30d362e..2adab72a 100644 --- a/dnscrypt-proxy/xtransport.go +++ b/dnscrypt-proxy/xtransport.go @@ -57,7 +57,6 @@ type AltSupport struct { type XTransport struct { transport *http.Transport h3Transport *http3.RoundTripper - h3UDPConn *net.UDPConn keepAlive time.Duration timeout time.Duration cachedIPs CachedIPs @@ -138,15 +137,7 @@ func (xTransport *XTransport) loadCachedIP(host string) (ip net.IP, expired bool func (xTransport *XTransport) rebuildTransport() { dlog.Debug("Rebuilding transport") if xTransport.transport != nil { - if xTransport.h3Transport != nil { - xTransport.h3Transport.Close() - if xTransport.h3UDPConn != nil { - xTransport.h3UDPConn.Close() - xTransport.h3UDPConn = nil - } - } else { - xTransport.transport.CloseIdleConnections() - } + xTransport.transport.CloseIdleConnections() } timeout := xTransport.timeout transport := &http.Transport{ @@ -272,27 +263,35 @@ func (xTransport *XTransport) rebuildTransport() { host, port := ExtractHostAndPort(addrStr, stamps.DefaultPort) ipOnly := host cachedIP, _ := xTransport.loadCachedIP(host) + network := "udp4" if cachedIP != nil { if ipv4 := cachedIP.To4(); ipv4 != nil { ipOnly = ipv4.String() } else { ipOnly = "[" + cachedIP.String() + "]" + network = "udp6" } } else { - dlog.Debugf("[%s] IP address was not cached in H3 DialContext", host) + dlog.Debugf("[%s] IP address was not cached in H3 context", host) + if xTransport.useIPv6 { + if xTransport.useIPv4 { + network = "udp" + } else { + network = "udp6" + } + } } addrStr = ipOnly + ":" + strconv.Itoa(port) - udpAddr, err := net.ResolveUDPAddr("udp", addrStr) + udpAddr, err := net.ResolveUDPAddr(network, addrStr) if err != nil { return nil, err } - if xTransport.h3UDPConn == nil { - xTransport.h3UDPConn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } + udpConn, err := net.ListenUDP(network, nil) + if err != nil { + return nil, err } - return quic.DialEarlyContext(ctx, xTransport.h3UDPConn, udpAddr, host, tlsCfg, cfg) + tlsCfg.ServerName = host + return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) }} xTransport.h3Transport = h3Transport } @@ -545,13 +544,8 @@ func (xTransport *XTransport) Fetch( err = errors.New(resp.Status) } } else { - if hasAltSupport { - dlog.Debugf("HTTP client error: [%v] - closing H3 connections", err) - xTransport.h3Transport.Close() - } else { - dlog.Debugf("HTTP client error: [%v] - closing idle connections", err) - xTransport.transport.CloseIdleConnections() - } + dlog.Debugf("HTTP client error: [%v] - closing idle connections", err) + xTransport.transport.CloseIdleConnections() } statusCode := 503 if resp != nil { diff --git a/go.mod b/go.mod index 7c101bad..46309d53 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/kardianos/service v1.2.2 github.com/miekg/dns v1.1.55 github.com/powerman/check v1.7.0 - github.com/quic-go/quic-go v0.34.0 + github.com/quic-go/quic-go v0.36.1 golang.org/x/crypto v0.10.0 golang.org/x/net v0.11.0 golang.org/x/sys v0.9.0 @@ -29,12 +29,12 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/hashicorp/go-syslog v1.0.0 // indirect - github.com/onsi/ginkgo/v2 v2.2.0 // indirect + github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/powerman/deepequal v0.1.0 // indirect @@ -43,9 +43,9 @@ require ( github.com/quic-go/qtls-go1-20 v0.2.2 // indirect github.com/smartystreets/goconvey v1.7.2 // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect - golang.org/x/mod v0.8.0 // indirect + golang.org/x/mod v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect - golang.org/x/tools v0.6.0 // indirect + golang.org/x/tools v0.9.1 // indirect google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f // indirect google.golang.org/grpc v1.53.0 // indirect google.golang.org/protobuf v1.30.0 // indirect diff --git a/go.sum b/go.sum index 70be956c..3ed4f389 100644 --- a/go.sum +++ b/go.sum @@ -12,13 +12,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185 h1:3T8ZyTDp5QxTx3NU48JVb2u+75xc040fofcBaN+6jPA= github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185/go.mod h1:cFRxtTwTOJkz2x3rQUNCYKWC93yP1VKjR8NUhqFxZNU= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= @@ -56,9 +57,9 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo= github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= -github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= -github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= -github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= +github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= +github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -73,15 +74,15 @@ github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc8 github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= -github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= -github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= +github.com/quic-go/quic-go v0.36.1 h1:WsG73nVtnDy1TiACxFxhQ3TqaW+DipmqzLEtNlAwZyY= +github.com/quic-go/quic-go v0.36.1/go.mod h1:zPetvwDlILVxt15n3hr3Gf/I3mDf7LpLKPhR4Ez0AZQ= github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs= github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg= github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -90,8 +91,8 @@ golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45 golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -100,7 +101,7 @@ golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -120,8 +121,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= +golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -137,6 +138,5 @@ google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/vendor/github.com/onsi/ginkgo/v2/formatter/formatter.go b/vendor/github.com/onsi/ginkgo/v2/formatter/formatter.go index 43b16211..743555dd 100644 --- a/vendor/github.com/onsi/ginkgo/v2/formatter/formatter.go +++ b/vendor/github.com/onsi/ginkgo/v2/formatter/formatter.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "regexp" + "strconv" "strings" ) @@ -50,6 +51,37 @@ func NewWithNoColorBool(noColor bool) Formatter { } func New(colorMode ColorMode) Formatter { + colorAliases := map[string]int{ + "black": 0, + "red": 1, + "green": 2, + "yellow": 3, + "blue": 4, + "magenta": 5, + "cyan": 6, + "white": 7, + } + for colorAlias, n := range colorAliases { + colorAliases[fmt.Sprintf("bright-%s", colorAlias)] = n + 8 + } + + getColor := func(color, defaultEscapeCode string) string { + color = strings.ToUpper(strings.ReplaceAll(color, "-", "_")) + envVar := fmt.Sprintf("GINKGO_CLI_COLOR_%s", color) + envVarColor := os.Getenv(envVar) + if envVarColor == "" { + return defaultEscapeCode + } + if colorCode, ok := colorAliases[envVarColor]; ok { + return fmt.Sprintf("\x1b[38;5;%dm", colorCode) + } + colorCode, err := strconv.Atoi(envVarColor) + if err != nil || colorCode < 0 || colorCode > 255 { + return defaultEscapeCode + } + return fmt.Sprintf("\x1b[38;5;%dm", colorCode) + } + f := Formatter{ ColorMode: colorMode, colors: map[string]string{ @@ -57,18 +89,18 @@ func New(colorMode ColorMode) Formatter { "bold": "\x1b[1m", "underline": "\x1b[4m", - "red": "\x1b[38;5;9m", - "orange": "\x1b[38;5;214m", - "coral": "\x1b[38;5;204m", - "magenta": "\x1b[38;5;13m", - "green": "\x1b[38;5;10m", - "dark-green": "\x1b[38;5;28m", - "yellow": "\x1b[38;5;11m", - "light-yellow": "\x1b[38;5;228m", - "cyan": "\x1b[38;5;14m", - "gray": "\x1b[38;5;243m", - "light-gray": "\x1b[38;5;246m", - "blue": "\x1b[38;5;12m", + "red": getColor("red", "\x1b[38;5;9m"), + "orange": getColor("orange", "\x1b[38;5;214m"), + "coral": getColor("coral", "\x1b[38;5;204m"), + "magenta": getColor("magenta", "\x1b[38;5;13m"), + "green": getColor("green", "\x1b[38;5;10m"), + "dark-green": getColor("dark-green", "\x1b[38;5;28m"), + "yellow": getColor("yellow", "\x1b[38;5;11m"), + "light-yellow": getColor("light-yellow", "\x1b[38;5;228m"), + "cyan": getColor("cyan", "\x1b[38;5;14m"), + "gray": getColor("gray", "\x1b[38;5;243m"), + "light-gray": getColor("light-gray", "\x1b[38;5;246m"), + "blue": getColor("blue", "\x1b[38;5;12m"), }, } colors := []string{} @@ -88,7 +120,10 @@ func (f Formatter) Fi(indentation uint, format string, args ...interface{}) stri } func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string { - out := fmt.Sprintf(f.style(format), args...) + out := f.style(format) + if len(args) > 0 { + out = fmt.Sprintf(out, args...) + } if indentation == 0 && maxWidth == 0 { return out diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go index f7d2eaf0..5db5d1a7 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/build/build_command.go @@ -39,6 +39,8 @@ func buildSpecs(args []string, cliConfig types.CLIConfig, goFlagsConfig types.Go command.AbortWith("Found no test suites") } + internal.VerifyCLIAndFrameworkVersion(suites) + opc := internal.NewOrderedParallelCompiler(cliConfig.ComputedNumCompilers()) opc.StartCompiling(suites, goFlagsConfig) diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/bootstrap_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/bootstrap_command.go index 0273abe9..73aff0b7 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/bootstrap_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/bootstrap_command.go @@ -2,6 +2,7 @@ package generators import ( "bytes" + "encoding/json" "fmt" "os" "text/template" @@ -25,6 +26,9 @@ func BuildBootstrapCommand() command.Command { {Name: "template", KeyPath: "CustomTemplate", UsageArgument: "template-file", Usage: "If specified, generate will use the contents of the file passed as the bootstrap template"}, + {Name: "template-data", KeyPath: "CustomTemplateData", + UsageArgument: "template-data-file", + Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the bootstrap template"}, }, &conf, types.GinkgoFlagSections{}, @@ -57,6 +61,7 @@ type bootstrapData struct { GomegaImport string GinkgoPackage string GomegaPackage string + CustomData map[string]any } func generateBootstrap(conf GeneratorsConfig) { @@ -95,17 +100,32 @@ func generateBootstrap(conf GeneratorsConfig) { tpl, err := os.ReadFile(conf.CustomTemplate) command.AbortIfError("Failed to read custom bootstrap file:", err) templateText = string(tpl) + if conf.CustomTemplateData != "" { + var tplCustomDataMap map[string]any + tplCustomData, err := os.ReadFile(conf.CustomTemplateData) + command.AbortIfError("Failed to read custom boostrap data file:", err) + if !json.Valid([]byte(tplCustomData)) { + command.AbortWith("Invalid JSON object in custom data file.") + } + //create map from the custom template data + json.Unmarshal(tplCustomData, &tplCustomDataMap) + data.CustomData = tplCustomDataMap + } } else if conf.Agouti { templateText = agoutiBootstrapText } else { templateText = bootstrapText } - bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Parse(templateText) + //Setting the option to explicitly fail if template is rendered trying to access missing key + bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText) command.AbortIfError("Failed to parse bootstrap template:", err) buf := &bytes.Buffer{} - bootstrapTemplate.Execute(buf, data) + //Being explicit about failing sooner during template rendering + //when accessing custom data rather than during the go fmt command + err = bootstrapTemplate.Execute(buf, data) + command.AbortIfError("Failed to render bootstrap template:", err) buf.WriteTo(f) diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_command.go index 93b0b4b2..48d23f91 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generate_command.go @@ -2,6 +2,7 @@ package generators import ( "bytes" + "encoding/json" "fmt" "os" "path/filepath" @@ -28,6 +29,9 @@ func BuildGenerateCommand() command.Command { {Name: "template", KeyPath: "CustomTemplate", UsageArgument: "template-file", Usage: "If specified, generate will use the contents of the file passed as the test file template"}, + {Name: "template-data", KeyPath: "CustomTemplateData", + UsageArgument: "template-data-file", + Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the test file template"}, }, &conf, types.GinkgoFlagSections{}, @@ -64,6 +68,7 @@ type specData struct { GomegaImport string GinkgoPackage string GomegaPackage string + CustomData map[string]any } func generateTestFiles(conf GeneratorsConfig, args []string) { @@ -122,16 +127,31 @@ func generateTestFileForSubject(subject string, conf GeneratorsConfig) { tpl, err := os.ReadFile(conf.CustomTemplate) command.AbortIfError("Failed to read custom template file:", err) templateText = string(tpl) + if conf.CustomTemplateData != "" { + var tplCustomDataMap map[string]any + tplCustomData, err := os.ReadFile(conf.CustomTemplateData) + command.AbortIfError("Failed to read custom template data file:", err) + if !json.Valid([]byte(tplCustomData)) { + command.AbortWith("Invalid JSON object in custom data file.") + } + //create map from the custom template data + json.Unmarshal(tplCustomData, &tplCustomDataMap) + data.CustomData = tplCustomDataMap + } } else if conf.Agouti { templateText = agoutiSpecText } else { templateText = specText } - specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Parse(templateText) + //Setting the option to explicitly fail if template is rendered trying to access missing key + specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText) command.AbortIfError("Failed to read parse test template:", err) - specTemplate.Execute(f, data) + //Being explicit about failing sooner during template rendering + //when accessing custom data rather than during the go fmt command + err = specTemplate.Execute(f, data) + command.AbortIfError("Failed to render bootstrap template:", err) internal.GoFmt(targetFile) } diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generators_common.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generators_common.go index 3086e605..3046a448 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generators_common.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/generators/generators_common.go @@ -13,6 +13,7 @@ import ( type GeneratorsConfig struct { Agouti, NoDot, Internal bool CustomTemplate string + CustomTemplateData string } func getPackageAndFormattedName() (string, string, string) { diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go index 496ec4a2..86da7340 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/compile.go @@ -25,7 +25,16 @@ func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig) TestSuite return suite } - args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./") + ginkgoInvocationPath, _ := os.Getwd() + ginkgoInvocationPath, _ = filepath.Abs(ginkgoInvocationPath) + packagePath := suite.AbsPath() + pathToInvocationPath, err := filepath.Rel(packagePath, ginkgoInvocationPath) + if err != nil { + suite.State = TestSuiteStateFailedToCompile + suite.CompilationError = fmt.Errorf("Failed to get relative path from package to the current working directory:\n%s", err.Error()) + return suite + } + args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./", pathToInvocationPath) if err != nil { suite.State = TestSuiteStateFailedToCompile suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error()) diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/run.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/run.go index cad23867..41052ea1 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/run.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/run.go @@ -6,6 +6,7 @@ import ( "io" "os" "os/exec" + "path/filepath" "regexp" "strings" "syscall" @@ -63,6 +64,12 @@ func checkForNoTestsWarning(buf *bytes.Buffer) bool { } func runGoTest(suite TestSuite, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) TestSuite { + // As we run the go test from the suite directory, make sure the cover profile is absolute + // and placed into the expected output directory when one is configured. + if goFlagsConfig.Cover && !filepath.IsAbs(goFlagsConfig.CoverProfile) { + goFlagsConfig.CoverProfile = AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, 0) + } + args, err := types.GenerateGoTestRunArgs(goFlagsConfig) command.AbortIfError("Failed to generate test run arguments", err) cmd, buf := buildAndStartCommand(suite, args, true) diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/verify_version.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/verify_version.go new file mode 100644 index 00000000..9da1bab3 --- /dev/null +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/internal/verify_version.go @@ -0,0 +1,54 @@ +package internal + +import ( + "fmt" + "os/exec" + "regexp" + "strings" + + "github.com/onsi/ginkgo/v2/formatter" + "github.com/onsi/ginkgo/v2/types" +) + +var versiorRe = regexp.MustCompile(`v(\d+\.\d+\.\d+)`) + +func VerifyCLIAndFrameworkVersion(suites TestSuites) { + cliVersion := types.VERSION + mismatches := map[string][]string{} + + for _, suite := range suites { + cmd := exec.Command("go", "list", "-m", "github.com/onsi/ginkgo/v2") + cmd.Dir = suite.Path + output, err := cmd.CombinedOutput() + if err != nil { + continue + } + components := strings.Split(string(output), " ") + if len(components) != 2 { + continue + } + matches := versiorRe.FindStringSubmatch(components[1]) + if matches == nil || len(matches) != 2 { + continue + } + libraryVersion := matches[1] + if cliVersion != libraryVersion { + mismatches[libraryVersion] = append(mismatches[libraryVersion], suite.PackageName) + } + } + + if len(mismatches) == 0 { + return + } + + fmt.Println(formatter.F("{{red}}{{bold}}Ginkgo detected a version mismatch between the Ginkgo CLI and the version of Ginkgo imported by your packages:{{/}}")) + + fmt.Println(formatter.Fi(1, "Ginkgo CLI Version:")) + fmt.Println(formatter.Fi(2, "{{bold}}%s{{/}}", cliVersion)) + fmt.Println(formatter.Fi(1, "Mismatched package versions found:")) + for version, packages := range mismatches { + fmt.Println(formatter.Fi(2, "{{bold}}%s{{/}} used by %s", version, strings.Join(packages, ", "))) + } + fmt.Println("") + fmt.Println(formatter.Fiw(1, formatter.COLS, "{{gray}}Ginkgo will continue to attempt to run but you may see errors (including flag parsing errors) and should either update your go.mod or your version of the Ginkgo CLI to match.\n\nTo install the matching version of the CLI run\n {{bold}}go install github.com/onsi/ginkgo/v2/ginkgo{{/}}{{gray}}\nfrom a path that contains a go.mod file. Alternatively you can use\n {{bold}}go run github.com/onsi/ginkgo/v2/ginkgo{{/}}{{gray}}\nfrom a path that contains a go.mod file to invoke the matching version of the Ginkgo CLI.\n\nIf you are attempting to test multiple packages that each have a different version of the Ginkgo library with a single Ginkgo CLI that is currently unsupported.\n{{/}}")) +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/ginkgo.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/ginkgo.go index c197bb68..0b9b19fe 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/ginkgo.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/ginkgo.go @@ -1,6 +1,7 @@ package outline import ( + "github.com/onsi/ginkgo/v2/types" "go/ast" "go/token" "strconv" @@ -25,9 +26,10 @@ type ginkgoMetadata struct { // End is the position of first character immediately after the spec or container block End int `json:"end"` - Spec bool `json:"spec"` - Focused bool `json:"focused"` - Pending bool `json:"pending"` + Spec bool `json:"spec"` + Focused bool `json:"focused"` + Pending bool `json:"pending"` + Labels []string `json:"labels"` } // ginkgoNode is used to construct the outline as a tree @@ -145,27 +147,35 @@ func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackage case "It", "Specify", "Entry": n.Spec = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + n.Labels = labelFromCallExpr(ce) + n.Pending = pendingFromCallExpr(ce) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName case "FIt", "FSpecify", "FEntry": n.Spec = true n.Focused = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + n.Labels = labelFromCallExpr(ce) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName case "PIt", "PSpecify", "XIt", "XSpecify", "PEntry", "XEntry": n.Spec = true n.Pending = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + n.Labels = labelFromCallExpr(ce) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName case "Context", "Describe", "When", "DescribeTable": n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + n.Labels = labelFromCallExpr(ce) + n.Pending = pendingFromCallExpr(ce) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName case "FContext", "FDescribe", "FWhen", "FDescribeTable": n.Focused = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + n.Labels = labelFromCallExpr(ce) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen", "PDescribeTable", "XDescribeTable": n.Pending = true n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) + n.Labels = labelFromCallExpr(ce) return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName case "By": n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt) @@ -216,3 +226,77 @@ func textFromCallExpr(ce *ast.CallExpr) (string, bool) { return text.Value, true } } + +func labelFromCallExpr(ce *ast.CallExpr) []string { + + labels := []string{} + if len(ce.Args) < 2 { + return labels + } + + for _, arg := range ce.Args[1:] { + switch expr := arg.(type) { + case *ast.CallExpr: + id, ok := expr.Fun.(*ast.Ident) + if !ok { + // to skip over cases where the expr.Fun. is actually *ast.SelectorExpr + continue + } + if id.Name == "Label" { + ls := extractLabels(expr) + for _, label := range ls { + labels = append(labels, label) + } + } + } + } + return labels +} + +func extractLabels(expr *ast.CallExpr) []string { + out := []string{} + for _, arg := range expr.Args { + switch expr := arg.(type) { + case *ast.BasicLit: + if expr.Kind == token.STRING { + unquoted, err := strconv.Unquote(expr.Value) + if err != nil { + unquoted = expr.Value + } + validated, err := types.ValidateAndCleanupLabel(unquoted, types.CodeLocation{}) + if err == nil { + out = append(out, validated) + } + } + } + } + + return out +} + +func pendingFromCallExpr(ce *ast.CallExpr) bool { + + pending := false + if len(ce.Args) < 2 { + return pending + } + + for _, arg := range ce.Args[1:] { + switch expr := arg.(type) { + case *ast.CallExpr: + id, ok := expr.Fun.(*ast.Ident) + if !ok { + // to skip over cases where the expr.Fun. is actually *ast.SelectorExpr + continue + } + if id.Name == "Pending" { + pending = true + } + case *ast.Ident: + if expr.Name == "Pending" { + pending = true + } + } + } + return pending +} diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/import.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/import.go index 4328ab39..67ec5ab7 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/import.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/import.go @@ -47,7 +47,7 @@ func packageNameForImport(f *ast.File, path string) *string { // or nil otherwise. func importSpec(f *ast.File, path string) *ast.ImportSpec { for _, s := range f.Imports { - if importPath(s) == path { + if strings.HasPrefix(importPath(s), path) { return s } } diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline.go index 4b45e762..c2327cda 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/outline/outline.go @@ -85,12 +85,19 @@ func (o *outline) String() string { // one 'width' of spaces for every level of nesting. func (o *outline) StringIndent(width int) string { var b strings.Builder - b.WriteString("Name,Text,Start,End,Spec,Focused,Pending\n") + b.WriteString("Name,Text,Start,End,Spec,Focused,Pending,Labels\n") currentIndent := 0 pre := func(n *ginkgoNode) { b.WriteString(fmt.Sprintf("%*s", currentIndent, "")) - b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending)) + var labels string + if len(n.Labels) == 1 { + labels = n.Labels[0] + } else { + labels = strings.Join(n.Labels, ", ") + } + //enclosing labels in a double quoted comma separate listed so that when inmported into a CSV app the Labels column has comma separate strings + b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t,\"%s\"\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending, labels)) currentIndent += width } post := func(n *ginkgoNode) { diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go index 8ee0acc8..aaed4d57 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/run/run_command.go @@ -24,7 +24,7 @@ func BuildRunCommand() command.Command { panic(err) } - interruptHandler := interrupt_handler.NewInterruptHandler(0, nil) + interruptHandler := interrupt_handler.NewInterruptHandler(nil) interrupt_handler.SwallowSigQuit() return command.Command{ @@ -69,6 +69,8 @@ func (r *SpecRunner) RunSpecs(args []string, additionalArgs []string) { skippedSuites := suites.WithState(internal.TestSuiteStateSkippedByFilter) suites = suites.WithoutState(internal.TestSuiteStateSkippedByFilter) + internal.VerifyCLIAndFrameworkVersion(suites) + if len(skippedSuites) > 0 { fmt.Println("Will skip:") for _, skippedSuite := range skippedSuites { @@ -115,7 +117,7 @@ OUTER_LOOP: } suites[suiteIdx] = suite - if r.interruptHandler.Status().Interrupted { + if r.interruptHandler.Status().Interrupted() { opc.StopAndDrain() break OUTER_LOOP } diff --git a/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go index 83dbeb1e..bde4193c 100644 --- a/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go +++ b/vendor/github.com/onsi/ginkgo/v2/ginkgo/watch/watch_command.go @@ -22,7 +22,7 @@ func BuildWatchCommand() command.Command { if err != nil { panic(err) } - interruptHandler := interrupt_handler.NewInterruptHandler(0, nil) + interruptHandler := interrupt_handler.NewInterruptHandler(nil) interrupt_handler.SwallowSigQuit() return command.Command{ @@ -65,6 +65,8 @@ type SpecWatcher struct { func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) { suites := internal.FindSuites(args, w.cliConfig, false).WithoutState(internal.TestSuiteStateSkippedByFilter) + internal.VerifyCLIAndFrameworkVersion(suites) + if len(suites) == 0 { command.AbortWith("Found no test suites") } @@ -127,7 +129,7 @@ func (w *SpecWatcher) WatchSpecs(args []string, additionalArgs []string) { w.updateSeed() w.computeSuccinctMode(len(suites)) for idx := range suites { - if w.interruptHandler.Status().Interrupted { + if w.interruptHandler.Status().Interrupted() { return } deltaTracker.WillRun(suites[idx]) @@ -156,7 +158,7 @@ func (w *SpecWatcher) compileAndRun(suite internal.TestSuite, additionalArgs []s fmt.Println(suite.CompilationError.Error()) return suite } - if w.interruptHandler.Status().Interrupted { + if w.interruptHandler.Status().Interrupted() { return suite } suite = internal.RunCompiledSuite(suite, w.suiteConfig, w.reporterConfig, w.cliConfig, w.goFlagsConfig, additionalArgs) diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go b/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go index ad224bc5..8ed86111 100644 --- a/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go +++ b/vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go @@ -1,7 +1,6 @@ package interrupt_handler import ( - "fmt" "os" "os/signal" "sync" @@ -11,27 +10,29 @@ import ( "github.com/onsi/ginkgo/v2/internal/parallel_support" ) -const TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION = 30 * time.Second -const TIMEOUT_REPEAT_INTERRUPT_FRACTION_OF_TIMEOUT = 10 -const ABORT_POLLING_INTERVAL = 500 * time.Millisecond -const ABORT_REPEAT_INTERRUPT_DURATION = 30 * time.Second +var ABORT_POLLING_INTERVAL = 500 * time.Millisecond type InterruptCause uint const ( InterruptCauseInvalid InterruptCause = iota - InterruptCauseSignal - InterruptCauseTimeout InterruptCauseAbortByOtherProcess ) +type InterruptLevel uint + +const ( + InterruptLevelUninterrupted InterruptLevel = iota + InterruptLevelCleanupAndReport + InterruptLevelReportOnly + InterruptLevelBailOut +) + func (ic InterruptCause) String() string { switch ic { case InterruptCauseSignal: return "Interrupted by User" - case InterruptCauseTimeout: - return "Interrupted by Timeout" case InterruptCauseAbortByOtherProcess: return "Interrupted by Other Ginkgo Process" } @@ -39,37 +40,51 @@ func (ic InterruptCause) String() string { } type InterruptStatus struct { - Interrupted bool - Channel chan interface{} - Cause InterruptCause + Channel chan interface{} + Level InterruptLevel + Cause InterruptCause +} + +func (s InterruptStatus) Interrupted() bool { + return s.Level != InterruptLevelUninterrupted +} + +func (s InterruptStatus) Message() string { + return s.Cause.String() +} + +func (s InterruptStatus) ShouldIncludeProgressReport() bool { + return s.Cause != InterruptCauseAbortByOtherProcess } type InterruptHandlerInterface interface { Status() InterruptStatus - SetInterruptPlaceholderMessage(string) - ClearInterruptPlaceholderMessage() - InterruptMessage() (string, bool) } type InterruptHandler struct { - c chan interface{} - lock *sync.Mutex - interrupted bool - interruptPlaceholderMessage string - interruptCause InterruptCause - client parallel_support.Client - stop chan interface{} + c chan interface{} + lock *sync.Mutex + level InterruptLevel + cause InterruptCause + client parallel_support.Client + stop chan interface{} + signals []os.Signal + requestAbortCheck chan interface{} } -func NewInterruptHandler(timeout time.Duration, client parallel_support.Client) *InterruptHandler { - handler := &InterruptHandler{ - c: make(chan interface{}), - lock: &sync.Mutex{}, - interrupted: false, - stop: make(chan interface{}), - client: client, +func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *InterruptHandler { + if len(signals) == 0 { + signals = []os.Signal{os.Interrupt, syscall.SIGTERM} } - handler.registerForInterrupts(timeout) + handler := &InterruptHandler{ + c: make(chan interface{}), + lock: &sync.Mutex{}, + stop: make(chan interface{}), + requestAbortCheck: make(chan interface{}), + client: client, + signals: signals, + } + handler.registerForInterrupts() return handler } @@ -77,30 +92,28 @@ func (handler *InterruptHandler) Stop() { close(handler.stop) } -func (handler *InterruptHandler) registerForInterrupts(timeout time.Duration) { +func (handler *InterruptHandler) registerForInterrupts() { // os signal handling signalChannel := make(chan os.Signal, 1) - signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM) - - // timeout handling - var timeoutChannel <-chan time.Time - var timeoutTimer *time.Timer - if timeout > 0 { - timeoutTimer = time.NewTimer(timeout) - timeoutChannel = timeoutTimer.C - } + signal.Notify(signalChannel, handler.signals...) // cross-process abort handling - var abortChannel chan bool + var abortChannel chan interface{} if handler.client != nil { - abortChannel = make(chan bool) + abortChannel = make(chan interface{}) go func() { pollTicker := time.NewTicker(ABORT_POLLING_INTERVAL) for { select { case <-pollTicker.C: if handler.client.ShouldAbort() { - abortChannel <- true + close(abortChannel) + pollTicker.Stop() + return + } + case <-handler.requestAbortCheck: + if handler.client.ShouldAbort() { + close(abortChannel) pollTicker.Stop() return } @@ -112,85 +125,53 @@ func (handler *InterruptHandler) registerForInterrupts(timeout time.Duration) { }() } - // listen for any interrupt signals - // note that some (timeouts, cross-process aborts) will only trigger once - // for these we set up a ticker to keep interrupting the suite until it ends - // this ensures any `AfterEach` or `AfterSuite`s that get stuck cleaning up - // get interrupted eventually - go func() { + go func(abortChannel chan interface{}) { var interruptCause InterruptCause - var repeatChannel <-chan time.Time - var repeatTicker *time.Ticker for { select { case <-signalChannel: interruptCause = InterruptCauseSignal - case <-timeoutChannel: - interruptCause = InterruptCauseTimeout - repeatInterruptTimeout := timeout / time.Duration(TIMEOUT_REPEAT_INTERRUPT_FRACTION_OF_TIMEOUT) - if repeatInterruptTimeout > TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION { - repeatInterruptTimeout = TIMEOUT_REPEAT_INTERRUPT_MAXIMUM_DURATION - } - timeoutTimer.Stop() - repeatTicker = time.NewTicker(repeatInterruptTimeout) - repeatChannel = repeatTicker.C case <-abortChannel: interruptCause = InterruptCauseAbortByOtherProcess - repeatTicker = time.NewTicker(ABORT_REPEAT_INTERRUPT_DURATION) - repeatChannel = repeatTicker.C - case <-repeatChannel: - //do nothing, just interrupt again using the same interruptCause case <-handler.stop: - if timeoutTimer != nil { - timeoutTimer.Stop() - } - if repeatTicker != nil { - repeatTicker.Stop() - } signal.Stop(signalChannel) return } + abortChannel = nil + handler.lock.Lock() - handler.interruptCause = interruptCause - if handler.interruptPlaceholderMessage != "" { - fmt.Println(handler.interruptPlaceholderMessage) + oldLevel := handler.level + handler.cause = interruptCause + if handler.level == InterruptLevelUninterrupted { + handler.level = InterruptLevelCleanupAndReport + } else if handler.level == InterruptLevelCleanupAndReport { + handler.level = InterruptLevelReportOnly + } else if handler.level == InterruptLevelReportOnly { + handler.level = InterruptLevelBailOut + } + if handler.level != oldLevel { + close(handler.c) + handler.c = make(chan interface{}) } - handler.interrupted = true - close(handler.c) - handler.c = make(chan interface{}) handler.lock.Unlock() } - }() + }(abortChannel) } func (handler *InterruptHandler) Status() InterruptStatus { handler.lock.Lock() - defer handler.lock.Unlock() - - return InterruptStatus{ - Interrupted: handler.interrupted, - Channel: handler.c, - Cause: handler.interruptCause, + status := InterruptStatus{ + Level: handler.level, + Channel: handler.c, + Cause: handler.cause, } -} - -func (handler *InterruptHandler) SetInterruptPlaceholderMessage(message string) { - handler.lock.Lock() - defer handler.lock.Unlock() - - handler.interruptPlaceholderMessage = message -} - -func (handler *InterruptHandler) ClearInterruptPlaceholderMessage() { - handler.lock.Lock() - defer handler.lock.Unlock() - - handler.interruptPlaceholderMessage = "" -} - -func (handler *InterruptHandler) InterruptMessage() (string, bool) { - handler.lock.Lock() - out := fmt.Sprintf("%s", handler.interruptCause.String()) - defer handler.lock.Unlock() - return out, handler.interruptCause != InterruptCauseAbortByOtherProcess + handler.lock.Unlock() + + if handler.client != nil && handler.client.ShouldAbort() && !status.Interrupted() { + close(handler.requestAbortCheck) + <-status.Channel + return handler.Status() + } + + return status } diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/client_server.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/client_server.go index b417bf5b..b3cd6429 100644 --- a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/client_server.go +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/client_server.go @@ -42,6 +42,8 @@ type Client interface { PostSuiteWillBegin(report types.Report) error PostDidRun(report types.SpecReport) error PostSuiteDidEnd(report types.Report) error + PostReportBeforeSuiteCompleted(state types.SpecState) error + BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error) BlockUntilNonprimaryProcsHaveFinished() error diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_client.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_client.go index ad9932f2..6547c7a6 100644 --- a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_client.go +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_client.go @@ -98,6 +98,19 @@ func (client *httpClient) PostEmitProgressReport(report types.ProgressReport) er return client.post("/progress-report", report) } +func (client *httpClient) PostReportBeforeSuiteCompleted(state types.SpecState) error { + return client.post("/report-before-suite-completed", state) +} + +func (client *httpClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) { + var state types.SpecState + err := client.poll("/report-before-suite-state", &state) + if err == ErrorGone { + return types.SpecStateFailed, nil + } + return state, err +} + func (client *httpClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error { beforeSuiteState := BeforeSuiteState{ State: state, diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_server.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_server.go index fa3ac682..d2c71ab1 100644 --- a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_server.go +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/http_server.go @@ -26,7 +26,7 @@ type httpServer struct { handler *ServerHandler } -//Create a new server, automatically selecting a port +// Create a new server, automatically selecting a port func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -38,7 +38,7 @@ func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, }, nil } -//Start the server. You don't need to `go s.Start()`, just `s.Start()` +// Start the server. You don't need to `go s.Start()`, just `s.Start()` func (server *httpServer) Start() { httpServer := &http.Server{} mux := http.NewServeMux() @@ -52,6 +52,8 @@ func (server *httpServer) Start() { mux.HandleFunc("/progress-report", server.emitProgressReport) //synchronization endpoints + mux.HandleFunc("/report-before-suite-completed", server.handleReportBeforeSuiteCompleted) + mux.HandleFunc("/report-before-suite-state", server.handleReportBeforeSuiteState) mux.HandleFunc("/before-suite-completed", server.handleBeforeSuiteCompleted) mux.HandleFunc("/before-suite-state", server.handleBeforeSuiteState) mux.HandleFunc("/have-nonprimary-procs-finished", server.handleHaveNonprimaryProcsFinished) @@ -63,12 +65,12 @@ func (server *httpServer) Start() { go httpServer.Serve(server.listener) } -//Stop the server +// Stop the server func (server *httpServer) Close() { server.listener.Close() } -//The address the server can be reached it. Pass this into the `ForwardingReporter`. +// The address the server can be reached it. Pass this into the `ForwardingReporter`. func (server *httpServer) Address() string { return "http://" + server.listener.Addr().String() } @@ -93,7 +95,7 @@ func (server *httpServer) RegisterAlive(node int, alive func() bool) { // Streaming Endpoints // -//The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters` +// The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters` func (server *httpServer) decode(writer http.ResponseWriter, request *http.Request, object interface{}) bool { defer request.Body.Close() if json.NewDecoder(request.Body).Decode(object) != nil { @@ -164,6 +166,23 @@ func (server *httpServer) emitProgressReport(writer http.ResponseWriter, request server.handleError(server.handler.EmitProgressReport(report, voidReceiver), writer) } +func (server *httpServer) handleReportBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) { + var state types.SpecState + if !server.decode(writer, request, &state) { + return + } + + server.handleError(server.handler.ReportBeforeSuiteCompleted(state, voidReceiver), writer) +} + +func (server *httpServer) handleReportBeforeSuiteState(writer http.ResponseWriter, request *http.Request) { + var state types.SpecState + if server.handleError(server.handler.ReportBeforeSuiteState(voidSender, &state), writer) { + return + } + json.NewEncoder(writer).Encode(state) +} + func (server *httpServer) handleBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) { var beforeSuiteState BeforeSuiteState if !server.decode(writer, request, &beforeSuiteState) { diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_client.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_client.go index fe93cc2b..59e8e6fd 100644 --- a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_client.go +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/rpc_client.go @@ -76,6 +76,19 @@ func (client *rpcClient) PostEmitProgressReport(report types.ProgressReport) err return client.client.Call("Server.EmitProgressReport", report, voidReceiver) } +func (client *rpcClient) PostReportBeforeSuiteCompleted(state types.SpecState) error { + return client.client.Call("Server.ReportBeforeSuiteCompleted", state, voidReceiver) +} + +func (client *rpcClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) { + var state types.SpecState + err := client.poll("Server.ReportBeforeSuiteState", &state) + if err == ErrorGone { + return types.SpecStateFailed, nil + } + return state, err +} + func (client *rpcClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error { beforeSuiteState := BeforeSuiteState{ State: state, diff --git a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/server_handler.go b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/server_handler.go index 7c6e67b9..a6d98793 100644 --- a/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/server_handler.go +++ b/vendor/github.com/onsi/ginkgo/v2/internal/parallel_support/server_handler.go @@ -18,16 +18,17 @@ var voidSender Void // It handles all the business logic to avoid duplication between the two servers type ServerHandler struct { - done chan interface{} - outputDestination io.Writer - reporter reporters.Reporter - alives []func() bool - lock *sync.Mutex - beforeSuiteState BeforeSuiteState - parallelTotal int - counter int - counterLock *sync.Mutex - shouldAbort bool + done chan interface{} + outputDestination io.Writer + reporter reporters.Reporter + alives []func() bool + lock *sync.Mutex + beforeSuiteState BeforeSuiteState + reportBeforeSuiteState types.SpecState + parallelTotal int + counter int + counterLock *sync.Mutex + shouldAbort bool numSuiteDidBegins int numSuiteDidEnds int @@ -37,11 +38,12 @@ type ServerHandler struct { func newServerHandler(parallelTotal int, reporter reporters.Reporter) *ServerHandler { return &ServerHandler{ - reporter: reporter, - lock: &sync.Mutex{}, - counterLock: &sync.Mutex{}, - alives: make([]func() bool, parallelTotal), - beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid}, + reporter: reporter, + lock: &sync.Mutex{}, + counterLock: &sync.Mutex{}, + alives: make([]func() bool, parallelTotal), + beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid}, + parallelTotal: parallelTotal, outputDestination: os.Stdout, done: make(chan interface{}), @@ -140,6 +142,29 @@ func (handler *ServerHandler) haveNonprimaryProcsFinished() bool { return true } +func (handler *ServerHandler) ReportBeforeSuiteCompleted(reportBeforeSuiteState types.SpecState, _ *Void) error { + handler.lock.Lock() + defer handler.lock.Unlock() + handler.reportBeforeSuiteState = reportBeforeSuiteState + + return nil +} + +func (handler *ServerHandler) ReportBeforeSuiteState(_ Void, reportBeforeSuiteState *types.SpecState) error { + proc1IsAlive := handler.procIsAlive(1) + handler.lock.Lock() + defer handler.lock.Unlock() + if handler.reportBeforeSuiteState == types.SpecStateInvalid { + if proc1IsAlive { + return ErrorEarly + } else { + return ErrorGone + } + } + *reportBeforeSuiteState = handler.reportBeforeSuiteState + return nil +} + func (handler *ServerHandler) BeforeSuiteCompleted(beforeSuiteState BeforeSuiteState, _ *Void) error { handler.lock.Lock() defer handler.lock.Unlock() diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/default_reporter.go b/vendor/github.com/onsi/ginkgo/v2/reporters/default_reporter.go index ccd3317a..56b7be75 100644 --- a/vendor/github.com/onsi/ginkgo/v2/reporters/default_reporter.go +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/default_reporter.go @@ -12,6 +12,7 @@ import ( "io" "runtime" "strings" + "sync" "time" "github.com/onsi/ginkgo/v2/formatter" @@ -23,13 +24,16 @@ type DefaultReporter struct { writer io.Writer // managing the emission stream - lastChar string + lastCharWasNewline bool lastEmissionWasDelimiter bool // rendering specDenoter string retryDenoter string formatter formatter.Formatter + + runningInParallel bool + lock *sync.Mutex } func NewDefaultReporterUnderTest(conf types.ReporterConfig, writer io.Writer) *DefaultReporter { @@ -44,12 +48,13 @@ func NewDefaultReporter(conf types.ReporterConfig, writer io.Writer) *DefaultRep conf: conf, writer: writer, - lastChar: "\n", + lastCharWasNewline: true, lastEmissionWasDelimiter: false, specDenoter: "•", retryDenoter: "↺", formatter: formatter.NewWithNoColorBool(conf.NoColor), + lock: &sync.Mutex{}, } if runtime.GOOS == "windows" { reporter.specDenoter = "+" @@ -97,166 +102,10 @@ func (r *DefaultReporter) SuiteWillBegin(report types.Report) { } } -func (r *DefaultReporter) WillRun(report types.SpecReport) { - if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) { - return - } - - r.emitDelimiter() - indentation := uint(0) - if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { - r.emitBlock(r.f("{{bold}}[%s] %s{{/}}", report.LeafNodeType.String(), report.LeafNodeText)) - } else { - if len(report.ContainerHierarchyTexts) > 0 { - r.emitBlock(r.cycleJoin(report.ContainerHierarchyTexts, " ")) - indentation = 1 - } - line := r.fi(indentation, "{{bold}}%s{{/}}", report.LeafNodeText) - labels := report.Labels() - if len(labels) > 0 { - line += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels, ", ")) - } - r.emitBlock(line) - } - r.emitBlock(r.fi(indentation, "{{gray}}%s{{/}}", report.LeafNodeLocation)) -} - -func (r *DefaultReporter) DidRun(report types.SpecReport) { - v := r.conf.Verbosity() - var header, highlightColor string - includeRuntime, emitGinkgoWriterOutput, stream, denoter := true, true, false, r.specDenoter - succinctLocationBlock := v.Is(types.VerbosityLevelSuccinct) - - hasGW := report.CapturedGinkgoWriterOutput != "" - hasStd := report.CapturedStdOutErr != "" - hasEmittableReports := report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways) || (report.ReportEntries.HasVisibility(types.ReportEntryVisibilityFailureOrVerbose) && (!report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose))) - - if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { - denoter = fmt.Sprintf("[%s]", report.LeafNodeType) - } - - switch report.State { - case types.SpecStatePassed: - highlightColor, succinctLocationBlock = "{{green}}", v.LT(types.VerbosityLevelVerbose) - emitGinkgoWriterOutput = (r.conf.AlwaysEmitGinkgoWriter || v.GTE(types.VerbosityLevelVerbose)) && hasGW - if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { - if v.GTE(types.VerbosityLevelVerbose) || hasStd || hasEmittableReports { - header = fmt.Sprintf("%s PASSED", denoter) - } else { - return - } - } else { - header, stream = denoter, true - if report.NumAttempts > 1 { - header, stream = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), false - } - if report.RunTime > r.conf.SlowSpecThreshold { - header, stream = fmt.Sprintf("%s [SLOW TEST]", header), false - } - } - if hasStd || emitGinkgoWriterOutput || hasEmittableReports { - stream = false - } - case types.SpecStatePending: - highlightColor = "{{yellow}}" - includeRuntime, emitGinkgoWriterOutput = false, false - if v.Is(types.VerbosityLevelSuccinct) { - header, stream = "P", true - } else { - header, succinctLocationBlock = "P [PENDING]", v.LT(types.VerbosityLevelVeryVerbose) - } - case types.SpecStateSkipped: - highlightColor = "{{cyan}}" - if report.Failure.Message != "" || v.Is(types.VerbosityLevelVeryVerbose) { - header = "S [SKIPPED]" - } else { - header, stream = "S", true - } - case types.SpecStateFailed: - highlightColor, header = "{{red}}", fmt.Sprintf("%s [FAILED]", denoter) - case types.SpecStatePanicked: - highlightColor, header = "{{magenta}}", fmt.Sprintf("%s! [PANICKED]", denoter) - case types.SpecStateInterrupted: - highlightColor, header = "{{orange}}", fmt.Sprintf("%s! [INTERRUPTED]", denoter) - case types.SpecStateAborted: - highlightColor, header = "{{coral}}", fmt.Sprintf("%s! [ABORTED]", denoter) - } - - // Emit stream and return - if stream { - r.emit(r.f(highlightColor + header + "{{/}}")) - return - } - - // Emit header - r.emitDelimiter() - if includeRuntime { - header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds()) - } - r.emitBlock(r.f(highlightColor + header + "{{/}}")) - - // Emit Code Location Block - r.emitBlock(r.codeLocationBlock(report, highlightColor, succinctLocationBlock, false)) - - //Emit Stdout/Stderr Output - if hasStd { - r.emitBlock("\n") - r.emitBlock(r.fi(1, "{{gray}}Begin Captured StdOut/StdErr Output >>{{/}}")) - r.emitBlock(r.fi(2, "%s", report.CapturedStdOutErr)) - r.emitBlock(r.fi(1, "{{gray}}<< End Captured StdOut/StdErr Output{{/}}")) - } - - //Emit Captured GinkgoWriter Output - if emitGinkgoWriterOutput && hasGW { - r.emitBlock("\n") - r.emitGinkgoWriterOutput(1, report.CapturedGinkgoWriterOutput, 0) - } - - if hasEmittableReports { - r.emitBlock("\n") - r.emitBlock(r.fi(1, "{{gray}}Begin Report Entries >>{{/}}")) - reportEntries := report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways) - if !report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose) { - reportEntries = report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways, types.ReportEntryVisibilityFailureOrVerbose) - } - for _, entry := range reportEntries { - r.emitBlock(r.fi(2, "{{bold}}"+entry.Name+"{{gray}} - %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT))) - if representation := entry.StringRepresentation(); representation != "" { - r.emitBlock(r.fi(3, representation)) - } - } - r.emitBlock(r.fi(1, "{{gray}}<< End Report Entries{{/}}")) - } - - // Emit Failure Message - if !report.Failure.IsZero() { - r.emitBlock("\n") - r.emitBlock(r.fi(1, highlightColor+"%s{{/}}", report.Failure.Message)) - r.emitBlock(r.fi(1, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}}\n", report.Failure.FailureNodeType, report.Failure.Location)) - if report.Failure.ForwardedPanic != "" { - r.emitBlock("\n") - r.emitBlock(r.fi(1, highlightColor+"%s{{/}}", report.Failure.ForwardedPanic)) - } - - if r.conf.FullTrace || report.Failure.ForwardedPanic != "" { - r.emitBlock("\n") - r.emitBlock(r.fi(1, highlightColor+"Full Stack Trace{{/}}")) - r.emitBlock(r.fi(2, "%s", report.Failure.Location.FullStackTrace)) - } - - if !report.Failure.ProgressReport.IsZero() { - r.emitBlock("\n") - r.emitProgressReport(1, false, report.Failure.ProgressReport) - } - } - - r.emitDelimiter() -} - func (r *DefaultReporter) SuiteDidEnd(report types.Report) { failures := report.SpecReports.WithState(types.SpecStateFailureStates) if len(failures) > 0 { - r.emitBlock("\n\n") + r.emitBlock("\n") if len(failures) > 1 { r.emitBlock(r.f("{{red}}{{bold}}Summarizing %d Failures:{{/}}", len(failures))) } else { @@ -269,10 +118,12 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) { highlightColor, heading = "{{magenta}}", "[PANICKED!]" case types.SpecStateAborted: highlightColor, heading = "{{coral}}", "[ABORTED]" + case types.SpecStateTimedout: + highlightColor, heading = "{{orange}}", "[TIMEDOUT]" case types.SpecStateInterrupted: highlightColor, heading = "{{orange}}", "[INTERRUPTED]" } - locationBlock := r.codeLocationBlock(specReport, highlightColor, true, true) + locationBlock := r.codeLocationBlock(specReport, highlightColor, false, true) r.emitBlock(r.fi(1, highlightColor+"%s{{/}} %s", heading, locationBlock)) } } @@ -313,28 +164,294 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) { if specs.CountOfFlakedSpecs() > 0 { r.emit(r.f("{{light-yellow}}{{bold}}%d Flaked{{/}} | ", specs.CountOfFlakedSpecs())) } + if specs.CountOfRepeatedSpecs() > 0 { + r.emit(r.f("{{light-yellow}}{{bold}}%d Repeated{{/}} | ", specs.CountOfRepeatedSpecs())) + } r.emit(r.f("{{yellow}}{{bold}}%d Pending{{/}} | ", specs.CountWithState(types.SpecStatePending))) r.emit(r.f("{{cyan}}{{bold}}%d Skipped{{/}}\n", specs.CountWithState(types.SpecStateSkipped))) } } +func (r *DefaultReporter) WillRun(report types.SpecReport) { + v := r.conf.Verbosity() + if v.LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) || report.RunningInParallel { + return + } + + r.emitDelimiter(0) + r.emitBlock(r.f(r.codeLocationBlock(report, "{{/}}", v.Is(types.VerbosityLevelVeryVerbose), false))) +} + +func (r *DefaultReporter) DidRun(report types.SpecReport) { + v := r.conf.Verbosity() + inParallel := report.RunningInParallel + + header := r.specDenoter + if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { + header = fmt.Sprintf("[%s]", report.LeafNodeType) + } + highlightColor := r.highlightColorForState(report.State) + + // have we already been streaming the timeline? + timelineHasBeenStreaming := v.GTE(types.VerbosityLevelVerbose) && !inParallel + + // should we show the timeline? + var timeline types.Timeline + showTimeline := !timelineHasBeenStreaming && (v.GTE(types.VerbosityLevelVerbose) || report.Failed()) + if showTimeline { + timeline = report.Timeline().WithoutHiddenReportEntries() + keepVeryVerboseSpecEvents := v.Is(types.VerbosityLevelVeryVerbose) || + (v.Is(types.VerbosityLevelVerbose) && r.conf.ShowNodeEvents) || + (report.Failed() && r.conf.ShowNodeEvents) + if !keepVeryVerboseSpecEvents { + timeline = timeline.WithoutVeryVerboseSpecEvents() + } + if len(timeline) == 0 && report.CapturedGinkgoWriterOutput == "" { + // the timeline is completely empty - don't show it + showTimeline = false + } + if v.LT(types.VerbosityLevelVeryVerbose) && report.CapturedGinkgoWriterOutput == "" && len(timeline) > 0 { + //if we aren't -vv and the timeline only has a single failure, don't show it as it will appear at the end of the report + failure, isFailure := timeline[0].(types.Failure) + if isFailure && (len(timeline) == 1 || (len(timeline) == 2 && failure.AdditionalFailure != nil)) { + showTimeline = false + } + } + } + + // should we have a separate section for always-visible reports? + showSeparateVisibilityAlwaysReportsSection := !timelineHasBeenStreaming && !showTimeline && report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways) + + // should we have a separate section for captured stdout/stderr + showSeparateStdSection := inParallel && (report.CapturedStdOutErr != "") + + // given all that - do we have any actual content to show? or are we a single denoter in a stream? + reportHasContent := v.Is(types.VerbosityLevelVeryVerbose) || showTimeline || showSeparateVisibilityAlwaysReportsSection || showSeparateStdSection || report.Failed() || (v.Is(types.VerbosityLevelVerbose) && !report.State.Is(types.SpecStateSkipped)) + + // should we show a runtime? + includeRuntime := !report.State.Is(types.SpecStateSkipped|types.SpecStatePending) || (report.State.Is(types.SpecStateSkipped) && report.Failure.Message != "") + + // should we show the codelocation block? + showCodeLocation := !timelineHasBeenStreaming || !report.State.Is(types.SpecStatePassed) + + switch report.State { + case types.SpecStatePassed: + if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) && !reportHasContent { + return + } + if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { + header = fmt.Sprintf("%s PASSED", header) + } + if report.NumAttempts > 1 && report.MaxFlakeAttempts > 1 { + header, reportHasContent = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), true + } + case types.SpecStatePending: + header = "P" + if v.GT(types.VerbosityLevelSuccinct) { + header, reportHasContent = "P [PENDING]", true + } + case types.SpecStateSkipped: + header = "S" + if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && report.Failure.Message != "") { + header, reportHasContent = "S [SKIPPED]", true + } + default: + header = fmt.Sprintf("%s [%s]", header, r.humanReadableState(report.State)) + if report.MaxMustPassRepeatedly > 1 { + header = fmt.Sprintf("%s DURING REPETITION #%d", header, report.NumAttempts) + } + } + + // If we have no content to show, jsut emit the header and return + if !reportHasContent { + r.emit(r.f(highlightColor + header + "{{/}}")) + return + } + + if includeRuntime { + header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds()) + } + + // Emit header + if !timelineHasBeenStreaming { + r.emitDelimiter(0) + } + r.emitBlock(r.f(highlightColor + header + "{{/}}")) + if showCodeLocation { + r.emitBlock(r.codeLocationBlock(report, highlightColor, v.Is(types.VerbosityLevelVeryVerbose), false)) + } + + //Emit Stdout/Stderr Output + if showSeparateStdSection { + r.emitBlock("\n") + r.emitBlock(r.fi(1, "{{gray}}Captured StdOut/StdErr Output >>{{/}}")) + r.emitBlock(r.fi(1, "%s", report.CapturedStdOutErr)) + r.emitBlock(r.fi(1, "{{gray}}<< Captured StdOut/StdErr Output{{/}}")) + } + + if showSeparateVisibilityAlwaysReportsSection { + r.emitBlock("\n") + r.emitBlock(r.fi(1, "{{gray}}Report Entries >>{{/}}")) + for _, entry := range report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways) { + r.emitReportEntry(1, entry) + } + r.emitBlock(r.fi(1, "{{gray}}<< Report Entries{{/}}")) + } + + if showTimeline { + r.emitBlock("\n") + r.emitBlock(r.fi(1, "{{gray}}Timeline >>{{/}}")) + r.emitTimeline(1, report, timeline) + r.emitBlock(r.fi(1, "{{gray}}<< Timeline{{/}}")) + } + + // Emit Failure Message + if !report.Failure.IsZero() && !v.Is(types.VerbosityLevelVeryVerbose) { + r.emitBlock("\n") + r.emitFailure(1, report.State, report.Failure, true) + if len(report.AdditionalFailures) > 0 { + r.emitBlock(r.fi(1, "\nThere were {{bold}}{{red}}additional failures{{/}} detected. To view them in detail run {{bold}}ginkgo -vv{{/}}")) + } + } + + r.emitDelimiter(0) +} + +func (r *DefaultReporter) highlightColorForState(state types.SpecState) string { + switch state { + case types.SpecStatePassed: + return "{{green}}" + case types.SpecStatePending: + return "{{yellow}}" + case types.SpecStateSkipped: + return "{{cyan}}" + case types.SpecStateFailed: + return "{{red}}" + case types.SpecStateTimedout: + return "{{orange}}" + case types.SpecStatePanicked: + return "{{magenta}}" + case types.SpecStateInterrupted: + return "{{orange}}" + case types.SpecStateAborted: + return "{{coral}}" + default: + return "{{gray}}" + } +} + +func (r *DefaultReporter) humanReadableState(state types.SpecState) string { + return strings.ToUpper(state.String()) +} + +func (r *DefaultReporter) emitTimeline(indent uint, report types.SpecReport, timeline types.Timeline) { + isVeryVerbose := r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose) + gw := report.CapturedGinkgoWriterOutput + cursor := 0 + for _, entry := range timeline { + tl := entry.GetTimelineLocation() + if tl.Offset < len(gw) { + r.emit(r.fi(indent, "%s", gw[cursor:tl.Offset])) + cursor = tl.Offset + } else if cursor < len(gw) { + r.emit(r.fi(indent, "%s", gw[cursor:])) + cursor = len(gw) + } + switch x := entry.(type) { + case types.Failure: + if isVeryVerbose { + r.emitFailure(indent, report.State, x, false) + } else { + r.emitShortFailure(indent, report.State, x) + } + case types.AdditionalFailure: + if isVeryVerbose { + r.emitFailure(indent, x.State, x.Failure, true) + } else { + r.emitShortFailure(indent, x.State, x.Failure) + } + case types.ReportEntry: + r.emitReportEntry(indent, x) + case types.ProgressReport: + r.emitProgressReport(indent, false, x) + case types.SpecEvent: + if isVeryVerbose || !x.IsOnlyVisibleAtVeryVerbose() || r.conf.ShowNodeEvents { + r.emitSpecEvent(indent, x, isVeryVerbose) + } + } + } + if cursor < len(gw) { + r.emit(r.fi(indent, "%s", gw[cursor:])) + } +} + +func (r *DefaultReporter) EmitFailure(state types.SpecState, failure types.Failure) { + if r.conf.Verbosity().Is(types.VerbosityLevelVerbose) { + r.emitShortFailure(1, state, failure) + } else if r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose) { + r.emitFailure(1, state, failure, true) + } +} + +func (r *DefaultReporter) emitShortFailure(indent uint, state types.SpecState, failure types.Failure) { + r.emitBlock(r.fi(indent, r.highlightColorForState(state)+"[%s]{{/}} in [%s] - %s {{gray}}@ %s{{/}}", + r.humanReadableState(state), + failure.FailureNodeType, + failure.Location, + failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), + )) +} + +func (r *DefaultReporter) emitFailure(indent uint, state types.SpecState, failure types.Failure, includeAdditionalFailure bool) { + highlightColor := r.highlightColorForState(state) + r.emitBlock(r.fi(indent, highlightColor+"[%s] %s{{/}}", r.humanReadableState(state), failure.Message)) + r.emitBlock(r.fi(indent, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}} {{gray}}@ %s{{/}}\n", failure.FailureNodeType, failure.Location, failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT))) + if failure.ForwardedPanic != "" { + r.emitBlock("\n") + r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.ForwardedPanic)) + } + + if r.conf.FullTrace || failure.ForwardedPanic != "" { + r.emitBlock("\n") + r.emitBlock(r.fi(indent, highlightColor+"Full Stack Trace{{/}}")) + r.emitBlock(r.fi(indent+1, "%s", failure.Location.FullStackTrace)) + } + + if !failure.ProgressReport.IsZero() { + r.emitBlock("\n") + r.emitProgressReport(indent, false, failure.ProgressReport) + } + + if failure.AdditionalFailure != nil && includeAdditionalFailure { + r.emitBlock("\n") + r.emitFailure(indent, failure.AdditionalFailure.State, failure.AdditionalFailure.Failure, true) + } +} + func (r *DefaultReporter) EmitProgressReport(report types.ProgressReport) { - r.emitDelimiter() + r.emitDelimiter(1) if report.RunningInParallel { - r.emit(r.f("{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess)) + r.emit(r.fi(1, "{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess)) } - r.emitProgressReport(0, true, report) - r.emitDelimiter() + shouldEmitGW := report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose) + r.emitProgressReport(1, shouldEmitGW, report) + r.emitDelimiter(1) } func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput bool, report types.ProgressReport) { + if report.Message != "" { + r.emitBlock(r.fi(indent, report.Message+"\n")) + indent += 1 + } if report.LeafNodeText != "" { + subjectIndent := indent if len(report.ContainerHierarchyTexts) > 0 { r.emit(r.fi(indent, r.cycleJoin(report.ContainerHierarchyTexts, " "))) r.emit(" ") + subjectIndent = 0 } - r.emit(r.f("{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time.Sub(report.SpecStartTime).Round(time.Millisecond))) + r.emit(r.fi(subjectIndent, "{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time().Sub(report.SpecStartTime).Round(time.Millisecond))) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.LeafNodeLocation)) indent += 1 } @@ -344,12 +461,12 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput r.emit(r.f(" {{bold}}{{orange}}%s{{/}}", report.CurrentNodeText)) } - r.emit(r.f(" (Node Runtime: %s)\n", report.Time.Sub(report.CurrentNodeStartTime).Round(time.Millisecond))) + r.emit(r.f(" (Node Runtime: %s)\n", report.Time().Sub(report.CurrentNodeStartTime).Round(time.Millisecond))) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentNodeLocation)) indent += 1 } if report.CurrentStepText != "" { - r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time.Sub(report.CurrentStepStartTime).Round(time.Millisecond))) + r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time().Sub(report.CurrentStepStartTime).Round(time.Millisecond))) r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentStepLocation)) indent += 1 } @@ -358,9 +475,19 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput indent -= 1 } - if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" && (report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)) { + if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" { r.emit("\n") - r.emitGinkgoWriterOutput(indent, report.CapturedGinkgoWriterOutput, 10) + r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}")) + limit, lines := 10, strings.Split(report.CapturedGinkgoWriterOutput, "\n") + if len(lines) <= limit { + r.emitBlock(r.fi(indent+1, "%s", report.CapturedGinkgoWriterOutput)) + } else { + r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}")) + for _, line := range lines[len(lines)-limit-1:] { + r.emitBlock(r.fi(indent+1, "%s", line)) + } + } + r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}")) } if !report.SpecGoroutine().IsZero() { @@ -369,6 +496,18 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput r.emitGoroutines(indent, report.SpecGoroutine()) } + if len(report.AdditionalReports) > 0 { + r.emit("\n") + r.emitBlock(r.fi(indent, "{{gray}}Begin Additional Progress Reports >>{{/}}")) + for i, additionalReport := range report.AdditionalReports { + r.emit(r.fi(indent+1, additionalReport)) + if i < len(report.AdditionalReports)-1 { + r.emitBlock(r.fi(indent+1, "{{gray}}%s{{/}}", strings.Repeat("-", 10))) + } + } + r.emitBlock(r.fi(indent, "{{gray}}<< End Additional Progress Reports{{/}}")) + } + highlightedGoroutines := report.HighlightedGoroutines() if len(highlightedGoroutines) > 0 { r.emit("\n") @@ -384,22 +523,48 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput } } -func (r *DefaultReporter) emitGinkgoWriterOutput(indent uint, output string, limit int) { - r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}")) - if limit == 0 { - r.emitBlock(r.fi(indent+1, "%s", output)) - } else { - lines := strings.Split(output, "\n") - if len(lines) <= limit { - r.emitBlock(r.fi(indent+1, "%s", output)) - } else { - r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}")) - for _, line := range lines[len(lines)-limit-1:] { - r.emitBlock(r.fi(indent+1, "%s", line)) - } - } +func (r *DefaultReporter) EmitReportEntry(entry types.ReportEntry) { + if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || entry.Visibility == types.ReportEntryVisibilityNever { + return + } + r.emitReportEntry(1, entry) +} + +func (r *DefaultReporter) emitReportEntry(indent uint, entry types.ReportEntry) { + r.emitBlock(r.fi(indent, "{{bold}}"+entry.Name+"{{gray}} "+fmt.Sprintf("- %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT)))) + if representation := entry.StringRepresentation(); representation != "" { + r.emitBlock(r.fi(indent+1, representation)) + } +} + +func (r *DefaultReporter) EmitSpecEvent(event types.SpecEvent) { + v := r.conf.Verbosity() + if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && (r.conf.ShowNodeEvents || !event.IsOnlyVisibleAtVeryVerbose())) { + r.emitSpecEvent(1, event, r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose)) + } +} + +func (r *DefaultReporter) emitSpecEvent(indent uint, event types.SpecEvent, includeLocation bool) { + location := "" + if includeLocation { + location = fmt.Sprintf("- %s ", event.CodeLocation.String()) + } + switch event.SpecEventType { + case types.SpecEventInvalid: + return + case types.SpecEventByStart: + r.emitBlock(r.fi(indent, "{{bold}}STEP:{{/}} %s {{gray}}%s@ %s{{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT))) + case types.SpecEventByEnd: + r.emitBlock(r.fi(indent, "{{bold}}END STEP:{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond))) + case types.SpecEventNodeStart: + r.emitBlock(r.fi(indent, "> Enter {{bold}}[%s]{{/}} %s {{gray}}%s@ %s{{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT))) + case types.SpecEventNodeEnd: + r.emitBlock(r.fi(indent, "< Exit {{bold}}[%s]{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond))) + case types.SpecEventSpecRepeat: + r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{green}}Passed{{/}}{{bold}}. Repeating %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT))) + case types.SpecEventSpecRetry: + r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{red}}Failed{{/}}{{bold}}. Retrying %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT))) } - r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}")) } func (r *DefaultReporter) emitGoroutines(indent uint, goroutines ...types.Goroutine) { @@ -457,31 +622,37 @@ func (r *DefaultReporter) emitSource(indent uint, fc types.FunctionCall) { /* Emitting to the writer */ func (r *DefaultReporter) emit(s string) { - if len(s) > 0 { - r.lastChar = s[len(s)-1:] - r.lastEmissionWasDelimiter = false - r.writer.Write([]byte(s)) - } + r._emit(s, false, false) } func (r *DefaultReporter) emitBlock(s string) { - if len(s) > 0 { - if r.lastChar != "\n" { - r.emit("\n") - } - r.emit(s) - if r.lastChar != "\n" { - r.emit("\n") - } - } + r._emit(s, true, false) } -func (r *DefaultReporter) emitDelimiter() { - if r.lastEmissionWasDelimiter { +func (r *DefaultReporter) emitDelimiter(indent uint) { + r._emit(r.fi(indent, "{{gray}}%s{{/}}", strings.Repeat("-", 30)), true, true) +} + +// a bit ugly - but we're trying to minimize locking on this hot codepath +func (r *DefaultReporter) _emit(s string, block bool, isDelimiter bool) { + if len(s) == 0 { return } - r.emitBlock(r.f("{{gray}}%s{{/}}", strings.Repeat("-", 30))) - r.lastEmissionWasDelimiter = true + r.lock.Lock() + defer r.lock.Unlock() + if isDelimiter && r.lastEmissionWasDelimiter { + return + } + if block && !r.lastCharWasNewline { + r.writer.Write([]byte("\n")) + } + r.lastCharWasNewline = (s[len(s)-1:] == "\n") + r.writer.Write([]byte(s)) + if block && !r.lastCharWasNewline { + r.writer.Write([]byte("\n")) + r.lastCharWasNewline = true + } + r.lastEmissionWasDelimiter = isDelimiter } /* Rendering text */ @@ -497,13 +668,14 @@ func (r *DefaultReporter) cycleJoin(elements []string, joiner string) string { return r.formatter.CycleJoin(elements, joiner, []string{"{{/}}", "{{gray}}"}) } -func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, succinct bool, usePreciseFailureLocation bool) string { +func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, veryVerbose bool, usePreciseFailureLocation bool) string { texts, locations, labels := []string{}, []types.CodeLocation{}, [][]string{} texts, locations, labels = append(texts, report.ContainerHierarchyTexts...), append(locations, report.ContainerHierarchyLocations...), append(labels, report.ContainerHierarchyLabels...) + if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) { texts = append(texts, r.f("[%s] %s", report.LeafNodeType, report.LeafNodeText)) } else { - texts = append(texts, report.LeafNodeText) + texts = append(texts, r.f(report.LeafNodeText)) } labels = append(labels, report.LeafNodeLabels) locations = append(locations, report.LeafNodeLocation) @@ -513,24 +685,58 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo failureLocation = report.Failure.Location } + highlightIndex := -1 switch report.Failure.FailureNodeContext { case types.FailureNodeAtTopLevel: - texts = append([]string{r.f(highlightColor+"{{bold}}TOP-LEVEL [%s]{{/}}", report.Failure.FailureNodeType)}, texts...) + texts = append([]string{fmt.Sprintf("TOP-LEVEL [%s]", report.Failure.FailureNodeType)}, texts...) locations = append([]types.CodeLocation{failureLocation}, locations...) labels = append([][]string{{}}, labels...) + highlightIndex = 0 case types.FailureNodeInContainer: i := report.Failure.FailureNodeContainerIndex - texts[i] = r.f(highlightColor+"{{bold}}%s [%s]{{/}}", texts[i], report.Failure.FailureNodeType) + texts[i] = fmt.Sprintf("%s [%s]", texts[i], report.Failure.FailureNodeType) locations[i] = failureLocation + highlightIndex = i case types.FailureNodeIsLeafNode: i := len(texts) - 1 - texts[i] = r.f(highlightColor+"{{bold}}[%s] %s{{/}}", report.LeafNodeType, report.LeafNodeText) + texts[i] = fmt.Sprintf("[%s] %s", report.LeafNodeType, report.LeafNodeText) locations[i] = failureLocation + highlightIndex = i + default: + //there is no failure, so we highlight the leaf ndoe + highlightIndex = len(texts) - 1 } out := "" - if succinct { - out += r.f("%s", r.cycleJoin(texts, " ")) + if veryVerbose { + for i := range texts { + if i == highlightIndex { + out += r.fi(uint(i), highlightColor+"{{bold}}%s{{/}}", texts[i]) + } else { + out += r.fi(uint(i), "%s", texts[i]) + } + if len(labels[i]) > 0 { + out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", ")) + } + out += "\n" + out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i]) + } + } else { + for i := range texts { + style := "{{/}}" + if i%2 == 1 { + style = "{{gray}}" + } + if i == highlightIndex { + style = highlightColor + "{{bold}}" + } + out += r.f(style+"%s", texts[i]) + if i < len(texts)-1 { + out += " " + } else { + out += r.f("{{/}}") + } + } flattenedLabels := report.Labels() if len(flattenedLabels) > 0 { out += r.f(" {{coral}}[%s]{{/}}", strings.Join(flattenedLabels, ", ")) @@ -539,17 +745,15 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo if usePreciseFailureLocation { out += r.f("{{gray}}%s{{/}}", failureLocation) } else { - out += r.f("{{gray}}%s{{/}}", locations[len(locations)-1]) - } - } else { - for i := range texts { - out += r.fi(uint(i), "%s", texts[i]) - if len(labels[i]) > 0 { - out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", ")) + leafLocation := locations[len(locations)-1] + if (report.Failure.FailureNodeLocation != types.CodeLocation{}) && (report.Failure.FailureNodeLocation != leafLocation) { + out += r.fi(1, highlightColor+"[%s]{{/}} {{gray}}%s{{/}}\n", report.Failure.FailureNodeType, report.Failure.FailureNodeLocation) + out += r.fi(1, "{{gray}}[%s] %s{{/}}", report.LeafNodeType, leafLocation) + } else { + out += r.f("{{gray}}%s{{/}}", leafLocation) } - out += "\n" - out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i]) } + } return out } diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/deprecated_reporter.go b/vendor/github.com/onsi/ginkgo/v2/reporters/deprecated_reporter.go index 89d30076..613072eb 100644 --- a/vendor/github.com/onsi/ginkgo/v2/reporters/deprecated_reporter.go +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/deprecated_reporter.go @@ -35,7 +35,7 @@ func ReportViaDeprecatedReporter(reporter DeprecatedReporter, report types.Repor FailOnPending: report.SuiteConfig.FailOnPending, FailFast: report.SuiteConfig.FailFast, FlakeAttempts: report.SuiteConfig.FlakeAttempts, - EmitSpecProgress: report.SuiteConfig.EmitSpecProgress, + EmitSpecProgress: false, DryRun: report.SuiteConfig.DryRun, ParallelNode: report.SuiteConfig.ParallelProcess, ParallelTotal: report.SuiteConfig.ParallelTotal, diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/junit_report.go b/vendor/github.com/onsi/ginkgo/v2/reporters/junit_report.go index 0f165069..592d7f61 100644 --- a/vendor/github.com/onsi/ginkgo/v2/reporters/junit_report.go +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/junit_report.go @@ -15,12 +15,32 @@ import ( "fmt" "os" "strings" - "time" "github.com/onsi/ginkgo/v2/config" "github.com/onsi/ginkgo/v2/types" ) +type JunitReportConfig struct { + // Spec States for which no timeline should be emitted for system-err + // set this to types.SpecStatePassed|types.SpecStateSkipped|types.SpecStatePending to only match failing specs + OmitTimelinesForSpecState types.SpecState + + // Enable OmitFailureMessageAttr to prevent failure messages appearing in the "message" attribute of the Failure and Error tags + OmitFailureMessageAttr bool + + //Enable OmitCapturedStdOutErr to prevent captured stdout/stderr appearing in system-out + OmitCapturedStdOutErr bool + + // Enable OmitSpecLabels to prevent labels from appearing in the spec name + OmitSpecLabels bool + + // Enable OmitLeafNodeType to prevent the spec leaf node type from appearing in the spec name + OmitLeafNodeType bool + + // Enable OmitSuiteSetupNodes to prevent the creation of testcase entries for setup nodes + OmitSuiteSetupNodes bool +} + type JUnitTestSuites struct { XMLName xml.Name `xml:"testsuites"` // Tests maps onto the total number of specs in all test suites (this includes any suite nodes such as BeforeSuite) @@ -128,6 +148,10 @@ type JUnitFailure struct { } func GenerateJUnitReport(report types.Report, dst string) error { + return GenerateJUnitReportWithConfig(report, dst, JunitReportConfig{}) +} + +func GenerateJUnitReportWithConfig(report types.Report, dst string, config JunitReportConfig) error { suite := JUnitTestSuite{ Name: report.SuiteDescription, Package: report.SuitePath, @@ -149,7 +173,6 @@ func GenerateJUnitReport(report types.Report, dst string) error { {"FailOnPending", fmt.Sprintf("%t", report.SuiteConfig.FailOnPending)}, {"FailFast", fmt.Sprintf("%t", report.SuiteConfig.FailFast)}, {"FlakeAttempts", fmt.Sprintf("%d", report.SuiteConfig.FlakeAttempts)}, - {"EmitSpecProgress", fmt.Sprintf("%t", report.SuiteConfig.EmitSpecProgress)}, {"DryRun", fmt.Sprintf("%t", report.SuiteConfig.DryRun)}, {"ParallelTotal", fmt.Sprintf("%d", report.SuiteConfig.ParallelTotal)}, {"OutputInterceptorMode", report.SuiteConfig.OutputInterceptorMode}, @@ -157,22 +180,33 @@ func GenerateJUnitReport(report types.Report, dst string) error { }, } for _, spec := range report.SpecReports { + if config.OmitSuiteSetupNodes && spec.LeafNodeType != types.NodeTypeIt { + continue + } name := fmt.Sprintf("[%s]", spec.LeafNodeType) + if config.OmitLeafNodeType { + name = "" + } if spec.FullText() != "" { name = name + " " + spec.FullText() } labels := spec.Labels() - if len(labels) > 0 { + if len(labels) > 0 && !config.OmitSpecLabels { name = name + " [" + strings.Join(labels, ", ") + "]" } + name = strings.TrimSpace(name) test := JUnitTestCase{ Name: name, Classname: report.SuiteDescription, Status: spec.State.String(), Time: spec.RunTime.Seconds(), - SystemOut: systemOutForUnstructuredReporters(spec), - SystemErr: systemErrForUnstructuredReporters(spec), + } + if !spec.State.Is(config.OmitTimelinesForSpecState) { + test.SystemErr = systemErrForUnstructuredReporters(spec) + } + if !config.OmitCapturedStdOutErr { + test.SystemOut = systemOutForUnstructuredReporters(spec) } suite.Tests += 1 @@ -191,28 +225,50 @@ func GenerateJUnitReport(report types.Report, dst string) error { test.Failure = &JUnitFailure{ Message: spec.Failure.Message, Type: "failed", - Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), + Description: failureDescriptionForUnstructuredReporters(spec), + } + if config.OmitFailureMessageAttr { + test.Failure.Message = "" + } + suite.Failures += 1 + case types.SpecStateTimedout: + test.Failure = &JUnitFailure{ + Message: spec.Failure.Message, + Type: "timedout", + Description: failureDescriptionForUnstructuredReporters(spec), + } + if config.OmitFailureMessageAttr { + test.Failure.Message = "" } suite.Failures += 1 case types.SpecStateInterrupted: test.Error = &JUnitError{ - Message: "interrupted", + Message: spec.Failure.Message, Type: "interrupted", - Description: interruptDescriptionForUnstructuredReporters(spec.Failure), + Description: failureDescriptionForUnstructuredReporters(spec), + } + if config.OmitFailureMessageAttr { + test.Error.Message = "" } suite.Errors += 1 case types.SpecStateAborted: test.Failure = &JUnitFailure{ Message: spec.Failure.Message, Type: "aborted", - Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), + Description: failureDescriptionForUnstructuredReporters(spec), + } + if config.OmitFailureMessageAttr { + test.Failure.Message = "" } suite.Errors += 1 case types.SpecStatePanicked: test.Error = &JUnitError{ Message: spec.Failure.ForwardedPanic, Type: "panicked", - Description: fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace), + Description: failureDescriptionForUnstructuredReporters(spec), + } + if config.OmitFailureMessageAttr { + test.Error.Message = "" } suite.Errors += 1 } @@ -278,52 +334,27 @@ func MergeAndCleanupJUnitReports(sources []string, dst string) ([]string, error) return messages, f.Close() } -func interruptDescriptionForUnstructuredReporters(failure types.Failure) string { +func failureDescriptionForUnstructuredReporters(spec types.SpecReport) string { out := &strings.Builder{} - out.WriteString(failure.Message + "\n") - NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(failure.ProgressReport) + NewDefaultReporter(types.ReporterConfig{NoColor: true, VeryVerbose: true}, out).emitFailure(0, spec.State, spec.Failure, true) + if len(spec.AdditionalFailures) > 0 { + out.WriteString("\nThere were additional failures detected after the initial failure. These are visible in the timeline\n") + } return out.String() } func systemErrForUnstructuredReporters(spec types.SpecReport) string { + return RenderTimeline(spec, true) +} + +func RenderTimeline(spec types.SpecReport, noColor bool) string { out := &strings.Builder{} - gw := spec.CapturedGinkgoWriterOutput - cursor := 0 - for _, pr := range spec.ProgressReports { - if cursor < pr.GinkgoWriterOffset { - if pr.GinkgoWriterOffset < len(gw) { - out.WriteString(gw[cursor:pr.GinkgoWriterOffset]) - cursor = pr.GinkgoWriterOffset - } else if cursor < len(gw) { - out.WriteString(gw[cursor:]) - cursor = len(gw) - } - } - NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(pr) - } - - if cursor < len(gw) { - out.WriteString(gw[cursor:]) - } - + NewDefaultReporter(types.ReporterConfig{NoColor: noColor, VeryVerbose: true}, out).emitTimeline(0, spec, spec.Timeline()) return out.String() } func systemOutForUnstructuredReporters(spec types.SpecReport) string { - systemOut := spec.CapturedStdOutErr - if len(spec.ReportEntries) > 0 { - systemOut += "\nReport Entries:\n" - for i, entry := range spec.ReportEntries { - systemOut += fmt.Sprintf("%s\n%s\n%s\n", entry.Name, entry.Location, entry.Time.Format(time.RFC3339Nano)) - if representation := entry.StringRepresentation(); representation != "" { - systemOut += representation + "\n" - } - if i+1 < len(spec.ReportEntries) { - systemOut += "--\n" - } - } - } - return systemOut + return spec.CapturedStdOutErr } // Deprecated JUnitReporter (so folks can still compile their suites) diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/reporter.go b/vendor/github.com/onsi/ginkgo/v2/reporters/reporter.go index f79f005d..5e726c46 100644 --- a/vendor/github.com/onsi/ginkgo/v2/reporters/reporter.go +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/reporter.go @@ -9,13 +9,21 @@ type Reporter interface { WillRun(report types.SpecReport) DidRun(report types.SpecReport) SuiteDidEnd(report types.Report) + + //Timeline emission + EmitFailure(state types.SpecState, failure types.Failure) EmitProgressReport(progressReport types.ProgressReport) + EmitReportEntry(entry types.ReportEntry) + EmitSpecEvent(event types.SpecEvent) } type NoopReporter struct{} -func (n NoopReporter) SuiteWillBegin(report types.Report) {} -func (n NoopReporter) WillRun(report types.SpecReport) {} -func (n NoopReporter) DidRun(report types.SpecReport) {} -func (n NoopReporter) SuiteDidEnd(report types.Report) {} -func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {} +func (n NoopReporter) SuiteWillBegin(report types.Report) {} +func (n NoopReporter) WillRun(report types.SpecReport) {} +func (n NoopReporter) DidRun(report types.SpecReport) {} +func (n NoopReporter) SuiteDidEnd(report types.Report) {} +func (n NoopReporter) EmitFailure(state types.SpecState, failure types.Failure) {} +func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {} +func (n NoopReporter) EmitReportEntry(entry types.ReportEntry) {} +func (n NoopReporter) EmitSpecEvent(event types.SpecEvent) {} diff --git a/vendor/github.com/onsi/ginkgo/v2/reporters/teamcity_report.go b/vendor/github.com/onsi/ginkgo/v2/reporters/teamcity_report.go index 00b03876..c1863496 100644 --- a/vendor/github.com/onsi/ginkgo/v2/reporters/teamcity_report.go +++ b/vendor/github.com/onsi/ginkgo/v2/reporters/teamcity_report.go @@ -60,15 +60,19 @@ func GenerateTeamcityReport(report types.Report, dst string) error { } fmt.Fprintf(f, "##teamcity[testIgnored name='%s' message='%s']\n", name, tcEscape(message)) case types.SpecStateFailed: - details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) + details := failureDescriptionForUnstructuredReporters(spec) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='failed - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) case types.SpecStatePanicked: - details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) + details := failureDescriptionForUnstructuredReporters(spec) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='panicked - %s' details='%s']\n", name, tcEscape(spec.Failure.ForwardedPanic), tcEscape(details)) + case types.SpecStateTimedout: + details := failureDescriptionForUnstructuredReporters(spec) + fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='timedout - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) case types.SpecStateInterrupted: - fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='interrupted' details='%s']\n", name, tcEscape(interruptDescriptionForUnstructuredReporters(spec.Failure))) + details := failureDescriptionForUnstructuredReporters(spec) + fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='interrupted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) case types.SpecStateAborted: - details := fmt.Sprintf("%s\n%s", spec.Failure.Location.String(), spec.Failure.Location.FullStackTrace) + details := failureDescriptionForUnstructuredReporters(spec) fmt.Fprintf(f, "##teamcity[testFailed name='%s' message='aborted - %s' details='%s']\n", name, tcEscape(spec.Failure.Message), tcEscape(details)) } diff --git a/vendor/github.com/onsi/ginkgo/v2/types/code_location.go b/vendor/github.com/onsi/ginkgo/v2/types/code_location.go index 12910918..9cd57681 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/code_location.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/code_location.go @@ -7,6 +7,7 @@ import ( "runtime" "runtime/debug" "strings" + "sync" ) type CodeLocation struct { @@ -38,6 +39,73 @@ func (codeLocation CodeLocation) ContentsOfLine() string { return lines[codeLocation.LineNumber-1] } +type codeLocationLocator struct { + pcs map[uintptr]bool + helpers map[string]bool + lock *sync.Mutex +} + +func (c *codeLocationLocator) addHelper(pc uintptr) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.pcs[pc] { + return + } + c.lock.Unlock() + f := runtime.FuncForPC(pc) + c.lock.Lock() + if f == nil { + return + } + c.helpers[f.Name()] = true + c.pcs[pc] = true +} + +func (c *codeLocationLocator) hasHelper(name string) bool { + c.lock.Lock() + defer c.lock.Unlock() + return c.helpers[name] +} + +func (c *codeLocationLocator) getCodeLocation(skip int) CodeLocation { + pc := make([]uintptr, 40) + n := runtime.Callers(skip+2, pc) + if n == 0 { + return CodeLocation{} + } + pc = pc[:n] + frames := runtime.CallersFrames(pc) + for { + frame, more := frames.Next() + if !c.hasHelper(frame.Function) { + return CodeLocation{FileName: frame.File, LineNumber: frame.Line} + } + if !more { + break + } + } + return CodeLocation{} +} + +var clLocator = &codeLocationLocator{ + pcs: map[uintptr]bool{}, + helpers: map[string]bool{}, + lock: &sync.Mutex{}, +} + +// MarkAsHelper is used by GinkgoHelper to mark the caller (appropriately offset by skip)as a helper. You can use this directly if you need to provide an optional `skip` to mark functions further up the call stack as helpers. +func MarkAsHelper(optionalSkip ...int) { + skip := 1 + if len(optionalSkip) > 0 { + skip += optionalSkip[0] + } + pc, _, _, ok := runtime.Caller(skip) + if ok { + clLocator.addHelper(pc) + } +} + func NewCustomCodeLocation(message string) CodeLocation { return CodeLocation{ CustomMessage: message, @@ -45,14 +113,13 @@ func NewCustomCodeLocation(message string) CodeLocation { } func NewCodeLocation(skip int) CodeLocation { - _, file, line, _ := runtime.Caller(skip + 1) - return CodeLocation{FileName: file, LineNumber: line} + return clLocator.getCodeLocation(skip + 1) } func NewCodeLocationWithStackTrace(skip int) CodeLocation { - _, file, line, _ := runtime.Caller(skip + 1) - stackTrace := PruneStack(string(debug.Stack()), skip+1) - return CodeLocation{FileName: file, LineNumber: line, FullStackTrace: stackTrace} + cl := clLocator.getCodeLocation(skip + 1) + cl.FullStackTrace = PruneStack(string(debug.Stack()), skip+1) + return cl } // PruneStack removes references to functions that are internal to Ginkgo diff --git a/vendor/github.com/onsi/ginkgo/v2/types/config.go b/vendor/github.com/onsi/ginkgo/v2/types/config.go index 438c947c..1014c7b4 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/config.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/config.go @@ -8,6 +8,7 @@ package types import ( "flag" "os" + "path/filepath" "runtime" "strconv" "strings" @@ -26,13 +27,14 @@ type SuiteConfig struct { FailOnPending bool FailFast bool FlakeAttempts int - EmitSpecProgress bool DryRun bool PollProgressAfter time.Duration PollProgressInterval time.Duration Timeout time.Duration + EmitSpecProgress bool // this is deprecated but its removal is causing compile issue for some users that were setting it manually OutputInterceptorMode string SourceRoots []string + GracePeriod time.Duration ParallelProcess int ParallelTotal int @@ -45,6 +47,7 @@ func NewDefaultSuiteConfig() SuiteConfig { Timeout: time.Hour, ParallelProcess: 1, ParallelTotal: 1, + GracePeriod: 30 * time.Second, } } @@ -79,13 +82,12 @@ func (vl VerbosityLevel) LT(comp VerbosityLevel) bool { // Configuration for Ginkgo's reporter type ReporterConfig struct { - NoColor bool - SlowSpecThreshold time.Duration - Succinct bool - Verbose bool - VeryVerbose bool - FullTrace bool - AlwaysEmitGinkgoWriter bool + NoColor bool + Succinct bool + Verbose bool + VeryVerbose bool + FullTrace bool + ShowNodeEvents bool JSONReport string JUnitReport string @@ -108,9 +110,7 @@ func (rc ReporterConfig) WillGenerateReport() bool { } func NewDefaultReporterConfig() ReporterConfig { - return ReporterConfig{ - SlowSpecThreshold: 5 * time.Second, - } + return ReporterConfig{} } // Configuration for the Ginkgo CLI @@ -233,6 +233,9 @@ type deprecatedConfig struct { SlowSpecThresholdWithFLoatUnits float64 Stream bool Notify bool + EmitSpecProgress bool + SlowSpecThreshold time.Duration + AlwaysEmitGinkgoWriter bool } // Flags @@ -273,8 +276,6 @@ var SuiteConfigFlags = GinkgoFlags{ {KeyPath: "S.DryRun", Name: "dry-run", SectionKey: "debug", DeprecatedName: "dryRun", DeprecatedDocLink: "changed-command-line-flags", Usage: "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v."}, - {KeyPath: "S.EmitSpecProgress", Name: "progress", SectionKey: "debug", - Usage: "If set, ginkgo will emit progress information as each spec runs to the GinkgoWriter."}, {KeyPath: "S.PollProgressAfter", Name: "poll-progress-after", SectionKey: "debug", UsageDefaultValue: "0", Usage: "Emit node progress reports periodically if node hasn't completed after this duration."}, {KeyPath: "S.PollProgressInterval", Name: "poll-progress-interval", SectionKey: "debug", UsageDefaultValue: "10s", @@ -283,6 +284,8 @@ var SuiteConfigFlags = GinkgoFlags{ Usage: "The location to look for source code when generating progress reports. You can pass multiple --source-root flags."}, {KeyPath: "S.Timeout", Name: "timeout", SectionKey: "debug", UsageDefaultValue: "1h", Usage: "Test suite fails if it does not complete within the specified timeout."}, + {KeyPath: "S.GracePeriod", Name: "grace-period", SectionKey: "debug", UsageDefaultValue: "30s", + Usage: "When interrupted, Ginkgo will wait for GracePeriod for the current running node to exit before moving on to the next one."}, {KeyPath: "S.OutputInterceptorMode", Name: "output-interceptor-mode", SectionKey: "debug", UsageArgument: "dup, swap, or none", Usage: "If set, ginkgo will use the specified output interception strategy when running in parallel. Defaults to dup on unix and swap on windows."}, @@ -299,6 +302,8 @@ var SuiteConfigFlags = GinkgoFlags{ {KeyPath: "D.RegexScansFilePath", DeprecatedName: "regexScansFilePath", DeprecatedDocLink: "removed--regexscansfilepath", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.DebugParallel", DeprecatedName: "debug", DeprecatedDocLink: "removed--debug", DeprecatedVersion: "2.0.0"}, + {KeyPath: "D.EmitSpecProgress", DeprecatedName: "progress", SectionKey: "debug", + DeprecatedVersion: "2.5.0", Usage: ". The functionality provided by --progress was confusing and is no longer needed. Use --show-node-events instead to see node entry and exit events included in the timeline of failed and verbose specs. Or you can run with -vv to always see all node events. Lastly, --poll-progress-after and the PollProgressAfter decorator now provide a better mechanism for debugging specs that tend to get stuck."}, } // ParallelConfigFlags provides flags for the Ginkgo test process (not the CLI) @@ -315,8 +320,6 @@ var ParallelConfigFlags = GinkgoFlags{ var ReporterConfigFlags = GinkgoFlags{ {KeyPath: "R.NoColor", Name: "no-color", SectionKey: "output", DeprecatedName: "noColor", DeprecatedDocLink: "changed-command-line-flags", Usage: "If set, suppress color output in default reporter."}, - {KeyPath: "R.SlowSpecThreshold", Name: "slow-spec-threshold", SectionKey: "output", UsageArgument: "duration", UsageDefaultValue: "5s", - Usage: "Specs that take longer to run than this threshold are flagged as slow by the default reporter."}, {KeyPath: "R.Verbose", Name: "v", SectionKey: "output", Usage: "If set, emits more output including GinkgoWriter contents."}, {KeyPath: "R.VeryVerbose", Name: "vv", SectionKey: "output", @@ -325,8 +328,8 @@ var ReporterConfigFlags = GinkgoFlags{ Usage: "If set, default reporter prints out a very succinct report"}, {KeyPath: "R.FullTrace", Name: "trace", SectionKey: "output", Usage: "If set, default reporter prints out the full stack trace when a failure occurs"}, - {KeyPath: "R.AlwaysEmitGinkgoWriter", Name: "always-emit-ginkgo-writer", SectionKey: "output", DeprecatedName: "reportPassed", DeprecatedDocLink: "renamed--reportpassed", - Usage: "If set, default reporter prints out captured output of passed tests."}, + {KeyPath: "R.ShowNodeEvents", Name: "show-node-events", SectionKey: "output", + Usage: "If set, default reporter prints node > Enter and < Exit events when specs fail"}, {KeyPath: "R.JSONReport", Name: "json-report", UsageArgument: "filename.json", SectionKey: "output", Usage: "If set, Ginkgo will generate a JSON-formatted test report at the specified location."}, @@ -339,6 +342,8 @@ var ReporterConfigFlags = GinkgoFlags{ Usage: "use --slow-spec-threshold instead and pass in a duration string (e.g. '5s', not '5.0')"}, {KeyPath: "D.NoisyPendings", DeprecatedName: "noisyPendings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"}, {KeyPath: "D.NoisySkippings", DeprecatedName: "noisySkippings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"}, + {KeyPath: "D.SlowSpecThreshold", DeprecatedName: "slow-spec-threshold", SectionKey: "output", Usage: "--slow-spec-threshold has been deprecated and will be removed in a future version of Ginkgo. This feature has proved to be more noisy than useful. You can use --poll-progress-after, instead, to get more actionable feedback about potentially slow specs and understand where they might be getting stuck.", DeprecatedVersion: "2.5.0"}, + {KeyPath: "D.AlwaysEmitGinkgoWriter", DeprecatedName: "always-emit-ginkgo-writer", SectionKey: "output", Usage: " - use -v instead, or one of Ginkgo's machine-readable report formats to get GinkgoWriter output for passing specs."}, } // BuildTestSuiteFlagSet attaches to the CommandLine flagset and provides flags for the Ginkgo test process @@ -390,6 +395,10 @@ func VetConfig(flagSet GinkgoFlagSet, suiteConfig SuiteConfig, reporterConfig Re errors = append(errors, GinkgoErrors.DryRunInParallelConfiguration()) } + if suiteConfig.GracePeriod <= 0 { + errors = append(errors, GinkgoErrors.GracePeriodCannotBeZero()) + } + if len(suiteConfig.FocusFiles) > 0 { _, err := ParseFileFilters(suiteConfig.FocusFiles) if err != nil { @@ -592,13 +601,29 @@ func VetAndInitializeCLIAndGoConfig(cliConfig CLIConfig, goFlagsConfig GoFlagsCo } // GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test -func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string) ([]string, error) { +func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string, pathToInvocationPath string) ([]string, error) { // if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure // the built test binary can generate a coverprofile if goFlagsConfig.CoverProfile != "" { goFlagsConfig.Cover = true } + if goFlagsConfig.CoverPkg != "" { + coverPkgs := strings.Split(goFlagsConfig.CoverPkg, ",") + adjustedCoverPkgs := make([]string, len(coverPkgs)) + for i, coverPkg := range coverPkgs { + coverPkg = strings.Trim(coverPkg, " ") + if strings.HasPrefix(coverPkg, "./") { + // this is a relative coverPkg - we need to reroot it + adjustedCoverPkgs[i] = "./" + filepath.Join(pathToInvocationPath, strings.TrimPrefix(coverPkg, "./")) + } else { + // this is a package name - don't touch it + adjustedCoverPkgs[i] = coverPkg + } + } + goFlagsConfig.CoverPkg = strings.Join(adjustedCoverPkgs, ",") + } + args := []string{"test", "-c", "-o", destination, packageToBuild} goArgs, err := GenerateFlagArgs( GoBuildFlags, diff --git a/vendor/github.com/onsi/ginkgo/v2/types/deprecation_support.go b/vendor/github.com/onsi/ginkgo/v2/types/deprecation_support.go index 2948dfa0..e2519f67 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/deprecation_support.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/deprecation_support.go @@ -38,7 +38,7 @@ func (d deprecations) Async() Deprecation { func (d deprecations) Measure() Deprecation { return Deprecation{ - Message: "Measure is deprecated and will be removed in Ginkgo V2. Please migrate to gomega/gmeasure.", + Message: "Measure is deprecated and has been removed from Ginkgo V2. Any Measure tests in your spec will not run. Please migrate to gomega/gmeasure.", DocLink: "removed-measure", Version: "1.16.3", } @@ -83,6 +83,13 @@ func (d deprecations) Nodot() Deprecation { } } +func (d deprecations) SuppressProgressReporting() Deprecation { + return Deprecation{ + Message: "Improvements to how reporters emit timeline information means that SuppressProgressReporting is no longer necessary and has been deprecated.", + Version: "2.5.0", + } +} + type DeprecationTracker struct { deprecations map[Deprecation][]CodeLocation lock *sync.Mutex diff --git a/vendor/github.com/onsi/ginkgo/v2/types/errors.go b/vendor/github.com/onsi/ginkgo/v2/types/errors.go index 6806d6af..1e0dbfd9 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/errors.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/errors.go @@ -108,8 +108,8 @@ Please ensure all assertions are inside leaf nodes such as {{bold}}BeforeEach{{/ func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocation) error { docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite" - if nodeType.Is(NodeTypeReportAfterSuite) { - docLink = "reporting-nodes---reportaftersuite" + if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) { + docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite" } return GinkgoError{ @@ -125,8 +125,8 @@ func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocatio func (g ginkgoErrors) SuiteNodeDuringRunPhase(nodeType NodeType, cl CodeLocation) error { docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite" - if nodeType.Is(NodeTypeReportAfterSuite) { - docLink = "reporting-nodes---reportaftersuite" + if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) { + docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite" } return GinkgoError{ @@ -180,6 +180,15 @@ func (g ginkgoErrors) InvalidDeclarationOfFocusedAndPending(cl CodeLocation, nod } } +func (g ginkgoErrors) InvalidDeclarationOfFlakeAttemptsAndMustPassRepeatedly(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Invalid Combination of Decorators: FlakeAttempts and MustPassRepeatedly", + Message: formatter.F(`[%s] node was decorated with both FlakeAttempts and MustPassRepeatedly. At most one is allowed.`, nodeType), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decorator interface{}) error { return GinkgoError{ Heading: "Unknown Decorator", @@ -189,20 +198,55 @@ func (g ginkgoErrors) UnknownDecorator(cl CodeLocation, nodeType NodeType, decor } } +func (g ginkgoErrors) InvalidBodyTypeForContainer(t reflect.Type, cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Invalid Function", + Message: formatter.F(`[%s] node must be passed {{bold}}func(){{/}} - i.e. functions that take nothing and return nothing. You passed {{bold}}%s{{/}} instead.`, nodeType, t), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + func (g ginkgoErrors) InvalidBodyType(t reflect.Type, cl CodeLocation, nodeType NodeType) error { + mustGet := "{{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}" + if nodeType.Is(NodeTypeContainer) { + mustGet = "{{bold}}func(){{/}}" + } return GinkgoError{ Heading: "Invalid Function", - Message: formatter.F(`[%s] node must be passed {{bold}}func(){{/}} - i.e. functions that take nothing and return nothing. + Message: formatter.F(`[%s] node must be passed `+mustGet+`. You passed {{bold}}%s{{/}} instead.`, nodeType, t), CodeLocation: cl, DocLink: "node-decorators-overview", } } +func (g ginkgoErrors) InvalidBodyTypeForSynchronizedBeforeSuiteProc1(t reflect.Type, cl CodeLocation) error { + mustGet := "{{bold}}func() []byte{{/}}, {{bold}}func(ctx SpecContext) []byte{{/}}, or {{bold}}func(ctx context.Context) []byte{{/}}, {{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}" + return GinkgoError{ + Heading: "Invalid Function", + Message: formatter.F(`[SynchronizedBeforeSuite] node must be passed `+mustGet+` for its first function. +You passed {{bold}}%s{{/}} instead.`, t), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + +func (g ginkgoErrors) InvalidBodyTypeForSynchronizedBeforeSuiteAllProcs(t reflect.Type, cl CodeLocation) error { + mustGet := "{{bold}}func(){{/}}, {{bold}}func(ctx SpecContext){{/}}, or {{bold}}func(ctx context.Context){{/}}, {{bold}}func([]byte){{/}}, {{bold}}func(ctx SpecContext, []byte){{/}}, or {{bold}}func(ctx context.Context, []byte){{/}}" + return GinkgoError{ + Heading: "Invalid Function", + Message: formatter.F(`[SynchronizedBeforeSuite] node must be passed `+mustGet+` for its second function. +You passed {{bold}}%s{{/}} instead.`, t), + CodeLocation: cl, + DocLink: "node-decorators-overview", + } +} + func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType) error { return GinkgoError{ Heading: "Multiple Functions", - Message: formatter.F(`[%s] node must be passed a single {{bold}}func(){{/}} - but more than one was passed in.`, nodeType), + Message: formatter.F(`[%s] node must be passed a single function - but more than one was passed in.`, nodeType), CodeLocation: cl, DocLink: "node-decorators-overview", } @@ -211,12 +255,30 @@ func (g ginkgoErrors) MultipleBodyFunctions(cl CodeLocation, nodeType NodeType) func (g ginkgoErrors) MissingBodyFunction(cl CodeLocation, nodeType NodeType) error { return GinkgoError{ Heading: "Missing Functions", - Message: formatter.F(`[%s] node must be passed a single {{bold}}func(){{/}} - but none was passed in.`, nodeType), + Message: formatter.F(`[%s] node must be passed a single function - but none was passed in.`, nodeType), CodeLocation: cl, DocLink: "node-decorators-overview", } } +func (g ginkgoErrors) InvalidTimeoutOrGracePeriodForNonContextNode(cl CodeLocation, nodeType NodeType) error { + return GinkgoError{ + Heading: "Invalid NodeTimeout SpecTimeout, or GracePeriod", + Message: formatter.F(`[%s] was passed NodeTimeout, SpecTimeout, or GracePeriod but does not have a callback that accepts a {{bold}}SpecContext{{/}} or {{bold}}context.Context{{/}}. You must accept a context to enable timeouts and grace periods`, nodeType), + CodeLocation: cl, + DocLink: "spec-timeouts-and-interruptible-nodes", + } +} + +func (g ginkgoErrors) InvalidTimeoutOrGracePeriodForNonContextCleanupNode(cl CodeLocation) error { + return GinkgoError{ + Heading: "Invalid NodeTimeout SpecTimeout, or GracePeriod", + Message: formatter.F(`[DeferCleanup] was passed NodeTimeout or GracePeriod but does not have a callback that accepts a {{bold}}SpecContext{{/}} or {{bold}}context.Context{{/}}. You must accept a context to enable timeouts and grace periods`), + CodeLocation: cl, + DocLink: "spec-timeouts-and-interruptible-nodes", + } +} + /* Ordered Container errors */ func (g ginkgoErrors) InvalidSerialNodeInNonSerialOrderedContainer(cl CodeLocation, nodeType NodeType) error { return GinkgoError{ @@ -236,6 +298,15 @@ func (g ginkgoErrors) SetupNodeNotInOrderedContainer(cl CodeLocation, nodeType N } } +func (g ginkgoErrors) InvalidContinueOnFailureDecoration(cl CodeLocation) error { + return GinkgoError{ + Heading: "ContinueOnFailure not decorating an outermost Ordered Container", + Message: "ContinueOnFailure can only decorate an Ordered container, and this Ordered container must be the outermost Ordered container.", + CodeLocation: cl, + DocLink: "ordered-containers", + } +} + /* DeferCleanup errors */ func (g ginkgoErrors) DeferCleanupInvalidFunction(cl CodeLocation) error { return GinkgoError{ @@ -258,7 +329,7 @@ func (g ginkgoErrors) PushingCleanupNodeDuringTreeConstruction(cl CodeLocation) func (g ginkgoErrors) PushingCleanupInReportingNode(cl CodeLocation, nodeType NodeType) error { return GinkgoError{ Heading: fmt.Sprintf("DeferCleanup cannot be called in %s", nodeType), - Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a ReportAfterEach or ReportAfterSuite.", + Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a Reporting node.", CodeLocation: cl, DocLink: "cleaning-up-our-cleanup-code-defercleanup", } @@ -380,6 +451,15 @@ func (g ginkgoErrors) InvalidEntryDescription(cl CodeLocation) error { } } +func (g ginkgoErrors) MissingParametersForTableFunction(cl CodeLocation) error { + return GinkgoError{ + Heading: fmt.Sprintf("No parameters have been passed to the Table Function"), + Message: fmt.Sprintf("The Table Function expected at least 1 parameter"), + CodeLocation: cl, + DocLink: "table-specs", + } +} + func (g ginkgoErrors) IncorrectParameterTypeForTable(i int, name string, cl CodeLocation) error { return GinkgoError{ Heading: "DescribeTable passed incorrect parameter type", @@ -498,6 +578,13 @@ func (g ginkgoErrors) DryRunInParallelConfiguration() error { } } +func (g ginkgoErrors) GracePeriodCannotBeZero() error { + return GinkgoError{ + Heading: "Ginkgo requires a positive --grace-period.", + Message: "Please set --grace-period to a positive duration. The default is 30s.", + } +} + func (g ginkgoErrors) ConflictingVerbosityConfiguration() error { return GinkgoError{ Heading: "Conflicting reporter verbosity settings.", diff --git a/vendor/github.com/onsi/ginkgo/v2/types/label_filter.go b/vendor/github.com/onsi/ginkgo/v2/types/label_filter.go index 0403f9e6..b0d3b651 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/label_filter.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/label_filter.go @@ -272,12 +272,23 @@ func tokenize(input string) func() (*treeNode, error) { } } +func MustParseLabelFilter(input string) LabelFilter { + filter, err := ParseLabelFilter(input) + if err != nil { + panic(err) + } + return filter +} + func ParseLabelFilter(input string) (LabelFilter, error) { if DEBUG_LABEL_FILTER_PARSING { fmt.Println("\n==============") fmt.Println("Input: ", input) fmt.Print("Tokens: ") } + if input == "" { + return func(_ []string) bool { return true }, nil + } nextToken := tokenize(input) root := &treeNode{token: lfTokenRoot} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/report_entry.go b/vendor/github.com/onsi/ginkgo/v2/types/report_entry.go index 798bedc0..7b1524b5 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/report_entry.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/report_entry.go @@ -6,8 +6,8 @@ import ( "time" ) -//ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports -//and across the network connection when running in parallel +// ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports +// and across the network connection when running in parallel type ReportEntryValue struct { raw interface{} //unexported to prevent gob from freaking out about unregistered structs AsJSON string @@ -85,10 +85,12 @@ func (rev *ReportEntryValue) GobDecode(data []byte) error { type ReportEntry struct { // Visibility captures the visibility policy for this ReportEntry Visibility ReportEntryVisibility - // Time captures the time the AddReportEntry was called - Time time.Time // Location captures the location of the AddReportEntry call Location CodeLocation + + Time time.Time //need this for backwards compatibility + TimelineLocation TimelineLocation + // Name captures the name of this report Name string // Value captures the (optional) object passed into AddReportEntry - this can be @@ -120,7 +122,9 @@ func (entry ReportEntry) GetRawValue() interface{} { return entry.Value.GetRawValue() } - +func (entry ReportEntry) GetTimelineLocation() TimelineLocation { + return entry.TimelineLocation +} type ReportEntries []ReportEntry diff --git a/vendor/github.com/onsi/ginkgo/v2/types/types.go b/vendor/github.com/onsi/ginkgo/v2/types/types.go index aec9062e..d048a8ad 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/types.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/types.go @@ -2,6 +2,8 @@ package types import ( "encoding/json" + "fmt" + "sort" "strings" "time" ) @@ -56,19 +58,20 @@ type Report struct { SuiteConfig SuiteConfig //SpecReports is a list of all SpecReports generated by this test run + //It is empty when the SuiteReport is provided to ReportBeforeSuite SpecReports SpecReports } -//PreRunStats contains a set of stats captured before the test run begins. This is primarily used -//by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs) -//and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters. +// PreRunStats contains a set of stats captured before the test run begins. This is primarily used +// by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs) +// and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters. type PreRunStats struct { TotalSpecs int SpecsThatWillRun int } -//Add is ued by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes -//to form a complete final report. +// Add is used by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes +// to form a complete final report. func (report Report) Add(other Report) Report { report.SuiteSucceeded = report.SuiteSucceeded && other.SuiteSucceeded @@ -147,14 +150,24 @@ type SpecReport struct { // ParallelProcess captures the parallel process that this spec ran on ParallelProcess int + // RunningInParallel captures whether this spec is part of a suite that ran in parallel + RunningInParallel bool + //Failure is populated if a spec has failed, panicked, been interrupted, or skipped by the user (e.g. calling Skip()) //It includes detailed information about the Failure Failure Failure - // NumAttempts captures the number of times this Spec was run. Flakey specs can be retried with - // ginkgo --flake-attempts=N + // NumAttempts captures the number of times this Spec was run. + // Flakey specs can be retried with ginkgo --flake-attempts=N or the use of the FlakeAttempts decorator. + // Repeated specs can be retried with the use of the MustPassRepeatedly decorator NumAttempts int + // MaxFlakeAttempts captures whether the spec has been retried with ginkgo --flake-attempts=N or the use of the FlakeAttempts decorator. + MaxFlakeAttempts int + + // MaxMustPassRepeatedly captures whether the spec has the MustPassRepeatedly decorator + MaxMustPassRepeatedly int + // CapturedGinkgoWriterOutput contains text printed to the GinkgoWriter CapturedGinkgoWriterOutput string @@ -168,6 +181,12 @@ type SpecReport struct { // ProgressReports contains any progress reports generated during this spec. These can either be manually triggered, or automatically generated by Ginkgo via the PollProgressAfter() decorator ProgressReports []ProgressReport + + // AdditionalFailures contains any failures that occurred after the initial spec failure. These typically occur in cleanup nodes after the initial failure and are only emitted when running in verbose mode. + AdditionalFailures []AdditionalFailure + + // SpecEvents capture additional events that occur during the spec run + SpecEvents SpecEvents } func (report SpecReport) MarshalJSON() ([]byte, error) { @@ -187,10 +206,14 @@ func (report SpecReport) MarshalJSON() ([]byte, error) { ParallelProcess int Failure *Failure `json:",omitempty"` NumAttempts int - CapturedGinkgoWriterOutput string `json:",omitempty"` - CapturedStdOutErr string `json:",omitempty"` - ReportEntries ReportEntries `json:",omitempty"` - ProgressReports []ProgressReport `json:",omitempty"` + MaxFlakeAttempts int + MaxMustPassRepeatedly int + CapturedGinkgoWriterOutput string `json:",omitempty"` + CapturedStdOutErr string `json:",omitempty"` + ReportEntries ReportEntries `json:",omitempty"` + ProgressReports []ProgressReport `json:",omitempty"` + AdditionalFailures []AdditionalFailure `json:",omitempty"` + SpecEvents SpecEvents `json:",omitempty"` }{ ContainerHierarchyTexts: report.ContainerHierarchyTexts, ContainerHierarchyLocations: report.ContainerHierarchyLocations, @@ -207,6 +230,8 @@ func (report SpecReport) MarshalJSON() ([]byte, error) { Failure: nil, ReportEntries: nil, NumAttempts: report.NumAttempts, + MaxFlakeAttempts: report.MaxFlakeAttempts, + MaxMustPassRepeatedly: report.MaxMustPassRepeatedly, CapturedGinkgoWriterOutput: report.CapturedGinkgoWriterOutput, CapturedStdOutErr: report.CapturedStdOutErr, } @@ -220,6 +245,12 @@ func (report SpecReport) MarshalJSON() ([]byte, error) { if len(report.ProgressReports) > 0 { out.ProgressReports = report.ProgressReports } + if len(report.AdditionalFailures) > 0 { + out.AdditionalFailures = report.AdditionalFailures + } + if len(report.SpecEvents) > 0 { + out.SpecEvents = report.SpecEvents + } return json.Marshal(out) } @@ -237,13 +268,13 @@ func (report SpecReport) CombinedOutput() string { return report.CapturedStdOutErr + "\n" + report.CapturedGinkgoWriterOutput } -//Failed returns true if report.State is one of the SpecStateFailureStates +// Failed returns true if report.State is one of the SpecStateFailureStates // (SpecStateFailed, SpecStatePanicked, SpecStateinterrupted, SpecStateAborted) func (report SpecReport) Failed() bool { return report.State.Is(SpecStateFailureStates) } -//FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText +// FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText func (report SpecReport) FullText() string { texts := []string{} texts = append(texts, report.ContainerHierarchyTexts...) @@ -253,7 +284,7 @@ func (report SpecReport) FullText() string { return strings.Join(texts, " ") } -//Labels returns a deduped set of all the spec's Labels. +// Labels returns a deduped set of all the spec's Labels. func (report SpecReport) Labels() []string { out := []string{} seen := map[string]bool{} @@ -275,7 +306,7 @@ func (report SpecReport) Labels() []string { return out } -//MatchesLabelFilter returns true if the spec satisfies the passed in label filter query +// MatchesLabelFilter returns true if the spec satisfies the passed in label filter query func (report SpecReport) MatchesLabelFilter(query string) (bool, error) { filter, err := ParseLabelFilter(query) if err != nil { @@ -284,29 +315,54 @@ func (report SpecReport) MatchesLabelFilter(query string) (bool, error) { return filter(report.Labels()), nil } -//FileName() returns the name of the file containing the spec +// FileName() returns the name of the file containing the spec func (report SpecReport) FileName() string { return report.LeafNodeLocation.FileName } -//LineNumber() returns the line number of the leaf node +// LineNumber() returns the line number of the leaf node func (report SpecReport) LineNumber() int { return report.LeafNodeLocation.LineNumber } -//FailureMessage() returns the failure message (or empty string if the test hasn't failed) +// FailureMessage() returns the failure message (or empty string if the test hasn't failed) func (report SpecReport) FailureMessage() string { return report.Failure.Message } -//FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed) +// FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed) func (report SpecReport) FailureLocation() CodeLocation { return report.Failure.Location } +// Timeline() returns a timeline view of the report +func (report SpecReport) Timeline() Timeline { + timeline := Timeline{} + if !report.Failure.IsZero() { + timeline = append(timeline, report.Failure) + if report.Failure.AdditionalFailure != nil { + timeline = append(timeline, *(report.Failure.AdditionalFailure)) + } + } + for _, additionalFailure := range report.AdditionalFailures { + timeline = append(timeline, additionalFailure) + } + for _, reportEntry := range report.ReportEntries { + timeline = append(timeline, reportEntry) + } + for _, progressReport := range report.ProgressReports { + timeline = append(timeline, progressReport) + } + for _, specEvent := range report.SpecEvents { + timeline = append(timeline, specEvent) + } + sort.Sort(timeline) + return timeline +} + type SpecReports []SpecReport -//WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes +// WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports { count := 0 for i := range reports { @@ -326,7 +382,7 @@ func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports { return out } -//WithState returns the subset of SpecReports with State matching one of the requested SpecStates +// WithState returns the subset of SpecReports with State matching one of the requested SpecStates func (reports SpecReports) WithState(states SpecState) SpecReports { count := 0 for i := range reports { @@ -345,7 +401,7 @@ func (reports SpecReports) WithState(states SpecState) SpecReports { return out } -//CountWithState returns the number of SpecReports with State matching one of the requested SpecStates +// CountWithState returns the number of SpecReports with State matching one of the requested SpecStates func (reports SpecReports) CountWithState(states SpecState) int { n := 0 for i := range reports { @@ -356,17 +412,75 @@ func (reports SpecReports) CountWithState(states SpecState) int { return n } -//CountWithState returns the number of SpecReports that passed after multiple attempts +// If the Spec passes, CountOfFlakedSpecs returns the number of SpecReports that failed after multiple attempts. func (reports SpecReports) CountOfFlakedSpecs() int { n := 0 for i := range reports { - if reports[i].State.Is(SpecStatePassed) && reports[i].NumAttempts > 1 { + if reports[i].MaxFlakeAttempts > 1 && reports[i].State.Is(SpecStatePassed) && reports[i].NumAttempts > 1 { n += 1 } } return n } +// If the Spec fails, CountOfRepeatedSpecs returns the number of SpecReports that passed after multiple attempts +func (reports SpecReports) CountOfRepeatedSpecs() int { + n := 0 + for i := range reports { + if reports[i].MaxMustPassRepeatedly > 1 && reports[i].State.Is(SpecStateFailureStates) && reports[i].NumAttempts > 1 { + n += 1 + } + } + return n +} + +// TimelineLocation captures the location of an event in the spec's timeline +type TimelineLocation struct { + //Offset is the offset (in bytes) of the event relative to the GinkgoWriter stream + Offset int `json:",omitempty"` + + //Order is the order of the event with respect to other events. The absolute value of Order + //is irrelevant. All that matters is that an event with a lower Order occurs before ane vent with a higher Order + Order int `json:",omitempty"` + + Time time.Time +} + +// TimelineEvent represent an event on the timeline +// consumers of Timeline will need to check the concrete type of each entry to determine how to handle it +type TimelineEvent interface { + GetTimelineLocation() TimelineLocation +} + +type Timeline []TimelineEvent + +func (t Timeline) Len() int { return len(t) } +func (t Timeline) Less(i, j int) bool { + return t[i].GetTimelineLocation().Order < t[j].GetTimelineLocation().Order +} +func (t Timeline) Swap(i, j int) { t[i], t[j] = t[j], t[i] } +func (t Timeline) WithoutHiddenReportEntries() Timeline { + out := Timeline{} + for _, event := range t { + if reportEntry, isReportEntry := event.(ReportEntry); isReportEntry && reportEntry.Visibility == ReportEntryVisibilityNever { + continue + } + out = append(out, event) + } + return out +} + +func (t Timeline) WithoutVeryVerboseSpecEvents() Timeline { + out := Timeline{} + for _, event := range t { + if specEvent, isSpecEvent := event.(SpecEvent); isSpecEvent && specEvent.IsOnlyVisibleAtVeryVerbose() { + continue + } + out = append(out, event) + } + return out +} + // Failure captures failure information for an individual test type Failure struct { // Message - the failure message passed into Fail(...). When using a matcher library @@ -379,6 +493,8 @@ type Failure struct { // This CodeLocation will include a fully-populated StackTrace Location CodeLocation + TimelineLocation TimelineLocation + // ForwardedPanic - if the failure represents a captured panic (i.e. Summary.State == SpecStatePanicked) // then ForwardedPanic will be populated with a string representation of the captured panic. ForwardedPanic string `json:",omitempty"` @@ -391,19 +507,29 @@ type Failure struct { // FailureNodeType will contain the NodeType of the node in which the failure occurred. // FailureNodeLocation will contain the CodeLocation of the node in which the failure occurred. // If populated, FailureNodeContainerIndex will be the index into SpecReport.ContainerHierarchyTexts and SpecReport.ContainerHierarchyLocations that represents the parent container of the node in which the failure occurred. - FailureNodeContext FailureNodeContext - FailureNodeType NodeType - FailureNodeLocation CodeLocation - FailureNodeContainerIndex int + FailureNodeContext FailureNodeContext `json:",omitempty"` + + FailureNodeType NodeType `json:",omitempty"` + + FailureNodeLocation CodeLocation `json:",omitempty"` + + FailureNodeContainerIndex int `json:",omitempty"` //ProgressReport is populated if the spec was interrupted or timed out - ProgressReport ProgressReport + ProgressReport ProgressReport `json:",omitempty"` + + //AdditionalFailure is non-nil if a follow-on failure occurred within the same node after the primary failure. This only happens when a node has timed out or been interrupted. In such cases the AdditionalFailure can include information about where/why the spec was stuck. + AdditionalFailure *AdditionalFailure `json:",omitempty"` } func (f Failure) IsZero() bool { return f.Message == "" && (f.Location == CodeLocation{}) } +func (f Failure) GetTimelineLocation() TimelineLocation { + return f.TimelineLocation +} + // FailureNodeContext captures the location context for the node containing the failing line of code type FailureNodeContext uint @@ -434,6 +560,18 @@ func (fnc FailureNodeContext) MarshalJSON() ([]byte, error) { return fncEnumSupport.MarshJSON(uint(fnc)) } +// AdditionalFailure capturs any additional failures that occur after the initial failure of a psec +// these typically occur in clean up nodes after the spec has failed. +// We can't simply use Failure as we want to track the SpecState to know what kind of failure this is +type AdditionalFailure struct { + State SpecState + Failure Failure +} + +func (f AdditionalFailure) GetTimelineLocation() TimelineLocation { + return f.Failure.TimelineLocation +} + // SpecState captures the state of a spec // To determine if a given `state` represents a failure state, use `state.Is(SpecStateFailureStates)` type SpecState uint @@ -448,6 +586,7 @@ const ( SpecStateAborted SpecStatePanicked SpecStateInterrupted + SpecStateTimedout ) var ssEnumSupport = NewEnumSupport(map[uint]string{ @@ -459,11 +598,15 @@ var ssEnumSupport = NewEnumSupport(map[uint]string{ uint(SpecStateAborted): "aborted", uint(SpecStatePanicked): "panicked", uint(SpecStateInterrupted): "interrupted", + uint(SpecStateTimedout): "timedout", }) func (ss SpecState) String() string { return ssEnumSupport.String(uint(ss)) } +func (ss SpecState) GomegaString() string { + return ssEnumSupport.String(uint(ss)) +} func (ss *SpecState) UnmarshalJSON(b []byte) error { out, err := ssEnumSupport.UnmarshJSON(b) *ss = SpecState(out) @@ -473,7 +616,7 @@ func (ss SpecState) MarshalJSON() ([]byte, error) { return ssEnumSupport.MarshJSON(uint(ss)) } -var SpecStateFailureStates = SpecStateFailed | SpecStateAborted | SpecStatePanicked | SpecStateInterrupted +var SpecStateFailureStates = SpecStateFailed | SpecStateTimedout | SpecStateAborted | SpecStatePanicked | SpecStateInterrupted func (ss SpecState) Is(states SpecState) bool { return ss&states != 0 @@ -481,35 +624,40 @@ func (ss SpecState) Is(states SpecState) bool { // ProgressReport captures the progress of the current spec. It is, effectively, a structured Ginkgo-aware stack trace type ProgressReport struct { - ParallelProcess int - RunningInParallel bool + Message string `json:",omitempty"` + ParallelProcess int `json:",omitempty"` + RunningInParallel bool `json:",omitempty"` - Time time.Time + ContainerHierarchyTexts []string `json:",omitempty"` + LeafNodeText string `json:",omitempty"` + LeafNodeLocation CodeLocation `json:",omitempty"` + SpecStartTime time.Time `json:",omitempty"` - ContainerHierarchyTexts []string - LeafNodeText string - LeafNodeLocation CodeLocation - SpecStartTime time.Time + CurrentNodeType NodeType `json:",omitempty"` + CurrentNodeText string `json:",omitempty"` + CurrentNodeLocation CodeLocation `json:",omitempty"` + CurrentNodeStartTime time.Time `json:",omitempty"` - CurrentNodeType NodeType - CurrentNodeText string - CurrentNodeLocation CodeLocation - CurrentNodeStartTime time.Time + CurrentStepText string `json:",omitempty"` + CurrentStepLocation CodeLocation `json:",omitempty"` + CurrentStepStartTime time.Time `json:",omitempty"` - CurrentStepText string - CurrentStepLocation CodeLocation - CurrentStepStartTime time.Time + AdditionalReports []string `json:",omitempty"` - CapturedGinkgoWriterOutput string `json:",omitempty"` - GinkgoWriterOffset int + CapturedGinkgoWriterOutput string `json:",omitempty"` + TimelineLocation TimelineLocation `json:",omitempty"` - Goroutines []Goroutine + Goroutines []Goroutine `json:",omitempty"` } func (pr ProgressReport) IsZero() bool { return pr.CurrentNodeType == NodeTypeInvalid } +func (pr ProgressReport) Time() time.Time { + return pr.TimelineLocation.Time +} + func (pr ProgressReport) SpecGoroutine() Goroutine { for _, goroutine := range pr.Goroutines { if goroutine.IsSpecGoroutine { @@ -547,6 +695,22 @@ func (pr ProgressReport) WithoutCapturedGinkgoWriterOutput() ProgressReport { return out } +func (pr ProgressReport) WithoutOtherGoroutines() ProgressReport { + out := pr + filteredGoroutines := []Goroutine{} + for _, goroutine := range pr.Goroutines { + if goroutine.IsSpecGoroutine || goroutine.HasHighlights() { + filteredGoroutines = append(filteredGoroutines, goroutine) + } + } + out.Goroutines = filteredGoroutines + return out +} + +func (pr ProgressReport) GetTimelineLocation() TimelineLocation { + return pr.TimelineLocation +} + type Goroutine struct { ID uint64 State string @@ -601,6 +765,7 @@ const ( NodeTypeReportBeforeEach NodeTypeReportAfterEach + NodeTypeReportBeforeSuite NodeTypeReportAfterSuite NodeTypeCleanupInvalid @@ -610,7 +775,9 @@ const ( ) var NodeTypesForContainerAndIt = NodeTypeContainer | NodeTypeIt -var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite +var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite +var NodeTypesAllowedDuringCleanupInterrupt = NodeTypeAfterEach | NodeTypeJustAfterEach | NodeTypeAfterAll | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeCleanupAfterEach | NodeTypeCleanupAfterAll | NodeTypeCleanupAfterSuite +var NodeTypesAllowedDuringReportInterrupt = NodeTypeReportBeforeEach | NodeTypeReportAfterEach | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite var ntEnumSupport = NewEnumSupport(map[uint]string{ uint(NodeTypeInvalid): "INVALID NODE TYPE", @@ -628,9 +795,10 @@ var ntEnumSupport = NewEnumSupport(map[uint]string{ uint(NodeTypeSynchronizedAfterSuite): "SynchronizedAfterSuite", uint(NodeTypeReportBeforeEach): "ReportBeforeEach", uint(NodeTypeReportAfterEach): "ReportAfterEach", + uint(NodeTypeReportBeforeSuite): "ReportBeforeSuite", uint(NodeTypeReportAfterSuite): "ReportAfterSuite", - uint(NodeTypeCleanupInvalid): "INVALID CLEANUP NODE", - uint(NodeTypeCleanupAfterEach): "DeferCleanup", + uint(NodeTypeCleanupInvalid): "DeferCleanup", + uint(NodeTypeCleanupAfterEach): "DeferCleanup (Each)", uint(NodeTypeCleanupAfterAll): "DeferCleanup (All)", uint(NodeTypeCleanupAfterSuite): "DeferCleanup (Suite)", }) @@ -650,3 +818,99 @@ func (nt NodeType) MarshalJSON() ([]byte, error) { func (nt NodeType) Is(nodeTypes NodeType) bool { return nt&nodeTypes != 0 } + +/* +SpecEvent captures a vareity of events that can occur when specs run. See SpecEventType for the list of available events. +*/ +type SpecEvent struct { + SpecEventType SpecEventType + + CodeLocation CodeLocation + TimelineLocation TimelineLocation + + Message string `json:",omitempty"` + Duration time.Duration `json:",omitempty"` + NodeType NodeType `json:",omitempty"` + Attempt int `json:",omitempty"` +} + +func (se SpecEvent) GetTimelineLocation() TimelineLocation { + return se.TimelineLocation +} + +func (se SpecEvent) IsOnlyVisibleAtVeryVerbose() bool { + return se.SpecEventType.Is(SpecEventByEnd | SpecEventNodeStart | SpecEventNodeEnd) +} + +func (se SpecEvent) GomegaString() string { + out := &strings.Builder{} + out.WriteString("[" + se.SpecEventType.String() + " SpecEvent] ") + if se.Message != "" { + out.WriteString("Message=") + out.WriteString(`"` + se.Message + `",`) + } + if se.Duration != 0 { + out.WriteString("Duration=" + se.Duration.String() + ",") + } + if se.NodeType != NodeTypeInvalid { + out.WriteString("NodeType=" + se.NodeType.String() + ",") + } + if se.Attempt != 0 { + out.WriteString(fmt.Sprintf("Attempt=%d", se.Attempt) + ",") + } + out.WriteString("CL=" + se.CodeLocation.String() + ",") + out.WriteString(fmt.Sprintf("TL.Offset=%d", se.TimelineLocation.Offset)) + + return out.String() +} + +type SpecEvents []SpecEvent + +func (se SpecEvents) WithType(seType SpecEventType) SpecEvents { + out := SpecEvents{} + for _, event := range se { + if event.SpecEventType.Is(seType) { + out = append(out, event) + } + } + return out +} + +type SpecEventType uint + +const ( + SpecEventInvalid SpecEventType = 0 + + SpecEventByStart SpecEventType = 1 << iota + SpecEventByEnd + SpecEventNodeStart + SpecEventNodeEnd + SpecEventSpecRepeat + SpecEventSpecRetry +) + +var seEnumSupport = NewEnumSupport(map[uint]string{ + uint(SpecEventInvalid): "INVALID SPEC EVENT", + uint(SpecEventByStart): "By", + uint(SpecEventByEnd): "By (End)", + uint(SpecEventNodeStart): "Node", + uint(SpecEventNodeEnd): "Node (End)", + uint(SpecEventSpecRepeat): "Repeat", + uint(SpecEventSpecRetry): "Retry", +}) + +func (se SpecEventType) String() string { + return seEnumSupport.String(uint(se)) +} +func (se *SpecEventType) UnmarshalJSON(b []byte) error { + out, err := seEnumSupport.UnmarshJSON(b) + *se = SpecEventType(out) + return err +} +func (se SpecEventType) MarshalJSON() ([]byte, error) { + return seEnumSupport.MarshJSON(uint(se)) +} + +func (se SpecEventType) Is(specEventTypes SpecEventType) bool { + return se&specEventTypes != 0 +} diff --git a/vendor/github.com/onsi/ginkgo/v2/types/version.go b/vendor/github.com/onsi/ginkgo/v2/types/version.go index 3974fdc2..43066341 100644 --- a/vendor/github.com/onsi/ginkgo/v2/types/version.go +++ b/vendor/github.com/onsi/ginkgo/v2/types/version.go @@ -1,3 +1,3 @@ package types -const VERSION = "2.2.0" +const VERSION = "2.9.5" diff --git a/vendor/github.com/quic-go/quic-go/README.md b/vendor/github.com/quic-go/quic-go/README.md index 977bb928..53638882 100644 --- a/vendor/github.com/quic-go/quic-go/README.md +++ b/vendor/github.com/quic-go/quic-go/README.md @@ -5,28 +5,185 @@ [![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go) [![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/) -quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go, including the Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) and Datagram Packetization Layer Path MTU - Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)). It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)). +quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)). + +In addition to these base RFCs, it also implements the following RFCs: +* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) +* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)) +* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369)) In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem. -## Guides +This repository provides both a QUIC implementation, located in the `quic` package, as well as an HTTP/3 implementation, located in the `http3` package. -*We currently support Go 1.19.x and Go 1.20.x* +## Using QUIC -Running tests: +### Running a Server - go test ./... +The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket. -### QUIC without HTTP/3 +```go +udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234}) +// ... error handling +tr := quic.Transport{ + Conn: udpConn, +} +ln, err := tr.Listen(tlsConf, quicConf) +// ... error handling +go func() { + for { + conn, err := ln.Accept() + // ... error handling + // handle the connection, usually in a new Go routine + } +} +``` -Take a look at [this echo example](example/echo/echo.go). +The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`). -## Usage +As a shortcut, `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`: + +``` +ln, err := quic.Listen(udpConn, tlsConf, quicConf) +``` + +When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections. + +### Running a Client + +As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections. + +```go +ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout +defer cancel() +conn, err := tr.Dial(ctx, , , ) +// ... error handling +``` + +As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`: + +```go +ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout +defer cancel() +conn, err := quic.Dial(ctx, conn, , , ) +``` + +Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections. + +### Using a QUIC Connection + +#### Accepting Streams + +QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3). + +Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides. + +On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop: + +```go +for { + str, err := conn.AcceptStream(context.Background()) // for bidirectional streams + // ... error handling + // handle the stream, usually in a new Go routine +} +``` + +These functions return an error when the underlying QUIC connection is closed. + +#### Opening Streams + +There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so. + +Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away: + +```go +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() +str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream +``` + +The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error: + +```go +str, err := conn.OpenStream() +if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + // It's currently not possible to open another stream, + // but it might be possible later, once the peer allowed us to do so. +} +``` + +These functions return an error when the underlying QUIC connection is closed. + +#### Using Streams + +Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions. + +Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream. + +In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream. + +Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream. + +A bidirectional stream is only closed once both the read and the write side of the stream have been either closed and reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`. + +### Configuring QUIC + +The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details. + +The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)). + +### Closing a Connection + +#### When the remote Peer closes the Connection + +In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Users can use errors assertions to find out what exactly went wrong: + +* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions. +* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`. +* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur. +* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer. +* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error. +* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below. + +#### Initiated by the Application + +A `quic.Connection` can be closed using `CloseWithError`: + +```go +conn.CloseWithError(0x42, "error 0x42 occurred") +``` + +Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes. + +On the receiver side, this is surfaced as a `quic.ApplicationError`. + +### QUIC Datagrams + +Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`. + +QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not automatically retransmitted. + +Datagrams are sent using the `SendMessage` method on the `quic.Connection`: + +```go +conn.SendMessage([]byte("foobar")) +``` + +And received using `ReceiveMessage`: + +```go +msg, err := conn.ReceiveMessage() +``` + +Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599). + + + +## Using HTTP/3 ### As a server -See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go: +See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go: ```go http.Handle("/", http.FileServer(http.Dir(wwwDir))) @@ -45,18 +202,29 @@ http.Client{ ## Projects using quic-go -| Project | Description | Stars | -|-----------------------------------------------------------|---------------------------------------------------------------------------------------------------------|-------| -| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) | -| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | -| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | -| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | -| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | -| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | -| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | -| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | -| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | -| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) | +| Project | Description | Stars | +|-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------| +| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) | +| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | +| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | +| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | +| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | +| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) | +| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | +| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | +| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/traefik/traefik?style=flat-square) | +| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) | +| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) | + +If you'd like to see your project added to this list, please send us a PR. + +## Release Policy + +quic-go always aims to support the latest two Go releases. + +### Dependency on forked crypto/tls + +Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20) and [qtls for Go 1.19](https://github.com/quic-go/qtls-go1-19). This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward. ## Contributing diff --git a/vendor/github.com/quic-go/quic-go/buffer_pool.go b/vendor/github.com/quic-go/quic-go/buffer_pool.go index f6745b08..48589e12 100644 --- a/vendor/github.com/quic-go/quic-go/buffer_pool.go +++ b/vendor/github.com/quic-go/quic-go/buffer_pool.go @@ -51,18 +51,22 @@ func (b *packetBuffer) Release() { } // Len returns the length of Data -func (b *packetBuffer) Len() protocol.ByteCount { - return protocol.ByteCount(len(b.Data)) -} +func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) } +func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) } func (b *packetBuffer) putBack() { - if cap(b.Data) != int(protocol.MaxPacketBufferSize) { - panic("putPacketBuffer called with packet of wrong size!") + if cap(b.Data) == protocol.MaxPacketBufferSize { + bufferPool.Put(b) + return } - bufferPool.Put(b) + if cap(b.Data) == protocol.MaxLargePacketBufferSize { + largeBufferPool.Put(b) + return + } + panic("putPacketBuffer called with packet of wrong size!") } -var bufferPool sync.Pool +var bufferPool, largeBufferPool sync.Pool func getPacketBuffer() *packetBuffer { buf := bufferPool.Get().(*packetBuffer) @@ -71,10 +75,18 @@ func getPacketBuffer() *packetBuffer { return buf } +func getLargePacketBuffer() *packetBuffer { + buf := largeBufferPool.Get().(*packetBuffer) + buf.refCount = 1 + buf.Data = buf.Data[:0] + return buf +} + func init() { - bufferPool.New = func() interface{} { - return &packetBuffer{ - Data: make([]byte, 0, protocol.MaxPacketBufferSize), - } + bufferPool.New = func() any { + return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)} + } + largeBufferPool.New = func() any { + return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)} } } diff --git a/vendor/github.com/quic-go/quic-go/client.go b/vendor/github.com/quic-go/quic-go/client.go index 98359c22..4dfacb50 100644 --- a/vendor/github.com/quic-go/quic-go/client.go +++ b/vendor/github.com/quic-go/quic-go/client.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "errors" - "fmt" "net" "github.com/quic-go/quic-go/internal/protocol" @@ -13,20 +12,19 @@ import ( ) type client struct { - sconn sendConn - // If the client is created with DialAddr, we create a packet conn. - // If it is started with Dial, we take a packet conn as a parameter. - createdPacketConn bool + sendConn sendConn use0RTT bool packetHandlers packetHandlerManager + onClose func() tlsConf *tls.Config config *Config - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID + connIDGenerator ConnectionIDGenerator + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID initialPacketNumber protocol.PacketNumber hasNegotiatedVersion bool @@ -45,153 +43,107 @@ type client struct { var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial // DialAddr establishes a new QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC connection is closed. -// The hostname for SNI is taken from the given address. -func DialAddr( - addr string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return DialAddrContext(context.Background(), addr, tlsConf, config) -} - -// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. -// It uses a new UDP connection and closes this connection when the QUIC connection is closed. -// The hostname for SNI is taken from the given address. -func DialAddrEarly( - addr string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return DialAddrEarlyContext(context.Background(), addr, tlsConf, config) -} - -// DialAddrEarlyContext establishes a new 0-RTT QUIC connection to a server using provided context. -// See DialAddrEarly for details -func DialAddrEarlyContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - conn, err := dialAddrContext(ctx, addr, tlsConf, config, true) - if err != nil { - return nil, err - } - utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early connection") - return conn, nil -} - -// DialAddrContext establishes a new QUIC connection to a server using the provided context. -// See DialAddr for details. -func DialAddrContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialAddrContext(ctx, addr, tlsConf, config, false) -} - -func dialAddrContext( - ctx context.Context, - addr string, - tlsConf *tls.Config, - config *Config, - use0RTT bool, -) (quicConn, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } +// It resolves the address, and then creates a new UDP connection to dial the QUIC server. +// When the QUIC connection is closed, this UDP connection is closed. +// See Dial for more details. +func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return nil, err } - return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true) -} - -// Dial establishes a new QUIC connection to a server using a net.PacketConn. If -// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn -// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP -// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write -// packets. The same PacketConn can be used for multiple calls to Dial and -// Listen, QUIC connection IDs are used for demultiplexing the different -// connections. The host parameter is used for SNI. The tls.Config must define -// an application protocol (using NextProtos). -func Dial( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false) -} - -// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. -// The same PacketConn can be used for multiple calls to Dial and Listen, -// QUIC connection IDs are used for demultiplexing the different connections. -// The host parameter is used for SNI. -// The tls.Config must define an application protocol (using NextProtos). -func DialEarly( - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return DialEarlyContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) -} - -// DialEarlyContext establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context. -// See DialEarly for details. -func DialEarlyContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (EarlyConnection, error) { - return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, true, false) -} - -// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. -// See Dial for details. -func DialContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, -) (Connection, error) { - return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false) -} - -func dialContext( - ctx context.Context, - pconn net.PacketConn, - remoteAddr net.Addr, - host string, - tlsConf *tls.Config, - config *Config, - use0RTT bool, - createdPacketConn bool, -) (quicConn, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) + udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn) + dl, err := setupTransport(udpConn, tlsConf, true) + if err != nil { + return nil, err + } + return dl.Dial(ctx, udpAddr, tlsConf, conf) +} + +// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. +// See DialAddr for more details. +func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + dl, err := setupTransport(udpConn, tlsConf, true) + if err != nil { + return nil, err + } + conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. +// See Dial for more details. +func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + dl, err := setupTransport(c, tlsConf, false) + if err != nil { + return nil, err + } + conn, err := dl.DialEarly(ctx, addr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +// Dial establishes a new QUIC connection to a server using a net.PacketConn. +// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does), +// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP +// will be used instead of ReadFrom and WriteTo to read/write packets. +// The tls.Config must define an application protocol (using NextProtos). +// +// This is a convenience function. More advanced use cases should instantiate a Transport, +// which offers configuration options for a more fine-grained control of the connection establishment, +// including reusing the underlying UDP socket for multiple QUIC connections. +func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + dl, err := setupTransport(c, tlsConf, false) + if err != nil { + return nil, err + } + conn, err := dl.Dial(ctx, addr, tlsConf, conf) + if err != nil { + dl.Close() + return nil, err + } + return conn, nil +} + +func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + return &Transport{ + Conn: c, + createdConn: createdPacketConn, + isSingleUse: true, + }, nil +} + +func dial( + ctx context.Context, + conn sendConn, + connIDGenerator ConnectionIDGenerator, + packetHandlers packetHandlerManager, + tlsConf *tls.Config, + config *Config, + onClose func(), + use0RTT bool, +) (quicConn, error) { + c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) if err != nil { return nil, err } @@ -199,14 +151,10 @@ func dialContext( c.tracingID = nextConnTracingID() if c.config.Tracer != nil { - c.tracer = c.config.Tracer.TracerForConnection( - context.WithValue(ctx, ConnectionTracingKey, c.tracingID), - protocol.PerspectiveClient, - c.destConnID, - ) + c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) } if c.tracer != nil { - c.tracer.StartedConnection(c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID) + c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) } if err := c.dial(ctx); err != nil { return nil, err @@ -214,40 +162,14 @@ func dialContext( return c.conn, nil } -func newClient( - pconn net.PacketConn, - remoteAddr net.Addr, - config *Config, - tlsConf *tls.Config, - host string, - use0RTT bool, - createdPacketConn bool, -) (*client, error) { +func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} } else { tlsConf = tlsConf.Clone() } - if tlsConf.ServerName == "" { - sni, _, err := net.SplitHostPort(host) - if err != nil { - // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. - sni = host - } - tlsConf.ServerName = sni - } - - // check that all versions are actually supported - if config != nil { - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } - } - - srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID() + srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err } @@ -256,28 +178,30 @@ func newClient( return nil, err } c := &client{ - srcConnID: srcConnID, - destConnID: destConnID, - sconn: newSendPconn(pconn, remoteAddr), - createdPacketConn: createdPacketConn, - use0RTT: use0RTT, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), + connIDGenerator: connIDGenerator, + srcConnID: srcConnID, + destConnID: destConnID, + sendConn: sendConn, + use0RTT: use0RTT, + onClose: onClose, + tlsConf: tlsConf, + config: config, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), } return c, nil } func (c *client) dial(ctx context.Context) error { - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sconn.LocalAddr(), c.sconn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.conn = newClientConnection( - c.sconn, + c.sendConn, c.packetHandlers, c.destConnID, c.srcConnID, + c.connIDGenerator, c.config, c.tlsConf, c.initialPacketNumber, @@ -291,13 +215,18 @@ func (c *client) dial(ctx context.Context) error { c.packetHandlers.Add(c.srcConnID, c.conn) errorChan := make(chan error, 1) + recreateChan := make(chan errCloseForRecreating) go func() { - err := c.conn.run() // returns as soon as the connection is closed - - if e := (&errCloseForRecreating{}); !errors.As(err, &e) && c.createdPacketConn { - c.packetHandlers.Destroy() + err := c.conn.run() + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + recreateChan <- *recreateErr + return } - errorChan <- err + if c.onClose != nil { + c.onClose() + } + errorChan <- err // returns as soon as the connection is closed }() // only set when we're using 0-RTT @@ -312,14 +241,12 @@ func (c *client) dial(ctx context.Context) error { c.conn.shutdown() return ctx.Err() case err := <-errorChan: - var recreateErr *errCloseForRecreating - if errors.As(err, &recreateErr) { - c.initialPacketNumber = recreateErr.nextPacketNumber - c.version = recreateErr.nextVersion - c.hasNegotiatedVersion = true - return c.dial(ctx) - } return err + case recreateErr := <-recreateChan: + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true + return c.dial(ctx) case <-earlyConnChan: // ready to send 0-RTT data return nil diff --git a/vendor/github.com/quic-go/quic-go/closed_conn.go b/vendor/github.com/quic-go/quic-go/closed_conn.go index 73904b84..0c988b53 100644 --- a/vendor/github.com/quic-go/quic-go/closed_conn.go +++ b/vendor/github.com/quic-go/quic-go/closed_conn.go @@ -16,13 +16,13 @@ type closedLocalConn struct { perspective protocol.Perspective logger utils.Logger - sendPacket func(net.Addr, *packetInfo) + sendPacket func(net.Addr, packetInfo) } var _ packetHandler = &closedLocalConn{} // newClosedLocalConn creates a new closedLocalConn and runs it. -func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { +func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { return &closedLocalConn{ sendPacket: sendPacket, perspective: pers, @@ -30,7 +30,7 @@ func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Pe } } -func (c *closedLocalConn) handlePacket(p *receivedPacket) { +func (c *closedLocalConn) handlePacket(p receivedPacket) { c.counter++ // exponential backoff // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving @@ -58,7 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler { return &closedRemoteConn{perspective: pers} } -func (s *closedRemoteConn) handlePacket(*receivedPacket) {} +func (s *closedRemoteConn) handlePacket(receivedPacket) {} func (s *closedRemoteConn) shutdown() {} func (s *closedRemoteConn) destroy(error) {} func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective } diff --git a/vendor/github.com/quic-go/quic-go/config.go b/vendor/github.com/quic-go/quic-go/config.go index 3ead9b7a..59df4cfd 100644 --- a/vendor/github.com/quic-go/quic-go/config.go +++ b/vendor/github.com/quic-go/quic-go/config.go @@ -1,12 +1,13 @@ package quic import ( - "errors" + "fmt" "net" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/quicvarint" ) // Clone clones a Config @@ -23,11 +24,24 @@ func validateConfig(config *Config) error { if config == nil { return nil } - if config.MaxIncomingStreams > 1<<60 { - return errors.New("invalid value for Config.MaxIncomingStreams") + const maxStreams = 1 << 60 + if config.MaxIncomingStreams > maxStreams { + config.MaxIncomingStreams = maxStreams } - if config.MaxIncomingUniStreams > 1<<60 { - return errors.New("invalid value for Config.MaxIncomingUniStreams") + if config.MaxIncomingUniStreams > maxStreams { + config.MaxIncomingUniStreams = maxStreams + } + if config.MaxStreamReceiveWindow > quicvarint.Max { + config.MaxStreamReceiveWindow = quicvarint.Max + } + if config.MaxConnectionReceiveWindow > quicvarint.Max { + config.MaxConnectionReceiveWindow = quicvarint.Max + } + // check that all QUIC versions are actually supported + for _, v := range config.Versions { + if !protocol.IsValidVersion(v) { + return fmt.Errorf("invalid QUIC version: %s", v) + } } return nil } @@ -35,7 +49,7 @@ func validateConfig(config *Config) error { // populateServerConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil func populateServerConfig(config *Config) *Config { - config = populateConfig(config, protocol.DefaultConnectionIDLength) + config = populateConfig(config) if config.MaxTokenAge == 0 { config.MaxTokenAge = protocol.TokenValidity } @@ -48,19 +62,9 @@ func populateServerConfig(config *Config) *Config { return config } -// populateClientConfig populates fields in the quic.Config with their default values, if none are set +// populateConfig populates fields in the quic.Config with their default values, if none are set // it may be called with nil -func populateClientConfig(config *Config, createdPacketConn bool) *Config { - defaultConnIDLen := protocol.DefaultConnectionIDLength - if createdPacketConn { - defaultConnIDLen = 0 - } - - config = populateConfig(config, defaultConnIDLen) - return config -} - -func populateConfig(config *Config, defaultConnIDLen int) *Config { +func populateConfig(config *Config) *Config { if config == nil { config = &Config{} } @@ -68,10 +72,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { if len(versions) == 0 { versions = protocol.SupportedVersions } - conIDLen := config.ConnectionIDLength - if config.ConnectionIDLength == 0 { - conIDLen = defaultConnIDLen - } handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout if config.HandshakeIdleTimeout != 0 { handshakeIdleTimeout = config.HandshakeIdleTimeout @@ -108,12 +108,9 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { } else if maxIncomingUniStreams < 0 { maxIncomingUniStreams = 0 } - connIDGenerator := config.ConnectionIDGenerator - if connIDGenerator == nil { - connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen} - } return &Config{ + GetConfigForClient: config.GetConfigForClient, Versions: versions, HandshakeIdleTimeout: handshakeIdleTimeout, MaxIdleTimeout: idleTimeout, @@ -128,9 +125,6 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, - ConnectionIDLength: conIDLen, - ConnectionIDGenerator: connIDGenerator, - StatelessResetKey: config.StatelessResetKey, TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, diff --git a/vendor/github.com/quic-go/quic-go/connection.go b/vendor/github.com/quic-go/quic-go/connection.go index eb16ece5..c544b591 100644 --- a/vendor/github.com/quic-go/quic-go/connection.go +++ b/vendor/github.com/quic-go/quic-go/connection.go @@ -61,11 +61,6 @@ type cryptoStreamHandler interface { ConnectionState() handshake.ConnectionState } -type packetInfo struct { - addr net.IP - ifIndex uint32 -} - type receivedPacket struct { buffer *packetBuffer @@ -75,7 +70,7 @@ type receivedPacket struct { ecn protocol.ECN - info *packetInfo + info packetInfo // only valid if the contained IP address is valid } func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) } @@ -173,7 +168,7 @@ type connection struct { oneRTTStream cryptoStream // only set for the server cryptoStreamHandler cryptoStreamHandler - receivedPackets chan *receivedPacket + receivedPackets chan receivedPacket sendingScheduled chan struct{} closeOnce sync.Once @@ -185,8 +180,8 @@ type connection struct { handshakeCtx context.Context handshakeCtxCancel context.CancelFunc - undecryptablePackets []*receivedPacket // undecryptable packets, waiting for a change in encryption level - undecryptablePacketsToProcess []*receivedPacket + undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level + undecryptablePacketsToProcess []receivedPacket clientHelloWritten <-chan *wire.TransportParameters earlyConnReadyChan chan struct{} @@ -199,6 +194,7 @@ type connection struct { versionNegotiated bool receivedFirstPacket bool + // the minimum of the max_idle_timeout values advertised by both endpoints idleTimeout time.Duration creationTime time.Time // The idle timeout is set based on the max of the time we received the last packet... @@ -240,6 +236,7 @@ var newConnection = func( clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, statelessResetToken protocol.StatelessResetToken, conf *Config, tlsConf *tls.Config, @@ -283,7 +280,7 @@ var newConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, - s.config.ConnectionIDGenerator, + connIDGenerator, ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) @@ -296,6 +293,7 @@ var newConnection = func( s.tracer, s.logger, ) + s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) initialStream := newCryptoStream() handshakeStream := newCryptoStream() params := &wire.TransportParameters{ @@ -311,9 +309,14 @@ var newConnection = func( DisableActiveMigration: true, StatelessResetToken: &statelessResetToken, OriginalDestinationConnectionID: origDestConnID, - ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, - InitialSourceConnectionID: srcConnID, - RetrySourceConnectionID: retrySrcConnID, + // For interoperability with quic-go versions before May 2023, this value must be set to a value + // different from protocol.DefaultActiveConnectionIDLimit. + // If set to the default value, it will be omitted from the transport parameters, which will make + // old quic-go versions interpret it as 0, instead of the default value of 2. + // See https://github.com/quic-go/quic-go/pull/3806. + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, + RetrySourceConnectionID: retrySrcConnID, } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize @@ -323,10 +326,6 @@ var newConnection = func( if s.tracer != nil { s.tracer.SentTransportParameters(params) } - var allow0RTT func() bool - if conf.Allow0RTT != nil { - allow0RTT = func() bool { return conf.Allow0RTT(conn.RemoteAddr()) } - } cs := handshake.NewCryptoSetupServer( initialStream, handshakeStream, @@ -344,14 +343,14 @@ var newConnection = func( }, }, tlsConf, - allow0RTT, + conf.Allow0RTT, s.rttStats, tracer, logger, s.version, ) s.cryptoStreamHandler = cs - s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) + s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream) return s @@ -363,6 +362,7 @@ var newClientConnection = func( runner connRunner, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, conf *Config, tlsConf *tls.Config, initialPacketNumber protocol.PacketNumber, @@ -402,7 +402,7 @@ var newClientConnection = func( runner.Retire, runner.ReplaceWithClosed, s.queueControlFrame, - s.config.ConnectionIDGenerator, + connIDGenerator, ) s.preSetup() s.ctx, s.ctxCancel = context.WithCancel(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) @@ -415,6 +415,7 @@ var newClientConnection = func( s.tracer, s.logger, ) + s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) initialStream := newCryptoStream() handshakeStream := newCryptoStream() params := &wire.TransportParameters{ @@ -428,8 +429,13 @@ var newClientConnection = func( MaxAckDelay: protocol.MaxAckDelayInclGranularity, AckDelayExponent: protocol.AckDelayExponent, DisableActiveMigration: true, - ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, - InitialSourceConnectionID: srcConnID, + // For interoperability with quic-go versions before May 2023, this value must be set to a value + // different from protocol.DefaultActiveConnectionIDLimit. + // If set to the default value, it will be omitted from the transport parameters, which will make + // old quic-go versions interpret it as 0, instead of the default value of 2. + // See https://github.com/quic-go/quic-go/pull/3806. + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, } if s.config.EnableDatagrams { params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize @@ -463,7 +469,7 @@ var newClientConnection = func( s.cryptoStreamHandler = cs s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) - s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, s.RemoteAddr(), cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) + s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) if len(tlsConf.ServerName) > 0 { s.tokenStoreKey = tlsConf.ServerName } else { @@ -504,7 +510,7 @@ func (s *connection) preSetup() { s.perspective, ) s.framer = newFramer(s.streamsMap) - s.receivedPackets = make(chan *receivedPacket, protocol.MaxConnUnprocessedPackets) + s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets) s.closeChan = make(chan closeError, 1) s.sendingScheduled = make(chan struct{}, 1) s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background()) @@ -657,7 +663,7 @@ runLoop: } else { idleTimeoutStartTime := s.idleTimeoutStartTime() if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) || - (s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.idleTimeout) { + (s.handshakeComplete && now.After(s.nextIdleTimeoutTime())) { s.destroyImpl(qerr.ErrIdleTimeout) continue } @@ -669,7 +675,7 @@ runLoop: sendQueueAvailable = s.sendQueue.Available() continue } - if err := s.sendPackets(); err != nil { + if err := s.triggerSending(); err != nil { s.closeLocal(err) } if s.sendQueue.WouldBlock() { @@ -681,12 +687,12 @@ runLoop: s.cryptoStreamHandler.Close() <-handshaking + s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE s.handleCloseError(&closeErr) if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { s.tracer.Close() } s.logger.Infof("Connection %s closed.", s.logID) - s.sendQueue.Close() s.timer.Stop() return closeErr.err } @@ -715,13 +721,20 @@ func (s *connection) ConnectionState() ConnectionState { return s.connState } +// Time when the connection should time out +func (s *connection) nextIdleTimeoutTime() time.Time { + idleTimeout := utils.Max(s.idleTimeout, s.rttStats.PTO(true)*3) + return s.idleTimeoutStartTime().Add(idleTimeout) +} + // Time when the next keep-alive packet should be sent. // It returns a zero time if no keep-alive should be sent. func (s *connection) nextKeepAliveTime() time.Time { if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() { return time.Time{} } - return s.lastPacketReceivedTime.Add(s.keepAliveInterval) + keepAliveInterval := utils.Max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2) + return s.lastPacketReceivedTime.Add(keepAliveInterval) } func (s *connection) maybeResetTimer() { @@ -735,7 +748,7 @@ func (s *connection) maybeResetTimer() { if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() { deadline = keepAliveTime } else { - deadline = s.idleTimeoutStartTime().Add(s.idleTimeout) + deadline = s.nextIdleTimeoutTime() } } @@ -792,25 +805,16 @@ func (s *connection) handleHandshakeConfirmed() { s.sentPacketHandler.SetHandshakeConfirmed() s.cryptoStreamHandler.SetHandshakeConfirmed() - if !s.config.DisablePathMTUDiscovery { + if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF { maxPacketSize := s.peerParams.MaxUDPPayloadSize if maxPacketSize == 0 { maxPacketSize = protocol.MaxByteCount } - maxPacketSize = utils.Min(maxPacketSize, protocol.MaxPacketBufferSize) - s.mtuDiscoverer = newMTUDiscoverer( - s.rttStats, - getMaxPacketSize(s.conn.RemoteAddr()), - maxPacketSize, - func(size protocol.ByteCount) { - s.sentPacketHandler.SetMaxDatagramSize(size) - s.packer.SetMaxPacketSize(size) - }, - ) + s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize)) } } -func (s *connection) handlePacketImpl(rp *receivedPacket) bool { +func (s *connection) handlePacketImpl(rp receivedPacket) bool { s.sentPacketHandler.ReceivedBytes(rp.Size()) if wire.IsVersionNegotiationPacket(rp.data) { @@ -826,7 +830,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { for len(data) > 0 { var destConnID protocol.ConnectionID if counter > 0 { - p = p.Clone() + p = *(p.Clone()) p.data = data var err error @@ -899,7 +903,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { return processed } -func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool { +func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protocol.ConnectionID) bool { var wasQueued bool defer func() { @@ -950,7 +954,7 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID proto return true } -func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { +func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { var wasQueued bool defer func() { @@ -1007,7 +1011,7 @@ func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) return true } -func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) { +func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) { switch err { case handshake.ErrKeysDropped: if s.tracer != nil { @@ -1109,7 +1113,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa return true } -func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { +func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets if s.tracer != nil { @@ -1253,7 +1257,11 @@ func (s *connection) handleFrames( ) (isAckEliciting bool, _ error) { // Only used for tracing. // If we're not tracing, this slice will always remain empty. - var frames []wire.Frame + var frames []logging.Frame + if log != nil { + frames = make([]logging.Frame, 0, 4) + } + var handleErr error for len(data) > 0 { l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version) if err != nil { @@ -1266,27 +1274,27 @@ func (s *connection) handleFrames( if ackhandler.IsFrameAckEliciting(frame) { isAckEliciting = true } - // Only process frames now if we're not logging. - // If we're logging, we need to make sure that the packet_received event is logged first. - if log == nil { - if err := s.handleFrame(frame, encLevel, destConnID); err != nil { + if log != nil { + frames = append(frames, logutils.ConvertFrame(frame)) + } + // An error occurred handling a previous frame. + // Don't handle the current frame. + if handleErr != nil { + continue + } + if err := s.handleFrame(frame, encLevel, destConnID); err != nil { + if log == nil { return false, err } - } else { - frames = append(frames, frame) + // If we're logging, we need to keep parsing (but not handling) all frames. + handleErr = err } } if log != nil { - fs := make([]logging.Frame, len(frames)) - for i, frame := range frames { - fs[i] = logutils.ConvertFrame(frame) - } - log(fs) - for _, frame := range frames { - if err := s.handleFrame(frame, encLevel, destConnID); err != nil { - return false, err - } + log(frames) + if handleErr != nil { + return false, handleErr } } return @@ -1302,7 +1310,6 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel err = s.handleStreamFrame(frame) case *wire.AckFrame: err = s.handleAckFrame(frame, encLevel) - wire.PutAckFrame(frame) case *wire.ConnectionCloseFrame: s.handleConnectionCloseFrame(frame) case *wire.ResetStreamFrame: @@ -1341,7 +1348,7 @@ func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel } // handlePacket is called by the server with a new packet -func (s *connection) handlePacket(p *receivedPacket) { +func (s *connection) handlePacket(p receivedPacket) { // Discard packets once the amount of queued packets is larger than // the channel size, protocol.MaxConnUnprocessedPackets select { @@ -1715,7 +1722,6 @@ func (s *connection) applyTransportParameters() { s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval)) s.streamsMap.UpdateLimits(params) - s.packer.HandleTransportParameters(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) s.rttStats.SetMaxAckDelay(params.MaxAckDelay) @@ -1730,75 +1736,208 @@ func (s *connection) applyTransportParameters() { } } -func (s *connection) sendPackets() error { +func (s *connection) triggerSending() error { s.pacingDeadline = time.Time{} + now := time.Now() - var sentPacket bool // only used in for packets sent in send mode SendAny - for { - sendMode := s.sentPacketHandler.SendMode() - if sendMode == ackhandler.SendAny && s.handshakeComplete && !s.sentPacketHandler.HasPacingBudget() { - deadline := s.sentPacketHandler.TimeUntilSend() - if deadline.IsZero() { - deadline = deadlineSendImmediately - } - s.pacingDeadline = deadline - // Allow sending of an ACK if we're pacing limit (if we haven't sent out a packet yet). - // This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate) - // sends enough ACKs to allow its peer to utilize the bandwidth. - if sentPacket { - return nil - } - sendMode = ackhandler.SendAck + sendMode := s.sentPacketHandler.SendMode(now) + //nolint:exhaustive // No need to handle pacing limited here. + switch sendMode { + case ackhandler.SendAny: + return s.sendPackets(now) + case ackhandler.SendNone: + return nil + case ackhandler.SendPacingLimited: + deadline := s.sentPacketHandler.TimeUntilSend() + if deadline.IsZero() { + deadline = deadlineSendImmediately } - switch sendMode { - case ackhandler.SendNone: + s.pacingDeadline = deadline + // Allow sending of an ACK if we're pacing limit. + // This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate) + // sends enough ACKs to allow its peer to utilize the bandwidth. + fallthrough + case ackhandler.SendAck: + // We can at most send a single ACK only packet. + // There will only be a new ACK after receiving new packets. + // SendAck is only returned when we're congestion limited, so we don't need to set the pacinggs timer. + return s.maybeSendAckOnlyPacket(now) + case ackhandler.SendPTOInitial: + if err := s.sendProbePacket(protocol.EncryptionInitial, now); err != nil { + return err + } + if s.sendQueue.WouldBlock() { + s.scheduleSending() return nil - case ackhandler.SendAck: - // If we already sent packets, and the send mode switches to SendAck, - // as we've just become congestion limited. - // There's no need to try to send an ACK at this moment. - if sentPacket { + } + return s.triggerSending() + case ackhandler.SendPTOHandshake: + if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil { + return err + } + if s.sendQueue.WouldBlock() { + s.scheduleSending() + return nil + } + return s.triggerSending() + case ackhandler.SendPTOAppData: + if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil { + return err + } + if s.sendQueue.WouldBlock() { + s.scheduleSending() + return nil + } + return s.triggerSending() + default: + return fmt.Errorf("BUG: invalid send mode %d", sendMode) + } +} + +func (s *connection) sendPackets(now time.Time) error { + // Path MTU Discovery + // Can't use GSO, since we need to send a single packet that's larger than our current maximum size. + // Performance-wise, this doesn't matter, since we only send a very small (<10) number of + // MTU probe packets per connection. + if s.handshakeConfirmed && s.mtuDiscoverer != nil && s.mtuDiscoverer.ShouldSendProbe(now) { + ping, size := s.mtuDiscoverer.GetPing() + p, buf, err := s.packer.PackMTUProbePacket(ping, size, s.version) + if err != nil { + return err + } + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, now) + s.sendQueue.Send(buf, buf.Len()) + // This is kind of a hack. We need to trigger sending again somehow. + s.pacingDeadline = deadlineSendImmediately + return nil + } + + if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { + s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) + } + s.windowUpdateQueue.QueueAll() + + if !s.handshakeConfirmed { + packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version) + if err != nil || packet == nil { + return err + } + s.sentFirstPacket = true + s.sendPackedCoalescedPacket(packet, now) + sendMode := s.sentPacketHandler.SendMode(now) + if sendMode == ackhandler.SendPacingLimited { + s.resetPacingDeadline() + } else if sendMode == ackhandler.SendAny { + s.pacingDeadline = deadlineSendImmediately + } + return nil + } + + if s.conn.capabilities().GSO { + return s.sendPacketsWithGSO(now) + } + return s.sendPacketsWithoutGSO(now) +} + +func (s *connection) sendPacketsWithoutGSO(now time.Time) error { + for { + buf := getPacketBuffer() + if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil { + if err == errNothingToPack { + buf.Release() return nil } - // We can at most send a single ACK only packet. - // There will only be a new ACK after receiving new packets. - // SendAck is only returned when we're congestion limited, so we don't need to set the pacinggs timer. - return s.maybeSendAckOnlyPacket() - case ackhandler.SendPTOInitial: - if err := s.sendProbePacket(protocol.EncryptionInitial); err != nil { - return err - } - case ackhandler.SendPTOHandshake: - if err := s.sendProbePacket(protocol.EncryptionHandshake); err != nil { - return err - } - case ackhandler.SendPTOAppData: - if err := s.sendProbePacket(protocol.Encryption1RTT); err != nil { - return err - } - case ackhandler.SendAny: - sent, err := s.sendPacket() - if err != nil || !sent { - return err - } - sentPacket = true - default: - return fmt.Errorf("BUG: invalid send mode %d", sendMode) + return err + } + + s.sendQueue.Send(buf, buf.Len()) + + if s.sendQueue.WouldBlock() { + return nil + } + sendMode := s.sentPacketHandler.SendMode(now) + if sendMode == ackhandler.SendPacingLimited { + s.resetPacingDeadline() + return nil + } + if sendMode != ackhandler.SendAny { + return nil } // Prioritize receiving of packets over sending out more packets. if len(s.receivedPackets) > 0 { s.pacingDeadline = deadlineSendImmediately return nil } - if s.sendQueue.WouldBlock() { - return nil - } } } -func (s *connection) maybeSendAckOnlyPacket() error { +func (s *connection) sendPacketsWithGSO(now time.Time) error { + buf := getLargePacketBuffer() + maxSize := s.mtuDiscoverer.CurrentSize() + + for { + var dontSendMore bool + size, err := s.appendPacket(buf, maxSize, now) + if err != nil { + if err != errNothingToPack { + return err + } + if buf.Len() == 0 { + buf.Release() + return nil + } + dontSendMore = true + } + + if !dontSendMore { + sendMode := s.sentPacketHandler.SendMode(now) + if sendMode == ackhandler.SendPacingLimited { + s.resetPacingDeadline() + } + if sendMode != ackhandler.SendAny { + dontSendMore = true + } + } + + // Append another packet if + // 1. The congestion controller and pacer allow sending more + // 2. The last packet appended was a full-size packet + // 3. We still have enough space for another full-size packet in the buffer + if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() { + continue + } + + s.sendQueue.Send(buf, maxSize) + + if dontSendMore { + return nil + } + if s.sendQueue.WouldBlock() { + return nil + } + + // Prioritize receiving of packets over sending out more packets. + if len(s.receivedPackets) > 0 { + s.pacingDeadline = deadlineSendImmediately + return nil + } + + buf = getLargePacketBuffer() + } +} + +func (s *connection) resetPacingDeadline() { + deadline := s.sentPacketHandler.TimeUntilSend() + if deadline.IsZero() { + deadline = deadlineSendImmediately + } + s.pacingDeadline = deadline +} + +func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { if !s.handshakeConfirmed { - packet, err := s.packer.PackCoalescedPacket(true, s.version) + packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err } @@ -1809,20 +1948,20 @@ func (s *connection) maybeSendAckOnlyPacket() error { return nil } - now := time.Now() - p, buffer, err := s.packer.PackPacket(true, now, s.version) + p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { if err == errNothingToPack { return nil } return err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) - s.sendPackedShortHeaderPacket(buffer, p.Packet, now) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) + s.registerPackedShortHeaderPacket(p, now) + s.sendQueue.Send(buf, buf.Len()) return nil } -func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { +func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time.Time) error { // Queue probe packets until we actually send out a packet, // or until there are no more packets to queue. var packet *coalescedPacket @@ -1831,7 +1970,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { break } var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) + packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err } @@ -1840,19 +1979,9 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { } } if packet == nil { - //nolint:exhaustive // Cannot send probe packets for 0-RTT. - switch encLevel { - case protocol.EncryptionInitial: - s.retransmissionQueue.AddInitial(&wire.PingFrame{}) - case protocol.EncryptionHandshake: - s.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - case protocol.Encryption1RTT: - s.retransmissionQueue.AddAppData(&wire.PingFrame{}) - default: - panic("unexpected encryption level") - } + s.retransmissionQueue.AddPing(encLevel) var err error - packet, err = s.packer.MaybePackProbePacket(encLevel, s.version) + packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version) if err != nil { return err } @@ -1860,55 +1989,35 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } - s.sendPackedCoalescedPacket(packet, time.Now()) + s.sendPackedCoalescedPacket(packet, now) return nil } -func (s *connection) sendPacket() (bool, error) { - if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked { - s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) - } - s.windowUpdateQueue.QueueAll() - - now := time.Now() - if !s.handshakeConfirmed { - packet, err := s.packer.PackCoalescedPacket(false, s.version) - if err != nil || packet == nil { - return false, err - } - s.sentFirstPacket = true - s.sendPackedCoalescedPacket(packet, now) - return true, nil - } else if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) { - ping, size := s.mtuDiscoverer.GetPing() - p, buffer, err := s.packer.PackMTUProbePacket(ping, size, now, s.version) - if err != nil { - return false, err - } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) - s.sendPackedShortHeaderPacket(buffer, p.Packet, now) - return true, nil - } - p, buffer, err := s.packer.PackPacket(false, now, s.version) +// appendPacket appends a new packet to the given packetBuffer. +// If there was nothing to pack, the returned size is 0. +func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) { + startLen := buf.Len() + p, err := s.packer.AppendPacket(buf, maxSize, s.version) if err != nil { - if err == errNothingToPack { - return false, nil - } - return false, err + return 0, err } - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) - s.sendPackedShortHeaderPacket(buffer, p.Packet, now) - return true, nil + size := buf.Len() - startLen + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false) + s.registerPackedShortHeaderPacket(p, now) + return size, nil } -func (s *connection) sendPackedShortHeaderPacket(buffer *packetBuffer, p *ackhandler.Packet, now time.Time) { - if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && ackhandler.HasAckElicitingFrames(p.Frames) { +func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { s.firstAckElicitingPacketAfterIdleSentTime = now } - s.sentPacketHandler.SentPacket(p) + largestAcked := protocol.InvalidPacketNumber + if p.Ack != nil { + largestAcked = p.Ack.LargestAcked() + } + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) s.connIDManager.SentPacket() - s.sendQueue.Send(buffer) } func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) { @@ -1917,16 +2026,24 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now } - s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) + largestAcked := protocol.InvalidPacketNumber + if p.ack != nil { + largestAcked = p.ack.LargestAcked() + } + s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false) } if p := packet.shortHdrPacket; p != nil { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now } - s.sentPacketHandler.SentPacket(p.Packet) + largestAcked := protocol.InvalidPacketNumber + if p.Ack != nil { + largestAcked = p.Ack.LargestAcked() + } + s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) } s.connIDManager.SentPacket() - s.sendQueue.Send(packet.buffer) + s.sendQueue.Send(packet.buffer, packet.buffer.Len()) } func (s *connection) sendConnectionClose(e error) ([]byte, error) { @@ -1935,20 +2052,20 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { var transportErr *qerr.TransportError var applicationErr *qerr.ApplicationError if errors.As(e, &transportErr) { - packet, err = s.packer.PackConnectionClose(transportErr, s.version) + packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version) } else if errors.As(e, &applicationErr) { - packet, err = s.packer.PackApplicationClose(applicationErr, s.version) + packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version) } else { packet, err = s.packer.PackConnectionClose(&qerr.TransportError{ ErrorCode: qerr.InternalError, ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()), - }, s.version) + }, s.mtuDiscoverer.CurrentSize(), s.version) } if err != nil { return nil, err } s.logCoalescedPacket(packet) - return packet.buffer.Data, s.conn.Write(packet.buffer.Data) + return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len()) } func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { @@ -1980,7 +2097,8 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { func (s *connection) logShortHeaderPacket( destConnID protocol.ConnectionID, ackFrame *wire.AckFrame, - frames []*ackhandler.Frame, + frames []ackhandler.Frame, + streamFrames []ackhandler.StreamFrame, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, @@ -1996,17 +2114,23 @@ func (s *connection) logShortHeaderPacket( if ackFrame != nil { wire.LogFrame(s.logger, ackFrame, true) } - for _, frame := range frames { - wire.LogFrame(s.logger, frame.Frame, true) + for _, f := range frames { + wire.LogFrame(s.logger, f.Frame, true) + } + for _, f := range streamFrames { + wire.LogFrame(s.logger, f.Frame, true) } } // tracing if s.tracer != nil { - fs := make([]logging.Frame, 0, len(frames)) + fs := make([]logging.Frame, 0, len(frames)+len(streamFrames)) for _, f := range frames { fs = append(fs, logutils.ConvertFrame(f.Frame)) } + for _, f := range streamFrames { + fs = append(fs, logutils.ConvertFrame(f.Frame)) + } var ack *logging.AckFrame if ackFrame != nil { ack = logutils.ConvertAckFrame(ackFrame) @@ -2034,6 +2158,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { packet.shortHdrPacket.DestConnID, packet.shortHdrPacket.Ack, packet.shortHdrPacket.Frames, + packet.shortHdrPacket.StreamFrames, packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.KeyPhase, @@ -2052,7 +2177,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) { s.logLongHeaderPacket(p) } if p := packet.shortHdrPacket; p != nil { - s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) } } @@ -2113,7 +2238,7 @@ func (s *connection) scheduleSending() { // tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys. // The logging.PacketType is only used for logging purposes. -func (s *connection) tryQueueingUndecryptablePacket(p *receivedPacket, pt logging.PacketType) { +func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging.PacketType) { if s.handshakeComplete { panic("shouldn't queue undecryptable packets after handshake completion") } diff --git a/vendor/github.com/quic-go/quic-go/framer.go b/vendor/github.com/quic-go/quic-go/framer.go index 0b205916..9409af4c 100644 --- a/vendor/github.com/quic-go/quic-go/framer.go +++ b/vendor/github.com/quic-go/quic-go/framer.go @@ -6,6 +6,7 @@ import ( "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils/ringbuffer" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/quicvarint" ) @@ -14,10 +15,10 @@ type framer interface { HasData() bool QueueControlFrame(wire.Frame) - AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) AddActiveStream(protocol.StreamID) - AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) Handle0RTTRejection() error } @@ -28,7 +29,7 @@ type framerI struct { streamGetter streamGetter activeStreams map[protocol.StreamID]struct{} - streamQueue []protocol.StreamID + streamQueue ringbuffer.RingBuffer[protocol.StreamID] controlFrameMutex sync.Mutex controlFrames []wire.Frame @@ -45,7 +46,7 @@ func newFramer(streamGetter streamGetter) framer { func (f *framerI) HasData() bool { f.mutex.Lock() - hasData := len(f.streamQueue) > 0 + hasData := !f.streamQueue.Empty() f.mutex.Unlock() if hasData { return true @@ -62,7 +63,7 @@ func (f *framerI) QueueControlFrame(frame wire.Frame) { f.controlFrameMutex.Unlock() } -func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { +func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) { var length protocol.ByteCount f.controlFrameMutex.Lock() for len(f.controlFrames) > 0 { @@ -71,9 +72,7 @@ func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protoco if length+frameLen > maxLen { break } - af := ackhandler.GetFrame() - af.Frame = frame - frames = append(frames, af) + frames = append(frames, ackhandler.Frame{Frame: frame}) length += frameLen f.controlFrames = f.controlFrames[:len(f.controlFrames)-1] } @@ -84,24 +83,23 @@ func (f *framerI) AppendControlFrames(frames []*ackhandler.Frame, maxLen protoco func (f *framerI) AddActiveStream(id protocol.StreamID) { f.mutex.Lock() if _, ok := f.activeStreams[id]; !ok { - f.streamQueue = append(f.streamQueue, id) + f.streamQueue.PushBack(id) f.activeStreams[id] = struct{}{} } f.mutex.Unlock() } -func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) { +func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) { + startLen := len(frames) var length protocol.ByteCount - var lastFrame *ackhandler.Frame f.mutex.Lock() // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet - numActiveStreams := len(f.streamQueue) + numActiveStreams := f.streamQueue.Len() for i := 0; i < numActiveStreams; i++ { if protocol.MinStreamFrameSize+length > maxLen { break } - id := f.streamQueue[0] - f.streamQueue = f.streamQueue[1:] + id := f.streamQueue.PopFront() // This should never return an error. Better check it anyway. // The stream will only be in the streamQueue, if it enqueued itself there. str, err := f.streamGetter.GetOrOpenSendStream(id) @@ -115,28 +113,27 @@ func (f *framerI) AppendStreamFrames(frames []*ackhandler.Frame, maxLen protocol // Therefore, we can pretend to have more bytes available when popping // the STREAM frame (which will always have the DataLen set). remainingLen += quicvarint.Len(uint64(remainingLen)) - frame, hasMoreData := str.popStreamFrame(remainingLen, v) + frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v) if hasMoreData { // put the stream back in the queue (at the end) - f.streamQueue = append(f.streamQueue, id) - } else { // no more data to send. Stream is not active any more + f.streamQueue.PushBack(id) + } else { // no more data to send. Stream is not active delete(f.activeStreams, id) } - // The frame can be nil + // The frame can be "nil" // * if the receiveStream was canceled after it said it had data // * the remaining size doesn't allow us to add another STREAM frame - if frame == nil { + if !ok { continue } frames = append(frames, frame) - length += frame.Length(v) - lastFrame = frame + length += frame.Frame.Length(v) } f.mutex.Unlock() - if lastFrame != nil { - lastFrameLen := lastFrame.Length(v) + if len(frames) > startLen { + l := frames[len(frames)-1].Frame.Length(v) // account for the smaller size of the last STREAM frame - lastFrame.Frame.(*wire.StreamFrame).DataLenPresent = false - length += lastFrame.Length(v) - lastFrameLen + frames[len(frames)-1].Frame.DataLenPresent = false + length += frames[len(frames)-1].Frame.Length(v) - l } return frames, length } @@ -146,7 +143,7 @@ func (f *framerI) Handle0RTTRejection() error { defer f.mutex.Unlock() f.controlFrameMutex.Lock() - f.streamQueue = f.streamQueue[:0] + f.streamQueue.Clear() for id := range f.activeStreams { delete(f.activeStreams, id) } diff --git a/vendor/github.com/quic-go/quic-go/http3/client.go b/vendor/github.com/quic-go/quic-go/http3/client.go index d89f2090..c54de5ea 100644 --- a/vendor/github.com/quic-go/quic-go/http3/client.go +++ b/vendor/github.com/quic-go/quic-go/http3/client.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "strconv" "sync" @@ -33,12 +34,11 @@ const ( var defaultQuicConfig = &quic.Config{ MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams KeepAlivePeriod: 10 * time.Second, - Versions: []protocol.VersionNumber{protocol.Version1}, } type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) -var dialAddr = quic.DialAddrEarlyContext +var dialAddr dialFunc = quic.DialAddrEarly type roundTripperOpts struct { DisableCompression bool @@ -74,9 +74,10 @@ var _ roundTripCloser = &client{} func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { if conf == nil { conf = defaultQuicConfig.Clone() - } else if len(conf.Versions) == 0 { + } + if len(conf.Versions) == 0 { conf = conf.Clone() - conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} + conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]} } if len(conf.Versions) != 1 { return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") @@ -92,6 +93,14 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con } else { tlsConf = tlsConf.Clone() } + if tlsConf.ServerName == "" { + sni, _, err := net.SplitHostPort(hostname) + if err != nil { + // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. + sni = hostname + } + tlsConf.ServerName = sni + } // Replace existing ALPNs by H3 tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} diff --git a/vendor/github.com/quic-go/quic-go/http3/mockgen.go b/vendor/github.com/quic-go/quic-go/http3/mockgen.go index cb370373..38939e60 100644 --- a/vendor/github.com/quic-go/quic-go/http3/mockgen.go +++ b/vendor/github.com/quic-go/quic-go/http3/mockgen.go @@ -4,3 +4,5 @@ package http3 //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser" type RoundTripCloser = roundTripCloser + +//go:generate sh -c "go run github.com/golang/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener" diff --git a/vendor/github.com/quic-go/quic-go/http3/response_writer.go b/vendor/github.com/quic-go/quic-go/http3/response_writer.go index 5cc32923..b7c79d50 100644 --- a/vendor/github.com/quic-go/quic-go/http3/response_writer.go +++ b/vendor/github.com/quic-go/quic-go/http3/response_writer.go @@ -6,6 +6,7 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/utils" @@ -15,6 +16,7 @@ import ( type responseWriter struct { conn quic.Connection + str quic.Stream bufferedStr *bufio.Writer buf []byte @@ -36,6 +38,7 @@ func newResponseWriter(str quic.Stream, conn quic.Connection, logger utils.Logge header: http.Header{}, buf: make([]byte, 16), conn: conn, + str: str, bufferedStr: bufio.NewWriter(str), logger: logger, } @@ -121,6 +124,14 @@ func (w *responseWriter) StreamCreator() StreamCreator { return w.conn } +func (w *responseWriter) SetReadDeadline(deadline time.Time) error { + return w.str.SetReadDeadline(deadline) +} + +func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { + return w.str.SetWriteDeadline(deadline) +} + // copied from http2/http2.go // bodyAllowedForStatus reports whether a given response status code // permits a body. See RFC 2616, section 4.4. diff --git a/vendor/github.com/quic-go/quic-go/http3/roundtrip.go b/vendor/github.com/quic-go/quic-go/http3/roundtrip.go index 95506cac..066b762b 100644 --- a/vendor/github.com/quic-go/quic-go/http3/roundtrip.go +++ b/vendor/github.com/quic-go/quic-go/http3/roundtrip.go @@ -10,21 +10,24 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "golang.org/x/net/http/httpguts" "github.com/quic-go/quic-go" ) -// declare this as a variable, such that we can it mock it in the tests -var quicDialer = quic.DialEarlyContext - type roundTripCloser interface { RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) HandshakeComplete() bool io.Closer } +type roundTripCloserWithCount struct { + roundTripCloser + useCount atomic.Int64 +} + // RoundTripper implements the http.RoundTripper interface type RoundTripper struct { mutex sync.Mutex @@ -82,8 +85,8 @@ type RoundTripper struct { MaxResponseHeaderBytes int64 newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests - clients map[string]roundTripCloser - udpConn *net.UDPConn + clients map[string]*roundTripCloserWithCount + transport *quic.Transport } // RoundTripOpt are options for the Transport.RoundTripOpt method. @@ -143,6 +146,7 @@ func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http. if err != nil { return nil, err } + defer cl.useCount.Add(-1) rsp, err := cl.RoundTripOpt(req, opt) if err != nil { r.removeClient(hostname) @@ -160,12 +164,12 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r.RoundTripOpt(req, RoundTripOpt{}) } -func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTripCloser, isReused bool, err error) { +func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { r.mutex.Lock() defer r.mutex.Unlock() if r.clients == nil { - r.clients = make(map[string]roundTripCloser) + r.clients = make(map[string]*roundTripCloserWithCount) } client, ok := r.clients[hostname] @@ -180,15 +184,16 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri } dial := r.Dial if dial == nil { - if r.udpConn == nil { - r.udpConn, err = net.ListenUDP("udp", nil) + if r.transport == nil { + udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, false, err } + r.transport = &quic.Transport{Conn: udpConn} } dial = r.makeDialer() } - client, err = newCl( + c, err := newCl( hostname, r.TLSClientConfig, &roundTripperOpts{ @@ -204,10 +209,12 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc roundTri if err != nil { return nil, false, err } + client = &roundTripCloserWithCount{roundTripCloser: c} r.clients[hostname] = client } else if client.HandshakeComplete() { isReused = true } + client.useCount.Add(1) return client, isReused, nil } @@ -231,9 +238,14 @@ func (r *RoundTripper) Close() error { } } r.clients = nil - if r.udpConn != nil { - r.udpConn.Close() - r.udpConn = nil + if r.transport != nil { + if err := r.transport.Close(); err != nil { + return err + } + if err := r.transport.Conn.Close(); err != nil { + return err + } + r.transport = nil } return nil } @@ -273,6 +285,17 @@ func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCf if err != nil { return nil, err } - return quicDialer(ctx, r.udpConn, udpAddr, addr, tlsCfg, cfg) + return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } +} + +func (r *RoundTripper) CloseIdleConnections() { + r.mutex.Lock() + defer r.mutex.Unlock() + for hostname, client := range r.clients { + if client.useCount.Load() == 0 { + client.Close() + delete(r.clients, hostname) + } } } diff --git a/vendor/github.com/quic-go/quic-go/http3/server.go b/vendor/github.com/quic-go/quic-go/http3/server.go index e74247ab..a03b1fc5 100644 --- a/vendor/github.com/quic-go/quic-go/http3/server.go +++ b/vendor/github.com/quic-go/quic-go/http3/server.go @@ -23,8 +23,12 @@ import ( // allows mocking of quic.Listen and quic.ListenAddr var ( - quicListen = quic.ListenEarly - quicListenAddr = quic.ListenAddrEarly + quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { + return quic.ListenEarly(conn, tlsConf, config) + } + quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { + return quic.ListenAddrEarly(addr, tlsConf, config) + } ) const ( @@ -44,6 +48,15 @@ const ( streamTypeQPACKDecoderStream = 3 ) +// A QUICEarlyListener listens for incoming QUIC connections. +type QUICEarlyListener interface { + Accept(context.Context) (quic.EarlyConnection, error) + Addr() net.Addr + io.Closer +} + +var _ QUICEarlyListener = &quic.EarlyListener{} + func versionToALPN(v protocol.VersionNumber) string { //nolint:exhaustive // These are all the versions we care about. switch v { @@ -193,7 +206,7 @@ type Server struct { UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) mutex sync.RWMutex - listeners map[*quic.EarlyListener]listenerInfo + listeners map[*QUICEarlyListener]listenerInfo closed bool @@ -249,13 +262,26 @@ func (s *Server) ServeQUICConn(conn quic.Connection) error { // Make sure you use http3.ConfigureTLSConfig to configure a tls.Config // and use it to construct a http3-friendly QUIC listener. // Closing the server does close the listener. -func (s *Server) ServeListener(ln quic.EarlyListener) error { +// ServeListener always returns a non-nil error. After Shutdown or Close, the returned error is http.ErrServerClosed. +func (s *Server) ServeListener(ln QUICEarlyListener) error { if err := s.addListener(&ln); err != nil { return err } - err := s.serveListener(ln) - s.removeListener(&ln) - return err + defer s.removeListener(&ln) + for { + conn, err := ln.Accept(context.Background()) + if err == quic.ErrServerClosed { + return http.ErrServerClosed + } + if err != nil { + return err + } + go func() { + if err := s.handleConn(conn); err != nil { + s.logger.Debugf(err.Error()) + } + }() + } } var errServerWithoutTLSConfig = errors.New("use of http3.Server without TLSConfig") @@ -275,7 +301,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { baseConf := ConfigureTLSConfig(tlsConf) quicConf := s.QuicConfig if quicConf == nil { - quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }} + quicConf = &quic.Config{Allow0RTT: true} } else { quicConf = s.QuicConfig.Clone() } @@ -283,7 +309,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { quicConf.EnableDatagrams = true } - var ln quic.EarlyListener + var ln QUICEarlyListener var err error if conn == nil { addr := s.Addr @@ -297,26 +323,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error { if err != nil { return err } - if err := s.addListener(&ln); err != nil { - return err - } - err = s.serveListener(ln) - s.removeListener(&ln) - return err -} - -func (s *Server) serveListener(ln quic.EarlyListener) error { - for { - conn, err := ln.Accept(context.Background()) - if err != nil { - return err - } - go func() { - if err := s.handleConn(conn); err != nil { - s.logger.Debugf(err.Error()) - } - }() - } + return s.ServeListener(ln) } func extractPort(addr string) (int, error) { @@ -391,7 +398,7 @@ func (s *Server) generateAltSvcHeader() { // We store a pointer to interface in the map set. This is safe because we only // call trackListener via Serve and can track+defer untrack the same pointer to // local variable there. We never need to compare a Listener from another caller. -func (s *Server) addListener(l *quic.EarlyListener) error { +func (s *Server) addListener(l *QUICEarlyListener) error { s.mutex.Lock() defer s.mutex.Unlock() @@ -402,25 +409,24 @@ func (s *Server) addListener(l *quic.EarlyListener) error { s.logger = utils.DefaultLogger.WithPrefix("server") } if s.listeners == nil { - s.listeners = make(map[*quic.EarlyListener]listenerInfo) + s.listeners = make(map[*QUICEarlyListener]listenerInfo) } if port, err := extractPort((*l).Addr().String()); err == nil { s.listeners[l] = listenerInfo{port} } else { - s.logger.Errorf( - "Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) + s.logger.Errorf("Unable to extract port from listener %+v, will not be announced using SetQuicHeaders: %s", err) s.listeners[l] = listenerInfo{} } s.generateAltSvcHeader() return nil } -func (s *Server) removeListener(l *quic.EarlyListener) { +func (s *Server) removeListener(l *QUICEarlyListener) { s.mutex.Lock() + defer s.mutex.Unlock() delete(s.listeners, l) s.generateAltSvcHeader() - s.mutex.Unlock() } func (s *Server) handleConn(conn quic.Connection) error { diff --git a/vendor/github.com/quic-go/quic-go/interface.go b/vendor/github.com/quic-go/quic-go/interface.go index b700e7c1..8486c7fe 100644 --- a/vendor/github.com/quic-go/quic-go/interface.go +++ b/vendor/github.com/quic-go/quic-go/interface.go @@ -198,7 +198,7 @@ type EarlyConnection interface { // HandshakeComplete blocks until the handshake completes (or fails). // For the client, data sent before completion of the handshake is encrypted with 0-RTT keys. - // For the serfer, data sent before completion of the handshake is encrypted with 1-RTT keys, + // For the server, data sent before completion of the handshake is encrypted with 1-RTT keys, // however the client's identity is only verified once the handshake completes. HandshakeComplete() <-chan struct{} @@ -239,21 +239,12 @@ type ConnectionIDGenerator interface { // Config contains all configuration data needed for a QUIC server or client. type Config struct { + // GetConfigForClient is called for incoming connections. + // If the error is not nil, the connection attempt is refused. + GetConfigForClient func(info *ClientHelloInfo) (*Config, error) // The QUIC versions that can be negotiated. // If not set, it uses all versions available. Versions []VersionNumber - // The length of the connection ID in bytes. - // It can be 0, or any value between 4 and 18. - // If not set, the interpretation depends on where the Config is used: - // If used for dialing an address, a 0 byte connection ID will be used. - // If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used. - // When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call. - ConnectionIDLength int - // An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection. - // The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers. - // By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used. - // Otherwise, if one is provided, then ConnectionIDLength is ignored. - ConnectionIDGenerator ConnectionIDGenerator // HandshakeIdleTimeout is the idle timeout before completion of the handshake. // Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. // If this value is zero, the timeout is set to 5 seconds. @@ -285,17 +276,21 @@ type Config struct { // If the application is consuming data quickly enough, the flow control auto-tuning algorithm // will increase the window up to MaxStreamReceiveWindow. // If this value is zero, it will default to 512 KB. + // Values larger than the maximum varint (quicvarint.Max) will be clipped to that value. InitialStreamReceiveWindow uint64 // MaxStreamReceiveWindow is the maximum stream-level flow control window for receiving data. // If this value is zero, it will default to 6 MB. + // Values larger than the maximum varint (quicvarint.Max) will be clipped to that value. MaxStreamReceiveWindow uint64 // InitialConnectionReceiveWindow is the initial size of the stream-level flow control window for receiving data. // If the application is consuming data quickly enough, the flow control auto-tuning algorithm // will increase the window up to MaxConnectionReceiveWindow. // If this value is zero, it will default to 512 KB. + // Values larger than the maximum varint (quicvarint.Max) will be clipped to that value. InitialConnectionReceiveWindow uint64 // MaxConnectionReceiveWindow is the connection-level flow control window for receiving data. // If this value is zero, it will default to 15 MB. + // Values larger than the maximum varint (quicvarint.Max) will be clipped to that value. MaxConnectionReceiveWindow uint64 // AllowConnectionWindowIncrease is called every time the connection flow controller attempts // to increase the connection flow control window. @@ -305,64 +300,49 @@ type Config struct { // in this callback. AllowConnectionWindowIncrease func(conn Connection, delta uint64) bool // MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open. - // Values above 2^60 are invalid. // If not set, it will default to 100. // If set to a negative value, it doesn't allow any bidirectional streams. + // Values larger than 2^60 will be clipped to that value. MaxIncomingStreams int64 // MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open. - // Values above 2^60 are invalid. // If not set, it will default to 100. // If set to a negative value, it doesn't allow any unidirectional streams. + // Values larger than 2^60 will be clipped to that value. MaxIncomingUniStreams int64 - // The StatelessResetKey is used to generate stateless reset tokens. - // If no key is configured, sending of stateless resets is disabled. - StatelessResetKey *StatelessResetKey // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most // every half of MaxIdleTimeout, whichever is smaller). KeepAlivePeriod time.Duration // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). - // Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. - // Note that if Path MTU discovery is causing issues on your system, please open a new issue + // This allows the sending of QUIC packets that fully utilize the available MTU of the path. + // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit. + // If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. DisablePathMTUDiscovery bool // DisableVersionNegotiationPackets disables the sending of Version Negotiation packets. // This can be useful if version information is exchanged out-of-band. // It has no effect for a client. DisableVersionNegotiationPackets bool // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. - // When set, 0-RTT is enabled. When not set, 0-RTT is disabled. // Only valid for the server. - // Warning: This API should not be considered stable and might change soon. - Allow0RTT func(net.Addr) bool + Allow0RTT bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool - Tracer logging.Tracer + Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer +} + +type ClientHelloInfo struct { + RemoteAddr net.Addr } // ConnectionState records basic details about a QUIC connection type ConnectionState struct { - TLS handshake.ConnectionState + // TLS contains information about the TLS connection state, incl. the tls.ConnectionState. + TLS handshake.ConnectionState + // SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated. + // This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams). + // If datagram support was negotiated, datagrams can be sent and received using the + // SendMessage and ReceiveMessage methods on the Connection. SupportsDatagrams bool - Version VersionNumber -} - -// A Listener for incoming QUIC connections -type Listener interface { - // Close the server. All active connections will be closed. - Close() error - // Addr returns the local network addr that the server is listening on. - Addr() net.Addr - // Accept returns new connections. It should be called in a loop. - Accept(context.Context) (Connection, error) -} - -// An EarlyListener listens for incoming QUIC connections, -// and returns them before the handshake completes. -type EarlyListener interface { - // Close the server. All active connections will be closed. - Close() error - // Addr returns the local network addr that the server is listening on. - Addr() net.Addr - // Accept returns new early connections. It should be called in a loop. - Accept(context.Context) (EarlyConnection, error) + // Version is the QUIC version of the QUIC connection. + Version VersionNumber } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go index 4bab4190..34506b12 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/ack_eliciting.go @@ -10,7 +10,7 @@ func IsFrameAckEliciting(f wire.Frame) bool { } // HasAckElicitingFrames returns true if at least one frame is ack-eliciting. -func HasAckElicitingFrames(fs []*Frame) bool { +func HasAckElicitingFrames(fs []Frame) bool { for _, f := range fs { if IsFrameAckEliciting(f.Frame) { return true diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go index deb23cfc..e03a8080 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/frame.go @@ -1,29 +1,21 @@ package ackhandler import ( - "sync" - "github.com/quic-go/quic-go/internal/wire" ) +// FrameHandler handles the acknowledgement and the loss of a frame. +type FrameHandler interface { + OnAcked(wire.Frame) + OnLost(wire.Frame) +} + type Frame struct { - wire.Frame // nil if the frame has already been acknowledged in another packet - OnLost func(wire.Frame) - OnAcked func(wire.Frame) + Frame wire.Frame // nil if the frame has already been acknowledged in another packet + Handler FrameHandler } -var framePool = sync.Pool{New: func() any { return &Frame{} }} - -func GetFrame() *Frame { - f := framePool.Get().(*Frame) - f.OnLost = nil - f.OnAcked = nil - return f -} - -func putFrame(f *Frame) { - f.Frame = nil - f.OnLost = nil - f.OnAcked = nil - framePool.Put(f) +type StreamFrame struct { + Frame *wire.StreamFrame + Handler FrameHandler } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go index 5924f84b..ba95f8a9 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/interfaces.go @@ -10,20 +10,20 @@ import ( // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet - SentPacket(packet *Packet) - ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) + SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool) + // ReceivedAck processes an ACK frame. + // It does not store a copy of the frame. + ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) ResetForRetry() error SetHandshakeConfirmed() // The SendMode determines if and what kind of packets can be sent. - SendMode() SendMode + SendMode(now time.Time) SendMode // TimeUntilSend is the time when the next packet should be sent. // It is used for pacing packets. TimeUntilSend() time.Time - // HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment. - HasPacingBudget() bool SetMaxDatagramSize(count protocol.ByteCount) // only to be called once the handshake is complete diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go index 394ee40a..5f43689b 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet.go @@ -8,13 +8,14 @@ import ( ) // A Packet is a packet -type Packet struct { +type packet struct { + SendTime time.Time PacketNumber protocol.PacketNumber - Frames []*Frame + StreamFrames []StreamFrame + Frames []Frame LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK Length protocol.ByteCount EncryptionLevel protocol.EncryptionLevel - SendTime time.Time IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller. @@ -23,15 +24,16 @@ type Packet struct { skippedPacket bool } -func (p *Packet) outstanding() bool { +func (p *packet) outstanding() bool { return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket } -var packetPool = sync.Pool{New: func() any { return &Packet{} }} +var packetPool = sync.Pool{New: func() any { return &packet{} }} -func GetPacket() *Packet { - p := packetPool.Get().(*Packet) +func getPacket() *packet { + p := packetPool.Get().(*packet) p.PacketNumber = 0 + p.StreamFrames = nil p.Frames = nil p.LargestAcked = 0 p.Length = 0 @@ -46,10 +48,8 @@ func GetPacket() *Packet { // We currently only return Packets back into the pool when they're acknowledged (not when they're lost). // This simplifies the code, and gives the vast majority of the performance benefit we can gain from using the pool. -func putPacket(p *Packet) { - for _, f := range p.Frames { - putFrame(f) - } +func putPacket(p *packet) { p.Frames = nil + p.StreamFrames = nil packetPool.Put(p) } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go index 9cf20a0b..e84171e3 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go @@ -7,7 +7,10 @@ import ( type packetNumberGenerator interface { Peek() protocol.PacketNumber - Pop() protocol.PacketNumber + // Pop pops the packet number. + // It reports if the packet number (before the one just popped) was skipped. + // It never skips more than one packet number in a row. + Pop() (skipped bool, _ protocol.PacketNumber) } type sequentialPacketNumberGenerator struct { @@ -24,10 +27,10 @@ func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber { return p.next } -func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber { +func (p *sequentialPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) { next := p.next p.next++ - return next + return false, next } // The skippingPacketNumberGenerator generates the packet number for the next packet @@ -56,21 +59,26 @@ func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol } func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber { + if p.next == p.nextToSkip { + return p.next + 1 + } return p.next } -func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber { +func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) { next := p.next - p.next++ // generate a new packet number for the next packet if p.next == p.nextToSkip { - p.next++ + next++ + p.next += 2 p.generateNewSkip() + return true, next } - return next + p.next++ // generate a new packet number for the next packet + return false, next } func (p *skippingPacketNumberGenerator) generateNewSkip() { // make sure that there are never two consecutive packet numbers that are skipped - p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) + p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period))) p.period = utils.Min(2*p.period, p.maxPeriod) } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go index 7132ccaa..b1883866 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go @@ -169,16 +169,18 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame { } } - ack := wire.GetAckFrame() + // This function always returns the same ACK frame struct, filled with the most recent values. + ack := h.lastAck + if ack == nil { + ack = &wire.AckFrame{} + } + ack.Reset() ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime)) ack.ECT0 = h.ect0 ack.ECT1 = h.ect1 ack.ECNCE = h.ecnce ack.AckRanges = h.packetHistory.AppendAckRanges(ack.AckRanges) - if h.lastAck != nil { - wire.PutAckFrame(h.lastAck) - } h.lastAck = ack h.ackAlarm = time.Time{} h.ackQueued = false diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go index 3d5fe560..c03f3a6f 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/send_mode.go @@ -16,6 +16,10 @@ const ( SendPTOHandshake // SendPTOAppData means that an Application data probe packet should be sent SendPTOAppData + // SendPacingLimited means that the pacer doesn't allow sending of a packet right now, + // but will do in a little while. + // The timestamp when sending is allowed again can be obtained via the SentPacketHandler.TimeUntilSend. + SendPacingLimited // SendAny means that any packet should be sent SendAny ) @@ -34,6 +38,8 @@ func (s SendMode) String() string { return "pto (Application Data)" case SendAny: return "any" + case SendPacingLimited: + return "pacing limited" default: return fmt.Sprintf("invalid send mode: %d", s) } diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go index 732bbc3a..07978910 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_handler.go @@ -38,7 +38,7 @@ type packetNumberSpace struct { largestSent protocol.PacketNumber } -func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace { +func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool) *packetNumberSpace { var pns packetNumberGenerator if skipPNs { pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod) @@ -46,7 +46,7 @@ func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStat pns = newSequentialPacketNumberGenerator(initialPN) } return &packetNumberSpace{ - history: newSentPacketHistory(rttStats), + history: newSentPacketHistory(), pns: pns, largestSent: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber, @@ -75,7 +75,7 @@ type sentPacketHandler struct { // Only applies to the application-data packet number space. lowestNotConfirmedAcked protocol.PacketNumber - ackedPackets []*Packet // to avoid allocations in detectAndRemoveAckedPackets + ackedPackets []*packet // to avoid allocations in detectAndRemoveAckedPackets bytesInFlight protocol.ByteCount @@ -125,9 +125,9 @@ func newSentPacketHandler( return &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, - initialPackets: newPacketNumberSpace(initialPN, false, rttStats), - handshakePackets: newPacketNumberSpace(0, false, rttStats), - appDataPackets: newPacketNumberSpace(0, true, rttStats), + initialPackets: newPacketNumberSpace(initialPN, false), + handshakePackets: newPacketNumberSpace(0, false), + appDataPackets: newPacketNumberSpace(0, true), rttStats: rttStats, congestion: congestion, perspective: pers, @@ -146,7 +146,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) { h.dropPackets(encLevel) } -func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) { +func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { if p.includedInBytesInFlight { if p.Length > h.bytesInFlight { panic("negative bytes_in_flight") @@ -165,7 +165,7 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { // remove outstanding packets from bytes_in_flight if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake { pnSpace := h.getPacketNumberSpace(encLevel) - pnSpace.history.Iterate(func(p *Packet) (bool, error) { + pnSpace.history.Iterate(func(p *packet) (bool, error) { h.removeFromBytesInFlight(p) return true, nil }) @@ -182,8 +182,8 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { // and not when the client drops 0-RTT keys when the handshake completes. // When 0-RTT is rejected, all application data sent so far becomes invalid. // Delete the packets from the history and remove them from bytes_in_flight. - h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { - if p.EncryptionLevel != protocol.Encryption0RTT { + h.appDataPackets.history.Iterate(func(p *packet) (bool, error) { + if p.EncryptionLevel != protocol.Encryption0RTT && !p.skippedPacket { return false, nil } h.removeFromBytesInFlight(p) @@ -228,26 +228,64 @@ func (h *sentPacketHandler) packetsInFlight() int { return packetsInFlight } -func (h *sentPacketHandler) SentPacket(p *Packet) { - h.bytesSent += p.Length +func (h *sentPacketHandler) SentPacket( + t time.Time, + pn, largestAcked protocol.PacketNumber, + streamFrames []StreamFrame, + frames []Frame, + encLevel protocol.EncryptionLevel, + size protocol.ByteCount, + isPathMTUProbePacket bool, +) { + h.bytesSent += size // For the client, drop the Initial packet number space when the first Handshake packet is sent. - if h.perspective == protocol.PerspectiveClient && p.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { + if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake && h.initialPackets != nil { h.dropPackets(protocol.EncryptionInitial) } - isAckEliciting := h.sentPacketImpl(p) - if isAckEliciting { - h.getPacketNumberSpace(p.EncryptionLevel).history.SentAckElicitingPacket(p) - } else { - h.getPacketNumberSpace(p.EncryptionLevel).history.SentNonAckElicitingPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) - putPacket(p) - p = nil //nolint:ineffassign // This is just to be on the safe side. + + pnSpace := h.getPacketNumberSpace(encLevel) + if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { + for p := utils.Max(0, pnSpace.largestSent+1); p < pn; p++ { + h.logger.Debugf("Skipping packet number %d", p) + } } - if h.tracer != nil && isAckEliciting { + + pnSpace.largestSent = pn + isAckEliciting := len(streamFrames) > 0 || len(frames) > 0 + + if isAckEliciting { + pnSpace.lastAckElicitingPacketTime = t + h.bytesInFlight += size + if h.numProbesToSend > 0 { + h.numProbesToSend-- + } + } + h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting) + + if !isAckEliciting { + pnSpace.history.SentNonAckElicitingPacket(pn) + if !h.peerCompletedAddressValidation { + h.setLossDetectionTimer() + } + return + } + + p := getPacket() + p.SendTime = t + p.PacketNumber = pn + p.EncryptionLevel = encLevel + p.Length = size + p.LargestAcked = largestAcked + p.StreamFrames = streamFrames + p.Frames = frames + p.IsPathMTUProbePacket = isPathMTUProbePacket + p.includedInBytesInFlight = true + + pnSpace.history.SentAckElicitingPacket(p) + if h.tracer != nil { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } - if isAckEliciting || !h.peerCompletedAddressValidation { - h.setLossDetectionTimer() - } + h.setLossDetectionTimer() } func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace { @@ -263,31 +301,6 @@ func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLev } } -func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ { - pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel) - - if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() { - for p := utils.Max(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ { - h.logger.Debugf("Skipping packet number %d", p) - } - } - - pnSpace.largestSent = packet.PacketNumber - isAckEliciting := len(packet.Frames) > 0 - - if isAckEliciting { - pnSpace.lastAckElicitingPacketTime = packet.SendTime - packet.includedInBytesInFlight = true - h.bytesInFlight += packet.Length - if h.numProbesToSend > 0 { - h.numProbesToSend-- - } - } - h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isAckEliciting) - - return isAckEliciting -} - func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) { pnSpace := h.getPacketNumberSpace(encLevel) @@ -361,7 +374,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } - pnSpace.history.DeleteOldPackets(rcvTime) h.setLossDetectionTimer() return acked1RTTPacket, nil } @@ -371,13 +383,13 @@ func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNu } // Packets are returned in ascending packet number order. -func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) { +func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*packet, error) { pnSpace := h.getPacketNumberSpace(encLevel) h.ackedPackets = h.ackedPackets[:0] ackRangeIndex := 0 lowestAcked := ack.LowestAcked() largestAcked := ack.LargestAcked() - err := pnSpace.history.Iterate(func(p *Packet) (bool, error) { + err := pnSpace.history.Iterate(func(p *packet) (bool, error) { // Ignore packets below the lowest acked if p.PacketNumber < lowestAcked { return true, nil @@ -425,8 +437,13 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL } for _, f := range p.Frames { - if f.OnAcked != nil { - f.OnAcked(f.Frame) + if f.Handler != nil { + f.Handler.OnAcked(f.Frame) + } + } + for _, f := range p.StreamFrames { + if f.Handler != nil { + f.Handler.OnAcked(f.Frame) } } if err := pnSpace.history.Remove(p.PacketNumber); err != nil { @@ -587,30 +604,31 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E lostSendTime := now.Add(-lossDelay) priorInFlight := h.bytesInFlight - return pnSpace.history.Iterate(func(p *Packet) (bool, error) { + return pnSpace.history.Iterate(func(p *packet) (bool, error) { if p.PacketNumber > pnSpace.largestAcked { return false, nil } - if p.declaredLost || p.skippedPacket { - return true, nil - } var packetLost bool if p.SendTime.Before(lostSendTime) { packetLost = true - if h.logger.Debug() { - h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) - } - if h.tracer != nil { - h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) + if !p.skippedPacket { + if h.logger.Debug() { + h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) + } + if h.tracer != nil { + h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) + } } } else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold { packetLost = true - if h.logger.Debug() { - h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) - } - if h.tracer != nil { - h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) + if !p.skippedPacket { + if h.logger.Debug() { + h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) + } + if h.tracer != nil { + h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) + } } } else if pnSpace.lossTime.IsZero() { // Note: This conditional is only entered once per call @@ -621,12 +639,14 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E pnSpace.lossTime = lossTime } if packetLost { - p = pnSpace.history.DeclareLost(p) - // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted - h.removeFromBytesInFlight(p) - h.queueFramesForRetransmission(p) - if !p.IsPathMTUProbePacket { - h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + pnSpace.history.DeclareLost(p.PacketNumber) + if !p.skippedPacket { + // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted + h.removeFromBytesInFlight(p) + h.queueFramesForRetransmission(p) + if !p.IsPathMTUProbePacket { + h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) + } } } return true, nil @@ -689,7 +709,8 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error { h.ptoMode = SendPTOHandshake case protocol.Encryption1RTT: // skip a packet number in order to elicit an immediate ACK - _ = h.PopPacketNumber(protocol.Encryption1RTT) + pn := h.PopPacketNumber(protocol.Encryption1RTT) + h.getPacketNumberSpace(protocol.Encryption1RTT).history.SkippedPacket(pn) h.ptoMode = SendPTOAppData default: return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel) @@ -703,23 +724,25 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time { func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { pnSpace := h.getPacketNumberSpace(encLevel) - - var lowestUnacked protocol.PacketNumber - if p := pnSpace.history.FirstOutstanding(); p != nil { - lowestUnacked = p.PacketNumber - } else { - lowestUnacked = pnSpace.largestAcked + 1 - } - pn := pnSpace.pns.Peek() - return pn, protocol.GetPacketNumberLengthForHeader(pn, lowestUnacked) + // See section 17.1 of RFC 9000. + return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked) } func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber { - return h.getPacketNumberSpace(encLevel).pns.Pop() + pnSpace := h.getPacketNumberSpace(encLevel) + skipped, pn := pnSpace.pns.Pop() + if skipped { + skippedPN := pn - 1 + pnSpace.history.SkippedPacket(skippedPN) + if h.logger.Debug() { + h.logger.Debugf("Skipping packet number %d", skippedPN) + } + } + return pn } -func (h *sentPacketHandler) SendMode() SendMode { +func (h *sentPacketHandler) SendMode(now time.Time) SendMode { numTrackedPackets := h.appDataPackets.history.Len() if h.initialPackets != nil { numTrackedPackets += h.initialPackets.history.Len() @@ -758,6 +781,9 @@ func (h *sentPacketHandler) SendMode() SendMode { } return SendAck } + if !h.congestion.HasPacingBudget(now) { + return SendPacingLimited + } return SendAny } @@ -765,10 +791,6 @@ func (h *sentPacketHandler) TimeUntilSend() time.Time { return h.congestion.TimeUntilSend(h.bytesInFlight) } -func (h *sentPacketHandler) HasPacingBudget() bool { - return h.congestion.HasPacingBudget() -} - func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) { h.congestion.SetMaxDatagramSize(s) } @@ -790,24 +812,32 @@ func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) // TODO: don't declare the packet lost here. // Keep track of acknowledged frames instead. h.removeFromBytesInFlight(p) - pnSpace.history.DeclareLost(p) + pnSpace.history.DeclareLost(p.PacketNumber) return true } -func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) { - if len(p.Frames) == 0 { +func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) { + if len(p.Frames) == 0 && len(p.StreamFrames) == 0 { panic("no frames") } for _, f := range p.Frames { - f.OnLost(f.Frame) + if f.Handler != nil { + f.Handler.OnLost(f.Frame) + } } + for _, f := range p.StreamFrames { + if f.Handler != nil { + f.Handler.OnLost(f.Frame) + } + } + p.StreamFrames = nil p.Frames = nil } func (h *sentPacketHandler) ResetForRetry() error { h.bytesInFlight = 0 var firstPacketSendTime time.Time - h.initialPackets.history.Iterate(func(p *Packet) (bool, error) { + h.initialPackets.history.Iterate(func(p *packet) (bool, error) { if firstPacketSendTime.IsZero() { firstPacketSendTime = p.SendTime } @@ -819,7 +849,7 @@ func (h *sentPacketHandler) ResetForRetry() error { }) // All application data packets sent at this point are 0-RTT packets. // In the case of a Retry, we can assume that the server dropped all of them. - h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) { + h.appDataPackets.history.Iterate(func(p *packet) (bool, error) { if !p.declaredLost && !p.skippedPacket { h.queueFramesForRetransmission(p) } @@ -839,8 +869,8 @@ func (h *sentPacketHandler) ResetForRetry() error { h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) } } - h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats) - h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats) + h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Peek(), false) + h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Peek(), true) oldAlarm := h.alarm h.alarm = time.Time{} if h.tracer != nil { diff --git a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go index 06478399..c14c0f49 100644 --- a/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go +++ b/vendor/github.com/quic-go/quic-go/internal/ackhandler/sent_packet_history.go @@ -2,162 +2,176 @@ package ackhandler import ( "fmt" - "sync" - "time" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/utils" - list "github.com/quic-go/quic-go/internal/utils/linkedlist" ) type sentPacketHistory struct { - rttStats *utils.RTTStats - outstandingPacketList *list.List[*Packet] - etcPacketList *list.List[*Packet] - packetMap map[protocol.PacketNumber]*list.Element[*Packet] - highestSent protocol.PacketNumber + packets []*packet + + numOutstanding int + + highestPacketNumber protocol.PacketNumber } -var packetElementPool sync.Pool - -func init() { - packetElementPool = *list.NewPool[*Packet]() -} - -func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { +func newSentPacketHistory() *sentPacketHistory { return &sentPacketHistory{ - rttStats: rttStats, - outstandingPacketList: list.NewWithPool[*Packet](&packetElementPool), - etcPacketList: list.NewWithPool[*Packet](&packetElementPool), - packetMap: make(map[protocol.PacketNumber]*list.Element[*Packet]), - highestSent: protocol.InvalidPacketNumber, + packets: make([]*packet, 0, 32), + highestPacketNumber: protocol.InvalidPacketNumber, } } -func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) { - h.registerSentPacket(pn, encLevel, t) +func (h *sentPacketHistory) checkSequentialPacketNumberUse(pn protocol.PacketNumber) { + if h.highestPacketNumber != protocol.InvalidPacketNumber { + if pn != h.highestPacketNumber+1 { + panic("non-sequential packet number use") + } + } } -func (h *sentPacketHistory) SentAckElicitingPacket(p *Packet) { - h.registerSentPacket(p.PacketNumber, p.EncryptionLevel, p.SendTime) +func (h *sentPacketHistory) SkippedPacket(pn protocol.PacketNumber) { + h.checkSequentialPacketNumberUse(pn) + h.highestPacketNumber = pn + h.packets = append(h.packets, &packet{ + PacketNumber: pn, + skippedPacket: true, + }) +} - var el *list.Element[*Packet] +func (h *sentPacketHistory) SentNonAckElicitingPacket(pn protocol.PacketNumber) { + h.checkSequentialPacketNumberUse(pn) + h.highestPacketNumber = pn + if len(h.packets) > 0 { + h.packets = append(h.packets, nil) + } +} + +func (h *sentPacketHistory) SentAckElicitingPacket(p *packet) { + h.checkSequentialPacketNumberUse(p.PacketNumber) + h.highestPacketNumber = p.PacketNumber + h.packets = append(h.packets, p) if p.outstanding() { - el = h.outstandingPacketList.PushBack(p) - } else { - el = h.etcPacketList.PushBack(p) + h.numOutstanding++ } - h.packetMap[p.PacketNumber] = el -} - -func (h *sentPacketHistory) registerSentPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, t time.Time) { - if pn <= h.highestSent { - panic("non-sequential packet number use") - } - // Skipped packet numbers. - for p := h.highestSent + 1; p < pn; p++ { - el := h.etcPacketList.PushBack(&Packet{ - PacketNumber: p, - EncryptionLevel: encLevel, - SendTime: t, - skippedPacket: true, - }) - h.packetMap[p] = el - } - h.highestSent = pn } // Iterate iterates through all packets. -func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { - cont := true - outstandingEl := h.outstandingPacketList.Front() - etcEl := h.etcPacketList.Front() - var el *list.Element[*Packet] - // whichever has the next packet number is returned first - for cont { - if outstandingEl == nil || (etcEl != nil && etcEl.Value.PacketNumber < outstandingEl.Value.PacketNumber) { - el = etcEl - } else { - el = outstandingEl +func (h *sentPacketHistory) Iterate(cb func(*packet) (cont bool, err error)) error { + for _, p := range h.packets { + if p == nil { + continue } - if el == nil { - return nil - } - if el == outstandingEl { - outstandingEl = outstandingEl.Next() - } else { - etcEl = etcEl.Next() - } - var err error - cont, err = cb(el.Value) + cont, err := cb(p) if err != nil { return err } + if !cont { + return nil + } } return nil } // FirstOutstanding returns the first outstanding packet. -func (h *sentPacketHistory) FirstOutstanding() *Packet { - el := h.outstandingPacketList.Front() - if el == nil { +func (h *sentPacketHistory) FirstOutstanding() *packet { + if !h.HasOutstandingPackets() { return nil } - return el.Value -} - -func (h *sentPacketHistory) Len() int { - return len(h.packetMap) -} - -func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { - el, ok := h.packetMap[p] - if !ok { - return fmt.Errorf("packet %d not found in sent packet history", p) + for _, p := range h.packets { + if p != nil && p.outstanding() { + return p + } } - el.List().Remove(el) - delete(h.packetMap, p) return nil } -func (h *sentPacketHistory) HasOutstandingPackets() bool { - return h.outstandingPacketList.Len() > 0 +func (h *sentPacketHistory) Len() int { + return len(h.packets) } -func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { - maxAge := 3 * h.rttStats.PTO(false) - var nextEl *list.Element[*Packet] - // we don't iterate outstandingPacketList, as we should not delete outstanding packets. - // being outstanding for more than 3*PTO should only happen in the case of drastic RTT changes. - for el := h.etcPacketList.Front(); el != nil; el = nextEl { - nextEl = el.Next() - p := el.Value - if p.SendTime.After(now.Add(-maxAge)) { - break - } - delete(h.packetMap, p.PacketNumber) - h.etcPacketList.Remove(el) - } -} - -func (h *sentPacketHistory) DeclareLost(p *Packet) *Packet { - el, ok := h.packetMap[p.PacketNumber] +func (h *sentPacketHistory) Remove(pn protocol.PacketNumber) error { + idx, ok := h.getIndex(pn) if !ok { - return nil + return fmt.Errorf("packet %d not found in sent packet history", pn) } - el.List().Remove(el) - p.declaredLost = true - // move it to the correct position in the etc list (based on the packet number) - for el = h.etcPacketList.Back(); el != nil; el = el.Prev() { - if el.Value.PacketNumber < p.PacketNumber { - break + p := h.packets[idx] + if p.outstanding() { + h.numOutstanding-- + if h.numOutstanding < 0 { + panic("negative number of outstanding packets") } } - if el == nil { - el = h.etcPacketList.PushFront(p) - } else { - el = h.etcPacketList.InsertAfter(p, el) + h.packets[idx] = nil + // clean up all skipped packets directly before this packet number + for idx > 0 { + idx-- + p := h.packets[idx] + if p == nil || !p.skippedPacket { + break + } + h.packets[idx] = nil + } + if idx == 0 { + h.cleanupStart() + } + if len(h.packets) > 0 && h.packets[0] == nil { + panic("remove failed") + } + return nil +} + +// getIndex gets the index of packet p in the packets slice. +func (h *sentPacketHistory) getIndex(p protocol.PacketNumber) (int, bool) { + if len(h.packets) == 0 { + return 0, false + } + first := h.packets[0].PacketNumber + if p < first { + return 0, false + } + index := int(p - first) + if index > len(h.packets)-1 { + return 0, false + } + return index, true +} + +func (h *sentPacketHistory) HasOutstandingPackets() bool { + return h.numOutstanding > 0 +} + +// delete all nil entries at the beginning of the packets slice +func (h *sentPacketHistory) cleanupStart() { + for i, p := range h.packets { + if p != nil { + h.packets = h.packets[i:] + return + } + } + h.packets = h.packets[:0] +} + +func (h *sentPacketHistory) LowestPacketNumber() protocol.PacketNumber { + if len(h.packets) == 0 { + return protocol.InvalidPacketNumber + } + return h.packets[0].PacketNumber +} + +func (h *sentPacketHistory) DeclareLost(pn protocol.PacketNumber) { + idx, ok := h.getIndex(pn) + if !ok { + return + } + p := h.packets[idx] + if p.outstanding() { + h.numOutstanding-- + if h.numOutstanding < 0 { + panic("negative number of outstanding packets") + } + } + h.packets[idx] = nil + if idx == 0 { + h.cleanupStart() } - h.packetMap[p.PacketNumber] = el - return el.Value } diff --git a/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go index dac3118e..2e508475 100644 --- a/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/cubic_sender.go @@ -120,8 +120,8 @@ func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time { return c.pacer.TimeUntilSend() } -func (c *cubicSender) HasPacingBudget() bool { - return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize +func (c *cubicSender) HasPacingBudget(now time.Time) bool { + return c.pacer.Budget(now) >= c.maxDatagramSize } func (c *cubicSender) maxCongestionWindow() protocol.ByteCount { diff --git a/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go b/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go index 5db3ebae..484bd5f8 100644 --- a/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/interface.go @@ -9,7 +9,7 @@ import ( // A SendAlgorithm performs congestion control type SendAlgorithm interface { TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time - HasPacingBudget() bool + HasPacingBudget(now time.Time) bool OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) CanSend(bytesInFlight protocol.ByteCount) bool MaybeExitSlowStart() diff --git a/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go b/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go index a5861062..09ea2680 100644 --- a/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go +++ b/vendor/github.com/quic-go/quic-go/internal/congestion/pacer.go @@ -12,16 +12,16 @@ const maxBurstSizePackets = 10 // The pacer implements a token bucket pacing algorithm. type pacer struct { - budgetAtLastSent protocol.ByteCount - maxDatagramSize protocol.ByteCount - lastSentTime time.Time - getAdjustedBandwidth func() uint64 // in bytes/s + budgetAtLastSent protocol.ByteCount + maxDatagramSize protocol.ByteCount + lastSentTime time.Time + adjustedBandwidth func() uint64 // in bytes/s } func newPacer(getBandwidth func() Bandwidth) *pacer { p := &pacer{ maxDatagramSize: initialMaxDatagramSize, - getAdjustedBandwidth: func() uint64 { + adjustedBandwidth: func() uint64 { // Bandwidth is in bits/s. We need the value in bytes/s. bw := uint64(getBandwidth() / BytesPerSecond) // Use a slightly higher value than the actual measured bandwidth. @@ -49,13 +49,16 @@ func (p *pacer) Budget(now time.Time) protocol.ByteCount { if p.lastSentTime.IsZero() { return p.maxBurstSize() } - budget := p.budgetAtLastSent + (protocol.ByteCount(p.getAdjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + budget := p.budgetAtLastSent + (protocol.ByteCount(p.adjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + if budget < 0 { // protect against overflows + budget = protocol.MaxByteCount + } return utils.Min(p.maxBurstSize(), budget) } func (p *pacer) maxBurstSize() protocol.ByteCount { return utils.Max( - protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9, + protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9, maxBurstSizePackets*p.maxDatagramSize, ) } @@ -68,7 +71,7 @@ func (p *pacer) TimeUntilSend() time.Time { } return p.lastSentTime.Add(utils.Max( protocol.MinPacingDelay, - time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.adjustedBandwidth())))*time.Nanosecond, )) } diff --git a/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go b/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go index 0420a5f9..8c9c2a8f 100644 --- a/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go +++ b/vendor/github.com/quic-go/quic-go/internal/handshake/crypto_setup.go @@ -116,7 +116,7 @@ type cryptoSetup struct { clientHelloWritten bool clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written zeroRTTParametersChan chan<- *wire.TransportParameters - allow0RTT func() bool + allow0RTT bool rttStats *utils.RTTStats @@ -197,7 +197,7 @@ func NewCryptoSetupServer( tp *wire.TransportParameters, runner handshakeRunner, tlsConf *tls.Config, - allow0RTT func() bool, + allow0RTT bool, rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, @@ -210,14 +210,13 @@ func NewCryptoSetupServer( tp, runner, tlsConf, - allow0RTT != nil, + allow0RTT, rttStats, tracer, logger, protocol.PerspectiveServer, version, ) - cs.allow0RTT = allow0RTT cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) return cs } @@ -253,6 +252,7 @@ func newCryptoSetup( readEncLevel: protocol.EncryptionInitial, writeEncLevel: protocol.EncryptionInitial, runner: runner, + allow0RTT: enable0RTT, ourParams: tp, paramsChan: extHandler.TransportParameters(), rttStats: rttStats, @@ -503,7 +503,7 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") return false } - if !h.allow0RTT() { + if !h.allow0RTT { h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.") return false } diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/params.go b/vendor/github.com/quic-go/quic-go/internal/protocol/params.go index 60c86779..fe3a7562 100644 --- a/vendor/github.com/quic-go/quic-go/internal/protocol/params.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/params.go @@ -5,6 +5,9 @@ import "time" // DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use. const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB +// DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use. +const DesiredSendBufferSize = (1 << 20) * 2 // 2 MB + // InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets. const InitialPacketSizeIPv4 = 1252 diff --git a/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go b/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go index 8241e274..98bc9ffb 100644 --- a/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go +++ b/vendor/github.com/quic-go/quic-go/internal/protocol/protocol.go @@ -59,7 +59,10 @@ type StatelessResetToken [16]byte // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // UDP adds an additional 8 bytes. This is a total overhead of 48 bytes. // Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452. -const MaxPacketBufferSize ByteCount = 1452 +const MaxPacketBufferSize = 1452 + +// MaxLargePacketBufferSize is used when using GSO +const MaxLargePacketBufferSize = 20 * 1024 // MinInitialPacketSize is the minimum size an Initial packet is required to have. const MinInitialPacketSize = 1200 @@ -77,6 +80,9 @@ const MinConnectionIDLenInitial = 8 // DefaultAckDelayExponent is the default ack delay exponent const DefaultAckDelayExponent = 3 +// DefaultActiveConnectionIDLimit is the default active connection ID limit +const DefaultActiveConnectionIDLimit = 2 + // MaxAckDelayExponent is the maximum ack delay exponent const MaxAckDelayExponent = 20 diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/ringbuffer/ringbuffer.go b/vendor/github.com/quic-go/quic-go/internal/utils/ringbuffer/ringbuffer.go new file mode 100644 index 00000000..81a5ad44 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/internal/utils/ringbuffer/ringbuffer.go @@ -0,0 +1,86 @@ +package ringbuffer + +// A RingBuffer is a ring buffer. +// It acts as a heap that doesn't cause any allocations. +type RingBuffer[T any] struct { + ring []T + headPos, tailPos int + full bool +} + +// Init preallocs a buffer with a certain size. +func (r *RingBuffer[T]) Init(size int) { + r.ring = make([]T, size) +} + +// Len returns the number of elements in the ring buffer. +func (r *RingBuffer[T]) Len() int { + if r.full { + return len(r.ring) + } + if r.tailPos >= r.headPos { + return r.tailPos - r.headPos + } + return r.tailPos - r.headPos + len(r.ring) +} + +// Empty says if the ring buffer is empty. +func (r *RingBuffer[T]) Empty() bool { + return !r.full && r.headPos == r.tailPos +} + +// PushBack adds a new element. +// If the ring buffer is full, its capacity is increased first. +func (r *RingBuffer[T]) PushBack(t T) { + if r.full || len(r.ring) == 0 { + r.grow() + } + r.ring[r.tailPos] = t + r.tailPos++ + if r.tailPos == len(r.ring) { + r.tailPos = 0 + } + if r.tailPos == r.headPos { + r.full = true + } +} + +// PopFront returns the next element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) PopFront() T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue") + } + r.full = false + t := r.ring[r.headPos] + r.ring[r.headPos] = *new(T) + r.headPos++ + if r.headPos == len(r.ring) { + r.headPos = 0 + } + return t +} + +// Grow the maximum size of the queue. +// This method assume the queue is full. +func (r *RingBuffer[T]) grow() { + oldRing := r.ring + newSize := len(oldRing) * 2 + if newSize == 0 { + newSize = 1 + } + r.ring = make([]T, newSize) + headLen := copy(r.ring, oldRing[r.headPos:]) + copy(r.ring[headLen:], oldRing[:r.headPos]) + r.headPos, r.tailPos, r.full = 0, len(oldRing), false +} + +// Clear removes all elements. +func (r *RingBuffer[T]) Clear() { + var zeroValue T + for i := range r.ring { + r.ring[i] = zeroValue + } + r.headPos, r.tailPos, r.full = 0, 0, false +} diff --git a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go index 527539e1..2cd9a191 100644 --- a/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go +++ b/vendor/github.com/quic-go/quic-go/internal/utils/rtt_stats.go @@ -103,8 +103,12 @@ func (r *RTTStats) SetMaxAckDelay(mad time.Duration) { // SetInitialRTT sets the initial RTT. // It is used during the 0-RTT handshake when restoring the RTT stats from the session state. func (r *RTTStats) SetInitialRTT(t time.Duration) { + // On the server side, by the time we get to process the session ticket, + // we might already have obtained an RTT measurement. + // This can happen if we received the ClientHello in multiple pieces, and one of those pieces was lost. + // Discard the restored value. A fresh measurement is always better. if r.hasMeasurement { - panic("initial RTT set after first measurement") + return } r.smoothedRTT = t r.latestRTT = t diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go index f145c8b4..9b23cc25 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame.go @@ -22,19 +22,17 @@ type AckFrame struct { } // parseAckFrame reads an ACK frame -func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) { +func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.VersionNumber) error { ecn := typ == ackECNFrameType - frame := GetAckFrame() - la, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } largestAcked := protocol.PacketNumber(la) delay, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } delayTime := time.Duration(delay*1< largestAcked { - return nil, errors.New("invalid first ACK range") + return errors.New("invalid first ACK range") } smallest := largestAcked - ackBlock @@ -65,41 +63,50 @@ func parseAckFrame(r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protoc for i := uint64(0); i < numBlocks; i++ { g, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } gap := protocol.PacketNumber(g) if smallest < gap+2 { - return nil, errInvalidAckRanges + return errInvalidAckRanges } largest := smallest - gap - 2 ab, err := quicvarint.Read(r) if err != nil { - return nil, err + return err } ackBlock := protocol.PacketNumber(ab) if ackBlock > largest { - return nil, errInvalidAckRanges + return errInvalidAckRanges } smallest = largest - ackBlock frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest}) } if !frame.validateAckRanges() { - return nil, errInvalidAckRanges + return errInvalidAckRanges } - // parse (and skip) the ECN section if ecn { - for i := 0; i < 3; i++ { - if _, err := quicvarint.Read(r); err != nil { - return nil, err - } + ect0, err := quicvarint.Read(r) + if err != nil { + return err } + frame.ECT0 = ect0 + ect1, err := quicvarint.Read(r) + if err != nil { + return err + } + frame.ECT1 = ect1 + ecnce, err := quicvarint.Read(r) + if err != nil { + return err + } + frame.ECNCE = ecnce } - return frame, nil + return nil } // Append appends an ACK frame. @@ -242,6 +249,18 @@ func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool { return p <= f.AckRanges[i].Largest } +func (f *AckFrame) Reset() { + f.DelayTime = 0 + f.ECT0 = 0 + f.ECT1 = 0 + f.ECNCE = 0 + for _, r := range f.AckRanges { + r.Largest = 0 + r.Smallest = 0 + } + f.AckRanges = f.AckRanges[:0] +} + func encodeAckDelay(delay time.Duration) uint64 { return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent))) } diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame_pool.go b/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame_pool.go deleted file mode 100644 index a0c6a21d..00000000 --- a/vendor/github.com/quic-go/quic-go/internal/wire/ack_frame_pool.go +++ /dev/null @@ -1,24 +0,0 @@ -package wire - -import "sync" - -var ackFramePool = sync.Pool{New: func() any { - return &AckFrame{} -}} - -func GetAckFrame() *AckFrame { - f := ackFramePool.Get().(*AckFrame) - f.AckRanges = f.AckRanges[:0] - f.ECNCE = 0 - f.ECT0 = 0 - f.ECT1 = 0 - f.DelayTime = 0 - return f -} - -func PutAckFrame(f *AckFrame) { - if cap(f.AckRanges) > 4 { - return - } - ackFramePool.Put(f) -} diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go b/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go index e624df94..ff35dd10 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/frame_parser.go @@ -39,9 +39,12 @@ const ( type frameParser struct { r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them - ackDelayExponent uint8 - + ackDelayExponent uint8 supportsDatagrams bool + + // To avoid allocating when parsing, keep a single ACK frame struct. + // It is used over and over again. + ackFrame *AckFrame } var _ FrameParser = &frameParser{} @@ -51,6 +54,7 @@ func NewFrameParser(supportsDatagrams bool) *frameParser { return &frameParser{ r: *bytes.NewReader(nil), supportsDatagrams: supportsDatagrams, + ackFrame: &AckFrame{}, } } @@ -105,7 +109,9 @@ func (p *frameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol. if encLevel != protocol.Encryption1RTT { ackDelayExponent = protocol.DefaultAckDelayExponent } - frame, err = parseAckFrame(r, typ, ackDelayExponent, v) + p.ackFrame.Reset() + err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v) + frame = p.ackFrame case resetStreamFrameType: frame, err = parseResetStreamFrame(r, v) case stopSendingFrameType: diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/header.go b/vendor/github.com/quic-go/quic-go/internal/wire/header.go index 6e8d4f9f..37f48cce 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/header.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/header.go @@ -13,8 +13,6 @@ import ( ) // ParseConnectionID parses the destination connection ID of a packet. -// It uses the data slice for the connection ID. -// That means that the connection ID must not be used after the packet buffer is released. func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { if len(data) == 0 { return protocol.ConnectionID{}, io.EOF diff --git a/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go b/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go index 8eb4cf46..1ec1e59d 100644 --- a/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go +++ b/vendor/github.com/quic-go/quic-go/internal/wire/transport_parameters.go @@ -2,14 +2,13 @@ package wire import ( "bytes" + "crypto/rand" "encoding/binary" "errors" "fmt" "io" - "math/rand" "net" "sort" - "sync" "time" "github.com/quic-go/quic-go/internal/protocol" @@ -26,15 +25,6 @@ var AdditionalTransportParametersClient map[uint64][]byte const transportParameterMarshalingVersion = 1 -var ( - randomMutex sync.Mutex - random rand.Rand -) - -func init() { - random = *rand.New(rand.NewSource(time.Now().UnixNano())) -} - type transportParameterID uint64 const ( @@ -118,6 +108,7 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec var ( readOriginalDestinationConnectionID bool readInitialSourceConnectionID bool + readActiveConnectionIDLimit bool ) p.AckDelayExponent = protocol.DefaultAckDelayExponent @@ -139,6 +130,9 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec } parameterIDs = append(parameterIDs, paramID) switch paramID { + case activeConnectionIDLimitParameterID: + readActiveConnectionIDLimit = true + fallthrough case maxIdleTimeoutParameterID, maxUDPPayloadSizeParameterID, initialMaxDataParameterID, @@ -148,7 +142,6 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec initialMaxStreamsBidiParameterID, initialMaxStreamsUniParameterID, maxAckDelayParameterID, - activeConnectionIDLimitParameterID, maxDatagramFrameSizeParameterID, ackDelayExponentParameterID: if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil { @@ -196,6 +189,9 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec } } + if !readActiveConnectionIDLimit { + p.ActiveConnectionIDLimit = protocol.DefaultActiveConnectionIDLimit + } if !fromSessionTicket { if sentBy == protocol.PerspectiveServer && !readOriginalDestinationConnectionID { return errors.New("missing original_destination_connection_id") @@ -335,13 +331,12 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { b := make([]byte, 0, 256) // add a greased value - b = quicvarint.Append(b, uint64(27+31*rand.Intn(100))) - randomMutex.Lock() - length := random.Intn(16) + random := make([]byte, 18) + rand.Read(random) + b = quicvarint.Append(b, 27+31*uint64(random[0])) + length := random[1] % 16 b = quicvarint.Append(b, uint64(length)) - b = b[:len(b)+length] - random.Read(b[len(b)-length:]) - randomMutex.Unlock() + b = append(b, random[2:2+length]...) // initial_max_stream_data_bidi_local b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) @@ -402,7 +397,9 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { } } // active_connection_id_limit - b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + if p.ActiveConnectionIDLimit != protocol.DefaultActiveConnectionIDLimit { + b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) + } // initial_source_connection_id b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID)) b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len())) diff --git a/vendor/github.com/quic-go/quic-go/logging/interface.go b/vendor/github.com/quic-go/quic-go/logging/interface.go index efcef151..2ce8582e 100644 --- a/vendor/github.com/quic-go/quic-go/logging/interface.go +++ b/vendor/github.com/quic-go/quic-go/logging/interface.go @@ -3,7 +3,6 @@ package logging import ( - "context" "net" "time" @@ -101,12 +100,6 @@ type ShortHeader struct { // A Tracer traces events. type Tracer interface { - // TracerForConnection requests a new tracer for a connection. - // The ODCID is the original destination connection ID: - // The destination connection ID that the client used on the first Initial packet it sent on this connection. - // If nil is returned, tracing will be disabled for this connection. - TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer - SentPacket(net.Addr, *Header, ByteCount, []Frame) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) diff --git a/vendor/github.com/quic-go/quic-go/logging/multiplex.go b/vendor/github.com/quic-go/quic-go/logging/multiplex.go index 8e85db49..672a5cdb 100644 --- a/vendor/github.com/quic-go/quic-go/logging/multiplex.go +++ b/vendor/github.com/quic-go/quic-go/logging/multiplex.go @@ -1,7 +1,6 @@ package logging import ( - "context" "net" "time" ) @@ -23,16 +22,6 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer { return &tracerMultiplexer{tracers} } -func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer { - var connTracers []ConnectionTracer - for _, t := range m.tracers { - if ct := t.TracerForConnection(ctx, p, odcid); ct != nil { - connTracers = append(connTracers, ct) - } - } - return NewMultiplexedConnectionTracer(connTracers...) -} - func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) { for _, t := range m.tracers { t.SentPacket(remote, hdr, size, frames) diff --git a/vendor/github.com/quic-go/quic-go/logging/null_tracer.go b/vendor/github.com/quic-go/quic-go/logging/null_tracer.go index 38052ae3..de970385 100644 --- a/vendor/github.com/quic-go/quic-go/logging/null_tracer.go +++ b/vendor/github.com/quic-go/quic-go/logging/null_tracer.go @@ -1,7 +1,6 @@ package logging import ( - "context" "net" "time" ) @@ -12,9 +11,6 @@ type NullTracer struct{} var _ Tracer = &NullTracer{} -func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer { - return NullConnectionTracer{} -} func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { } diff --git a/vendor/github.com/quic-go/quic-go/mockgen.go b/vendor/github.com/quic-go/quic-go/mockgen.go index 443e9c10..eb700864 100644 --- a/vendor/github.com/quic-go/quic-go/mockgen.go +++ b/vendor/github.com/quic-go/quic-go/mockgen.go @@ -65,9 +65,6 @@ type UnknownPacketHandler = unknownPacketHandler //go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager" type PacketHandlerManager = packetHandlerManager -//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_multiplexer_test.go github.com/quic-go/quic-go Multiplexer" -type Multiplexer = multiplexer - // Need to use source mode for the batchConn, since reflect mode follows type aliases. // See https://github.com/golang/mock/issues/244 for details. // diff --git a/vendor/github.com/quic-go/quic-go/mtu_discoverer.go b/vendor/github.com/quic-go/quic-go/mtu_discoverer.go index 5a8484c7..317b0929 100644 --- a/vendor/github.com/quic-go/quic-go/mtu_discoverer.go +++ b/vendor/github.com/quic-go/quic-go/mtu_discoverer.go @@ -1,6 +1,7 @@ package quic import ( + "net" "time" "github.com/quic-go/quic-go/internal/ackhandler" @@ -10,7 +11,11 @@ import ( ) type mtuDiscoverer interface { + // Start starts the MTU discovery process. + // It's unnecessary to call ShouldSendProbe before that. + Start(maxPacketSize protocol.ByteCount) ShouldSendProbe(now time.Time) bool + CurrentSize() protocol.ByteCount GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount) } @@ -22,25 +27,38 @@ const ( mtuProbeDelay = 5 ) +func getMaxPacketSize(addr net.Addr) protocol.ByteCount { + maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if utils.IsIPv4(udpAddr.IP) { + maxSize = protocol.InitialPacketSizeIPv4 + } else { + maxSize = protocol.InitialPacketSizeIPv6 + } + } + return maxSize +} + type mtuFinder struct { lastProbeTime time.Time - probeInFlight bool mtuIncreased func(protocol.ByteCount) rttStats *utils.RTTStats + inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight current protocol.ByteCount max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer) } var _ mtuDiscoverer = &mtuFinder{} -func newMTUDiscoverer(rttStats *utils.RTTStats, start, max protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) mtuDiscoverer { +func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder { return &mtuFinder{ - current: start, - rttStats: rttStats, - lastProbeTime: time.Now(), // to make sure the first probe packet is not sent immediately - mtuIncreased: mtuIncreased, - max: max, + inFlight: protocol.InvalidByteCount, + current: start, + rttStats: rttStats, + mtuIncreased: mtuIncreased, } } @@ -48,8 +66,16 @@ func (f *mtuFinder) done() bool { return f.max-f.current <= maxMTUDiff+1 } +func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) { + f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately + f.max = maxPacketSize +} + func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { - if f.probeInFlight || f.done() { + if f.max == 0 || f.lastProbeTime.IsZero() { + return false + } + if f.inFlight != protocol.InvalidByteCount || f.done() { return false } return !now.Before(f.lastProbeTime.Add(mtuProbeDelay * f.rttStats.SmoothedRTT())) @@ -58,17 +84,36 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool { func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) { size := (f.max + f.current) / 2 f.lastProbeTime = time.Now() - f.probeInFlight = true + f.inFlight = size return ackhandler.Frame{ - Frame: &wire.PingFrame{}, - OnLost: func(wire.Frame) { - f.probeInFlight = false - f.max = size - }, - OnAcked: func(wire.Frame) { - f.probeInFlight = false - f.current = size - f.mtuIncreased(size) - }, + Frame: &wire.PingFrame{}, + Handler: (*mtuFinderAckHandler)(f), }, size } + +func (f *mtuFinder) CurrentSize() protocol.ByteCount { + return f.current +} + +type mtuFinderAckHandler mtuFinder + +var _ ackhandler.FrameHandler = &mtuFinderAckHandler{} + +func (h *mtuFinderAckHandler) OnAcked(wire.Frame) { + size := h.inFlight + if size == protocol.InvalidByteCount { + panic("OnAcked callback called although there's no MTU probe packet in flight") + } + h.inFlight = protocol.InvalidByteCount + h.current = size + h.mtuIncreased(size) +} + +func (h *mtuFinderAckHandler) OnLost(wire.Frame) { + size := h.inFlight + if size == protocol.InvalidByteCount { + panic("OnLost callback called although there's no MTU probe packet in flight") + } + h.max = size + h.inFlight = protocol.InvalidByteCount +} diff --git a/vendor/github.com/quic-go/quic-go/multiplexer.go b/vendor/github.com/quic-go/quic-go/multiplexer.go index 37d4e75c..85f7f403 100644 --- a/vendor/github.com/quic-go/quic-go/multiplexer.go +++ b/vendor/github.com/quic-go/quic-go/multiplexer.go @@ -6,7 +6,6 @@ import ( "sync" "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/logging" ) var ( @@ -14,30 +13,19 @@ var ( connMuxer multiplexer ) -type indexableConn interface { - LocalAddr() net.Addr -} +type indexableConn interface{ LocalAddr() net.Addr } type multiplexer interface { - AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error) + AddConn(conn indexableConn) RemoveConn(indexableConn) error } -type connManager struct { - connIDLen int - statelessResetKey *StatelessResetKey - tracer logging.Tracer - manager packetHandlerManager -} - // The connMultiplexer listens on multiple net.PacketConns and dispatches // incoming packets to the connection handler. type connMultiplexer struct { mutex sync.Mutex - conns map[string] /* LocalAddr().String() */ connManager - newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests - + conns map[string] /* LocalAddr().String() */ indexableConn logger utils.Logger } @@ -46,57 +34,38 @@ var _ multiplexer = &connMultiplexer{} func getMultiplexer() multiplexer { connMuxerOnce.Do(func() { connMuxer = &connMultiplexer{ - conns: make(map[string]connManager), - logger: utils.DefaultLogger.WithPrefix("muxer"), - newPacketHandlerManager: newPacketHandlerMap, + conns: make(map[string]indexableConn), + logger: utils.DefaultLogger.WithPrefix("muxer"), } }) return connMuxer } -func (m *connMultiplexer) AddConn( - c net.PacketConn, - connIDLen int, - statelessResetKey *StatelessResetKey, - tracer logging.Tracer, -) (packetHandlerManager, error) { +func (m *connMultiplexer) index(addr net.Addr) string { + return addr.Network() + " " + addr.String() +} + +func (m *connMultiplexer) AddConn(c indexableConn) { m.mutex.Lock() defer m.mutex.Unlock() - addr := c.LocalAddr() - connIndex := addr.Network() + " " + addr.String() + connIndex := m.index(c.LocalAddr()) p, ok := m.conns[connIndex] - if !ok { - manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) - if err != nil { - return nil, err - } - p = connManager{ - connIDLen: connIDLen, - statelessResetKey: statelessResetKey, - manager: manager, - tracer: tracer, - } - m.conns[connIndex] = p - } else { - if p.connIDLen != connIDLen { - return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) - } - if statelessResetKey != nil && p.statelessResetKey != statelessResetKey { - return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") - } - if tracer != p.tracer { - return nil, fmt.Errorf("cannot use different tracers on the same packet conn") - } + if ok { + // Panics if we're already listening on this connection. + // This is a safeguard because we're introducing a breaking API change, see + // https://github.com/quic-go/quic-go/issues/3727 for details. + // We'll remove this at a later time, when most users of the library have made the switch. + panic("connection already exists") // TODO: write a nice message } - return p.manager, nil + m.conns[connIndex] = p } func (m *connMultiplexer) RemoveConn(c indexableConn) error { m.mutex.Lock() defer m.mutex.Unlock() - connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() + connIndex := m.index(c.LocalAddr()) if _, ok := m.conns[connIndex]; !ok { return fmt.Errorf("cannote remove connection, connection is unknown") } diff --git a/vendor/github.com/quic-go/quic-go/packet_handler_map.go b/vendor/github.com/quic-go/quic-go/packet_handler_map.go index 1f643412..2a16773c 100644 --- a/vendor/github.com/quic-go/quic-go/packet_handler_map.go +++ b/vendor/github.com/quic-go/quic-go/packet_handler_map.go @@ -5,147 +5,86 @@ import ( "crypto/rand" "crypto/sha256" "errors" - "fmt" "hash" "io" - "log" "net" - "os" - "strconv" - "strings" "sync" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" - "github.com/quic-go/quic-go/internal/wire" - "github.com/quic-go/quic-go/logging" ) -// rawConn is a connection that allow reading of a receivedPacket. +type connCapabilities struct { + // This connection has the Don't Fragment (DF) bit set. + // This means it makes to run DPLPMTUD. + DF bool + // GSO (Generic Segmentation Offload) supported + GSO bool +} + +// rawConn is a connection that allow reading of a receivedPackeh. type rawConn interface { - ReadPacket() (*receivedPacket, error) - WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) + ReadPacket() (receivedPacket, error) + // The size parameter is used for GSO. + // If GSO is not support, len(b) must be equal to size. + WritePacket(b []byte, size uint16, addr net.Addr, oob []byte) (int, error) LocalAddr() net.Addr + SetReadDeadline(time.Time) error io.Closer + + capabilities() connCapabilities } type closePacket struct { payload []byte addr net.Addr - info *packetInfo + info packetInfo } -// The packetHandlerMap stores packetHandlers, identified by connection ID. -// It is used: -// * by the server to store connections -// * when multiplexing outgoing connections to store clients +type unknownPacketHandler interface { + handlePacket(receivedPacket) + setCloseError(error) +} + +var errListenerAlreadySet = errors.New("listener already set") + type packetHandlerMap struct { - mutex sync.Mutex + mutex sync.Mutex + handlers map[protocol.ConnectionID]packetHandler + resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - conn rawConn - connIDLen int - - closeQueue chan closePacket - - handlers map[protocol.ConnectionID]packetHandler - resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - server unknownPacketHandler - numZeroRTTEntries int - - listening chan struct{} // is closed when listen returns closed bool + closeChan chan struct{} + + enqueueClosePacket func(closePacket) deleteRetiredConnsAfter time.Duration - zeroRTTQueueDuration time.Duration - statelessResetEnabled bool - statelessResetMutex sync.Mutex - statelessResetHasher hash.Hash + statelessResetMutex sync.Mutex + statelessResetHasher hash.Hash - tracer logging.Tracer logger utils.Logger } var _ packetHandlerManager = &packetHandlerMap{} -func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { - conn, ok := c.(interface{ SetReadBuffer(int) error }) - if !ok { - return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") - } - size, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if size >= protocol.DesiredReceiveBufferSize { - logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) - return nil - } - if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { - return fmt.Errorf("failed to increase receive buffer size: %w", err) - } - newSize, err := inspectReadBuffer(c) - if err != nil { - return fmt.Errorf("failed to determine receive buffer size: %w", err) - } - if newSize == size { - return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - if newSize < protocol.DesiredReceiveBufferSize { - return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) - } - logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) - return nil -} - -// only print warnings about the UDP receive buffer size once -var receiveBufferWarningOnce sync.Once - -func newPacketHandlerMap( - c net.PacketConn, - connIDLen int, - statelessResetKey *StatelessResetKey, - tracer logging.Tracer, - logger utils.Logger, -) (packetHandlerManager, error) { - if err := setReceiveBuffer(c, logger); err != nil { - if !strings.Contains(err.Error(), "use of closed network connection") { - receiveBufferWarningOnce.Do(func() { - if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { - return - } - log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Receive-Buffer-Size for details.", err) - }) - } - } - conn, err := wrapConn(c) - if err != nil { - return nil, err - } - m := &packetHandlerMap{ - conn: conn, - connIDLen: connIDLen, - listening: make(chan struct{}), +func newPacketHandlerMap(key *StatelessResetKey, enqueueClosePacket func(closePacket), logger utils.Logger) *packetHandlerMap { + h := &packetHandlerMap{ + closeChan: make(chan struct{}), handlers: make(map[protocol.ConnectionID]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, - zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, - closeQueue: make(chan closePacket, 4), - statelessResetEnabled: statelessResetKey != nil, - tracer: tracer, + enqueueClosePacket: enqueueClosePacket, logger: logger, } - if m.statelessResetEnabled { - m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:]) + if key != nil { + h.statelessResetHasher = hmac.New(sha256.New, key[:]) } - go m.listen() - go m.runCloseQueue() - - if logger.Debug() { - go m.logUsage() + if h.logger.Debug() { + go h.logUsage() } - return m, nil + return h } func (h *packetHandlerMap) logUsage() { @@ -153,7 +92,7 @@ func (h *packetHandlerMap) logUsage() { var printedZero bool for { select { - case <-h.listening: + case <-h.closeChan: return case <-ticker.C: } @@ -174,6 +113,14 @@ func (h *packetHandlerMap) logUsage() { } } +func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) { + h.mutex.Lock() + defer h.mutex.Unlock() + + handler, ok := h.handlers[id] + return handler, ok +} + func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { h.mutex.Lock() defer h.mutex.Unlock() @@ -187,26 +134,17 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) return true } -func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() (packetHandler, bool)) bool { h.mutex.Lock() defer h.mutex.Unlock() - var q *zeroRTTQueue - if handler, ok := h.handlers[clientDestConnID]; ok { - q, ok = handler.(*zeroRTTQueue) - if !ok { - h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) - return false - } - q.retireTimer.Stop() - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } + if _, ok := h.handlers[clientDestConnID]; ok { + h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) + return false } - conn := fn() - if q != nil { - q.EnqueueAll(conn) + conn, ok := fn() + if !ok { + return false } h.handlers[clientDestConnID] = conn h.handlers[newConnID] = conn @@ -239,13 +177,8 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p var handler packetHandler if connClosePacket != nil { handler = newClosedLocalConn( - func(addr net.Addr, info *packetInfo) { - select { - case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}: - default: - // Oops, we're backlogged. - // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. - } + func(addr net.Addr, info packetInfo) { + h.enqueueClosePacket(closePacket{payload: connClosePacket, addr: addr, info: info}) }, pers, h.logger, @@ -272,17 +205,6 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p }) } -func (h *packetHandlerMap) runCloseQueue() { - for { - select { - case <-h.listening: - return - case p := <-h.closeQueue: - h.conn.WritePacket(p.payload, p.addr, p.info.OOB()) - } - } -} - func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { h.mutex.Lock() h.resetTokens[token] = handler @@ -295,19 +217,16 @@ func (h *packetHandlerMap) RemoveResetToken(token protocol.StatelessResetToken) h.mutex.Unlock() } -func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { +func (h *packetHandlerMap) GetByResetToken(token protocol.StatelessResetToken) (packetHandler, bool) { h.mutex.Lock() - h.server = s - h.mutex.Unlock() + defer h.mutex.Unlock() + + handler, ok := h.resetTokens[token] + return handler, ok } func (h *packetHandlerMap) CloseServer() { h.mutex.Lock() - if h.server == nil { - h.mutex.Unlock() - return - } - h.server = nil var wg sync.WaitGroup for _, handler := range h.handlers { if handler.getPerspective() == protocol.PerspectiveServer { @@ -323,23 +242,16 @@ func (h *packetHandlerMap) CloseServer() { wg.Wait() } -// Destroy closes the underlying connection and waits until listen() has returned. -// It does not close active connections. -func (h *packetHandlerMap) Destroy() error { - if err := h.conn.Close(); err != nil { - return err - } - <-h.listening // wait until listening returns - return nil -} - -func (h *packetHandlerMap) close(e error) error { +func (h *packetHandlerMap) Close(e error) { h.mutex.Lock() + if h.closed { h.mutex.Unlock() - return nil + return } + close(h.closeChan) + var wg sync.WaitGroup for _, handler := range h.handlers { wg.Add(1) @@ -348,129 +260,14 @@ func (h *packetHandlerMap) close(e error) error { wg.Done() }(handler) } - - if h.server != nil { - h.server.setCloseError(e) - } h.closed = true h.mutex.Unlock() wg.Wait() - return getMultiplexer().RemoveConn(h.conn) -} - -func (h *packetHandlerMap) listen() { - defer close(h.listening) - for { - p, err := h.conn.ReadPacket() - //nolint:staticcheck // SA1019 ignore this! - // TODO: This code is used to ignore wsa errors on Windows. - // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. - // See https://github.com/quic-go/quic-go/issues/1737 for details. - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { - h.logger.Debugf("Temporary error reading from conn: %w", err) - continue - } - if err != nil { - h.close(err) - return - } - h.handlePacket(p) - } -} - -func (h *packetHandlerMap) handlePacket(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, h.connIDLen) - if err != nil { - h.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) - if h.tracer != nil { - h.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) - } - p.buffer.MaybeRelease() - return - } - - h.mutex.Lock() - defer h.mutex.Unlock() - - if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { - return - } - - if handler, ok := h.handlers[connID]; ok { - if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue - if wire.Is0RTTPacket(p.data) { - ha.handlePacket(p) - return - } - } else { // existing connection - handler.handlePacket(p) - return - } - } - if !wire.IsLongHeaderPacket(p.data[0]) { - go h.maybeSendStatelessReset(p, connID) - return - } - if h.server == nil { // no server set - h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) - return - } - if wire.Is0RTTPacket(p.data) { - if h.numZeroRTTEntries >= protocol.Max0RTTQueues { - if h.tracer != nil { - h.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) - } - return - } - h.numZeroRTTEntries++ - queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} - h.handlers[connID] = queue - queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { - h.mutex.Lock() - defer h.mutex.Unlock() - // The entry might have been replaced by an actual connection. - // Only delete it if it's still a 0-RTT queue. - if handler, ok := h.handlers[connID]; ok { - if q, ok := handler.(*zeroRTTQueue); ok { - delete(h.handlers, connID) - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } - q.Clear() - if h.logger.Debug() { - h.logger.Debugf("Removing 0-RTT queue for %s.", connID) - } - } - } - }) - queue.handlePacket(p) - return - } - h.server.handlePacket(p) -} - -func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { - // stateless resets are always short header packets - if wire.IsLongHeaderPacket(data[0]) { - return false - } - if len(data) < 17 /* type byte + 16 bytes for the reset token */ { - return false - } - - token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - h.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) - go sess.destroy(&StatelessResetError{Token: token}) - return true - } - return false } func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) protocol.StatelessResetToken { var token protocol.StatelessResetToken - if !h.statelessResetEnabled { + if h.statelessResetHasher == nil { // Return a random stateless reset token. // This token will be sent in the server's transport parameters. // By using a random token, an off-path attacker won't be able to disrupt the connection. @@ -484,24 +281,3 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) h.statelessResetMutex.Unlock() return token } - -func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { - defer p.buffer.Release() - if !h.statelessResetEnabled { - return - } - // Don't send a stateless reset in response to very small packets. - // This includes packets that could be stateless resets. - if len(p.data) <= protocol.MinStatelessResetSize { - return - } - token := h.GetStatelessResetToken(connID) - h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) - data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) - rand.Read(data) - data[0] = (data[0] & 0x7f) | 0x40 - data = append(data, token[:]...) - if _, err := h.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { - h.logger.Debugf("Error sending Stateless Reset: %s", err) - } -} diff --git a/vendor/github.com/quic-go/quic-go/packet_packer.go b/vendor/github.com/quic-go/quic-go/packet_packer.go index 14befd46..577f3b04 100644 --- a/vendor/github.com/quic-go/quic-go/packet_packer.go +++ b/vendor/github.com/quic-go/quic-go/packet_packer.go @@ -3,30 +3,25 @@ package quic import ( "errors" "fmt" - "net" - "time" "github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" ) var errNothingToPack = errors.New("nothing to pack") type packer interface { - PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) - PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) - MaybePackProbePacket(protocol.EncryptionLevel, protocol.VersionNumber) (*coalescedPacket, error) - PackConnectionClose(*qerr.TransportError, protocol.VersionNumber) (*coalescedPacket, error) - PackApplicationClose(*qerr.ApplicationError, protocol.VersionNumber) (*coalescedPacket, error) + PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) + PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) + AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) + MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) + PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) + PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error) + PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) - SetMaxPacketSize(protocol.ByteCount) - PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) - - HandleTransportParameters(*wire.TransportParameters) SetToken([]byte) } @@ -35,26 +30,31 @@ type sealer interface { } type payload struct { - frames []*ackhandler.Frame - ack *wire.AckFrame - length protocol.ByteCount + streamFrames []ackhandler.StreamFrame + frames []ackhandler.Frame + ack *wire.AckFrame + length protocol.ByteCount } type longHeaderPacket struct { - header *wire.ExtendedHeader - ack *wire.AckFrame - frames []*ackhandler.Frame + header *wire.ExtendedHeader + ack *wire.AckFrame + frames []ackhandler.Frame + streamFrames []ackhandler.StreamFrame // only used for 0-RTT packets length protocol.ByteCount - - isMTUProbePacket bool } type shortHeaderPacket struct { - *ackhandler.Packet + PacketNumber protocol.PacketNumber + Frames []ackhandler.Frame + StreamFrames []ackhandler.StreamFrame + Ack *wire.AckFrame + Length protocol.ByteCount + IsPathMTUProbePacket bool + // used for logging DestConnID protocol.ConnectionID - Ack *wire.AckFrame PacketNumberLen protocol.PacketNumberLen KeyPhase protocol.KeyPhaseBit } @@ -83,52 +83,6 @@ func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) } -func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { - largestAcked := protocol.InvalidPacketNumber - if p.ack != nil { - largestAcked = p.ack.LargestAcked() - } - encLevel := p.EncryptionLevel() - for i := range p.frames { - if p.frames[i].OnLost != nil { - continue - } - //nolint:exhaustive // Short header packets are handled separately. - switch encLevel { - case protocol.EncryptionInitial: - p.frames[i].OnLost = q.AddInitial - case protocol.EncryptionHandshake: - p.frames[i].OnLost = q.AddHandshake - case protocol.Encryption0RTT: - p.frames[i].OnLost = q.AddAppData - } - } - - ap := ackhandler.GetPacket() - ap.PacketNumber = p.header.PacketNumber - ap.LargestAcked = largestAcked - ap.Frames = p.frames - ap.Length = p.length - ap.EncryptionLevel = encLevel - ap.SendTime = now - ap.IsPathMTUProbePacket = p.isMTUProbePacket - return ap -} - -func getMaxPacketSize(addr net.Addr) protocol.ByteCount { - maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) - // If this is not a UDP address, we don't know anything about the MTU. - // Use the minimum size of an Initial packet as the max packet size. - if udpAddr, ok := addr.(*net.UDPAddr); ok { - if utils.IsIPv4(udpAddr.IP) { - maxSize = protocol.InitialPacketSizeIPv4 - } else { - maxSize = protocol.InitialPacketSizeIPv6 - } - } - return maxSize -} - type packetNumberManager interface { PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber @@ -143,8 +97,8 @@ type sealingManager interface { type frameSource interface { HasData() bool - AppendStreamFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) - AppendControlFrames([]*ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]*ackhandler.Frame, protocol.ByteCount) + AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount) + AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) } type ackFrameSource interface { @@ -169,13 +123,23 @@ type packetPacker struct { datagramQueue *datagramQueue retransmissionQueue *retransmissionQueue - maxPacketSize protocol.ByteCount numNonAckElicitingAcks int } var _ packer = &packetPacker{} -func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() protocol.ConnectionID, initialStream cryptoStream, handshakeStream cryptoStream, packetNumberManager packetNumberManager, retransmissionQueue *retransmissionQueue, remoteAddr net.Addr, cryptoSetup sealingManager, framer frameSource, acks ackFrameSource, datagramQueue *datagramQueue, perspective protocol.Perspective) *packetPacker { +func newPacketPacker( + srcConnID protocol.ConnectionID, + getDestConnID func() protocol.ConnectionID, + initialStream, handshakeStream cryptoStream, + packetNumberManager packetNumberManager, + retransmissionQueue *retransmissionQueue, + cryptoSetup sealingManager, + framer frameSource, + acks ackFrameSource, + datagramQueue *datagramQueue, + perspective protocol.Perspective, +) *packetPacker { return &packetPacker{ cryptoSetup: cryptoSetup, getDestConnID: getDestConnID, @@ -188,23 +152,22 @@ func newPacketPacker(srcConnID protocol.ConnectionID, getDestConnID func() proto framer: framer, acks: acks, pnManager: packetNumberManager, - maxPacketSize: getMaxPacketSize(remoteAddr), } } // PackConnectionClose packs a packet that closes the connection with a transport error. -func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { var reason string // don't send details of crypto errors if !e.ErrorCode.IsCryptoError() { reason = e.ErrorMessage } - return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, v) + return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, maxPacketSize, v) } // PackApplicationClose packs a packet that closes the connection with an application error. -func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, v protocol.VersionNumber) (*coalescedPacket, error) { - return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, v) +func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { + return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v) } func (p *packetPacker) packConnectionClose( @@ -212,6 +175,7 @@ func (p *packetPacker) packConnectionClose( errorCode uint64, frameType uint64, reason string, + maxPacketSize protocol.ByteCount, v protocol.VersionNumber, ) (*coalescedPacket, error) { var sealers [4]sealer @@ -241,7 +205,7 @@ func (p *packetPacker) packConnectionClose( ccf.ReasonPhrase = "" } pl := payload{ - frames: []*ackhandler.Frame{{Frame: ccf}}, + frames: []ackhandler.Frame{{Frame: ccf}}, length: ccf.Length(v), } @@ -293,20 +257,14 @@ func (p *packetPacker) packConnectionClose( } var paddingLen protocol.ByteCount if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payloads[i].frames, size) + paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize) } if encLevel == protocol.Encryption1RTT { - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false, v) + shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v) if err != nil { return nil, err } - packet.shortHdrPacket = &shortHeaderPacket{ - Packet: ap, - DestConnID: connID, - Ack: ack, - PacketNumberLen: oneRTTPacketNumberLen, - KeyPhase: keyPhase, - } + packet.shortHdrPacket = &shp } else { longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v) if err != nil { @@ -342,25 +300,21 @@ func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnL } // size is the expected size of the packet, if no padding was applied. -func (p *packetPacker) initialPaddingLen(frames []*ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { +func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, maxPacketSize protocol.ByteCount) protocol.ByteCount { // For the server, only ack-eliciting Initial packets need to be padded. if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) { return 0 } - if size >= p.maxPacketSize { + if currentSize >= maxPacketSize { return 0 } - return p.maxPacketSize - size + return maxPacketSize - currentSize } // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumber) (*coalescedPacket, error) { - maxPacketSize := p.maxPacketSize - if p.perspective == protocol.PerspectiveClient { - maxPacketSize = protocol.MinInitialPacketSize - } +func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { var ( initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload @@ -417,7 +371,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe if oneRTTPayload.length > 0 { size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) } - } else if p.perspective == protocol.PerspectiveClient { // 0-RTT + } else if p.perspective == protocol.PerspectiveClient && !onlyAck { // 0-RTT packets can't contain ACK frames var err error zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer() if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { @@ -442,7 +396,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe longHdrPackets: make([]*longHeaderPacket, 0, 3), } if initialPayload.length > 0 { - padding := p.initialPaddingLen(initialPayload.frames, size) + padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) if err != nil { return nil, err @@ -463,48 +417,44 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, v protocol.VersionNumbe } packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) } else if oneRTTPayload.length > 0 { - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false, v) + shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v) if err != nil { return nil, err } - packet.shortHdrPacket = &shortHeaderPacket{ - Packet: ap, - DestConnID: connID, - Ack: ack, - PacketNumberLen: oneRTTPacketNumberLen, - KeyPhase: kp, - } + packet.shortHdrPacket = &shp } return packet, nil } -// PackPacket packs a packet in the application data packet number space. +// PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackPacket(onlyAck bool, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { + buf := getPacketBuffer() + packet, err := p.appendPacket(buf, true, maxPacketSize, v) + return packet, buf, err +} + +// AppendPacket packs a packet in the application data packet number space. +// It should be called after the handshake is confirmed. +func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { + return p.appendPacket(buf, false, maxPacketSize, v) +} + +func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { - return shortHeaderPacket{}, nil, err + return shortHeaderPacket{}, err } pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) connID := p.getDestConnID() hdrLen := wire.ShortHeaderLen(connID, pnLen) - pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true, v) + pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, v) if pl.length == 0 { - return shortHeaderPacket{}, nil, errNothingToPack + return shortHeaderPacket{}, errNothingToPack } kp := sealer.KeyPhase() - buffer := getPacketBuffer() - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, sealer, false, v) - if err != nil { - return shortHeaderPacket{}, nil, err - } - return shortHeaderPacket{ - Packet: ap, - DestConnID: connID, - Ack: ack, - PacketNumberLen: pnLen, - KeyPhase: kp, - }, buffer, nil + + return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v) } func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) { @@ -519,14 +469,17 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en } var s cryptoStream + var handler ackhandler.FrameHandler var hasRetransmission bool //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. switch encLevel { case protocol.EncryptionInitial: s = p.initialStream + handler = p.retransmissionQueue.InitialAckHandler() hasRetransmission = p.retransmissionQueue.HasInitialData() case protocol.EncryptionHandshake: s = p.handshakeStream + handler = p.retransmissionQueue.HandshakeAckHandler() hasRetransmission = p.retransmissionQueue.HasHandshakeData() } @@ -550,27 +503,27 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en maxPacketSize -= hdr.GetLength(v) if hasRetransmission { for { - var f wire.Frame + var f ackhandler.Frame //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s switch encLevel { case protocol.EncryptionInitial: - f = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v) + f.Frame = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v) + f.Handler = p.retransmissionQueue.InitialAckHandler() case protocol.EncryptionHandshake: - f = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v) + f.Frame = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v) + f.Handler = p.retransmissionQueue.HandshakeAckHandler() } - if f == nil { + if f.Frame == nil { break } - af := ackhandler.GetFrame() - af.Frame = f - pl.frames = append(pl.frames, af) - frameLen := f.Length(v) + pl.frames = append(pl.frames, f) + frameLen := f.Frame.Length(v) pl.length += frameLen maxPacketSize -= frameLen } } else if s.HasData() { cf := s.PopCryptoFrame(maxPacketSize) - pl.frames = []*ackhandler.Frame{{Frame: cf}} + pl.frames = []ackhandler.Frame{{Frame: cf, Handler: handler}} pl.length += cf.Length(v) } return hdr, pl @@ -595,18 +548,14 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v) // check if we have anything to send - if len(pl.frames) == 0 { + if len(pl.frames) == 0 && len(pl.streamFrames) == 0 { if pl.ack == nil { return payload{} } // the packet only contains an ACK if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks { ping := &wire.PingFrame{} - // don't retransmit the PING frame when it is lost - af := ackhandler.GetFrame() - af.Frame = ping - af.OnLost = func(wire.Frame) {} - pl.frames = append(pl.frames, af) + pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping}) pl.length += ping.Length(v) p.numNonAckElicitingAcks = 0 } else { @@ -621,15 +570,12 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload { if onlyAck { if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { - return payload{ - ack: ack, - length: ack.Length(v), - } + return payload{ack: ack, length: ack.Length(v)} } return payload{} } - pl := payload{frames: make([]*ackhandler.Frame, 0, 1)} + pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)} hasData := p.framer.HasData() hasRetransmission := p.retransmissionQueue.HasAppData() @@ -647,11 +593,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc if f := p.datagramQueue.Peek(); f != nil { size := f.Length(v) if size <= maxFrameSize-pl.length { - af := ackhandler.GetFrame() - af.Frame = f - // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. - af.OnLost = func(wire.Frame) {} - pl.frames = append(pl.frames, af) + pl.frames = append(pl.frames, ackhandler.Frame{Frame: f}) pl.length += size p.datagramQueue.Pop() } @@ -672,25 +614,28 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc if f == nil { break } - af := ackhandler.GetFrame() - af.Frame = f - pl.frames = append(pl.frames, af) + pl.frames = append(pl.frames, ackhandler.Frame{Frame: f, Handler: p.retransmissionQueue.AppDataAckHandler()}) pl.length += f.Length(v) } } if hasData { var lengthAdded protocol.ByteCount + startLen := len(pl.frames) pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v) pl.length += lengthAdded + // add handlers for the control frames that were added + for i := startLen; i < len(pl.frames); i++ { + pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler() + } - pl.frames, lengthAdded = p.framer.AppendStreamFrames(pl.frames, maxFrameSize-pl.length, v) + pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v) pl.length += lengthAdded } return pl } -func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) (*coalescedPacket, error) { +func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { if encLevel == protocol.Encryption1RTT { s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { @@ -700,23 +645,17 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v connID := p.getDestConnID() pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdrLen := wire.ShortHeaderLen(connID, pnLen) - pl := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v) + pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v) if pl.length == 0 { return nil, nil } buffer := getPacketBuffer() packet := &coalescedPacket{buffer: buffer} - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, s, false, v) + shp, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v) if err != nil { return nil, err } - packet.shortHdrPacket = &shortHeaderPacket{ - Packet: ap, - DestConnID: connID, - Ack: ack, - PacketNumberLen: pnLen, - KeyPhase: kp, - } + packet.shortHdrPacket = &shp return packet, nil } @@ -731,14 +670,14 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v if err != nil { return nil, err } - hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) + hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) case protocol.EncryptionHandshake: var err error sealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil { return nil, err } - hdr, pl = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v) + hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v) default: panic("unknown encryption level") } @@ -751,7 +690,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead()) var padding protocol.ByteCount if encLevel == protocol.EncryptionInitial { - padding = p.initialPaddingLen(pl.frames, size) + padding = p.initialPaddingLen(pl.frames, size, maxPacketSize) } longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v) @@ -762,10 +701,10 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, v return packet, nil } -func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { +func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) { pl := payload{ - frames: []*ackhandler.Frame{&ping}, - length: ping.Length(v), + frames: []ackhandler.Frame{ping}, + length: ping.Frame.Length(v), } buffer := getPacketBuffer() s, err := p.cryptoSetup.Get1RTTSealer() @@ -776,17 +715,8 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead()) kp := s.KeyPhase() - ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, s, true, v) - if err != nil { - return shortHeaderPacket{}, nil, err - } - return shortHeaderPacket{ - Packet: ap, - DestConnID: connID, - Ack: ack, - PacketNumberLen: pnLen, - KeyPhase: kp, - }, buffer, nil + packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, size, s, true, v) + return packet, buffer, err } func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader { @@ -829,23 +759,22 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire } payloadOffset := protocol.ByteCount(len(raw)) - pn := p.pnManager.PopPacketNumber(encLevel) - if pn != header.PacketNumber { - return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") - } - raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) if err != nil { return nil, err } - raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen) + raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen) buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] + if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber { + return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber) + } return &longHeaderPacket{ - header: header, - ack: pl.ack, - frames: pl.frames, - length: protocol.ByteCount(len(raw)), + header: header, + ack: pl.ack, + frames: pl.frames, + streamFrames: pl.streamFrames, + length: protocol.ByteCount(len(raw)), }, nil } @@ -856,11 +785,11 @@ func (p *packetPacker) appendShortHeaderPacket( pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit, pl payload, - padding protocol.ByteCount, + padding, maxPacketSize protocol.ByteCount, sealer sealer, isMTUProbePacket bool, v protocol.VersionNumber, -) (*ackhandler.Packet, *wire.AckFrame, error) { +) (shortHeaderPacket, error) { var paddingLen protocol.ByteCount if pl.length < 4-protocol.ByteCount(pnLen) { paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length @@ -871,48 +800,36 @@ func (p *packetPacker) appendShortHeaderPacket( raw := buffer.Data[startLen:] raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp) if err != nil { - return nil, nil, err + return shortHeaderPacket{}, err } payloadOffset := protocol.ByteCount(len(raw)) - if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) { - return nil, nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") - } - raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) if err != nil { - return nil, nil, err + return shortHeaderPacket{}, err } if !isMTUProbePacket { - if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize { - return nil, nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > maxPacketSize { + return shortHeaderPacket{}, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, maxPacketSize) } } raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen)) buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] - // create the ackhandler.Packet - largestAcked := protocol.InvalidPacketNumber - if pl.ack != nil { - largestAcked = pl.ack.LargestAcked() + if newPN := p.pnManager.PopPacketNumber(protocol.Encryption1RTT); newPN != pn { + return shortHeaderPacket{}, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, newPN) } - for i := range pl.frames { - if pl.frames[i].OnLost != nil { - continue - } - pl.frames[i].OnLost = p.retransmissionQueue.AddAppData - } - - ap := ackhandler.GetPacket() - ap.PacketNumber = pn - ap.LargestAcked = largestAcked - ap.Frames = pl.frames - ap.Length = protocol.ByteCount(len(raw)) - ap.EncryptionLevel = protocol.Encryption1RTT - ap.SendTime = time.Now() - ap.IsPathMTUProbePacket = isMTUProbePacket - - return ap, pl.ack, nil + return shortHeaderPacket{ + PacketNumber: pn, + PacketNumberLen: pnLen, + KeyPhase: kp, + StreamFrames: pl.streamFrames, + Frames: pl.frames, + Ack: pl.ack, + Length: protocol.ByteCount(len(raw)), + DestConnID: connID, + IsPathMTUProbePacket: isMTUProbePacket, + }, nil } func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { @@ -927,9 +844,16 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr if paddingLen > 0 { raw = append(raw, make([]byte, paddingLen)...) } - for _, frame := range pl.frames { + for _, f := range pl.frames { var err error - raw, err = frame.Append(raw, v) + raw, err = f.Frame.Append(raw, v) + if err != nil { + return nil, err + } + } + for _, f := range pl.streamFrames { + var err error + raw, err = f.Frame.Append(raw, v) if err != nil { return nil, err } @@ -953,16 +877,3 @@ func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.Pack func (p *packetPacker) SetToken(token []byte) { p.token = token } - -// When a higher MTU is discovered, use it. -func (p *packetPacker) SetMaxPacketSize(s protocol.ByteCount) { - p.maxPacketSize = s -} - -// If the peer sets a max_packet_size that's smaller than the size we're currently using, -// we need to reduce the size of packets we send. -func (p *packetPacker) HandleTransportParameters(params *wire.TransportParameters) { - if params.MaxUDPPayloadSize != 0 { - p.maxPacketSize = utils.Min(p.maxPacketSize, params.MaxUDPPayloadSize) - } -} diff --git a/vendor/github.com/quic-go/quic-go/quicvarint/varint.go b/vendor/github.com/quic-go/quic-go/quicvarint/varint.go index cbebfe61..3f12c076 100644 --- a/vendor/github.com/quic-go/quic-go/quicvarint/varint.go +++ b/vendor/github.com/quic-go/quic-go/quicvarint/varint.go @@ -70,25 +70,6 @@ func Read(r io.ByteReader) (uint64, error) { return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil } -// Write writes i in the QUIC varint format to w. -// Deprecated: use Append instead. -func Write(w Writer, i uint64) { - if i <= maxVarInt1 { - w.WriteByte(uint8(i)) - } else if i <= maxVarInt2 { - w.Write([]byte{uint8(i>>8) | 0x40, uint8(i)}) - } else if i <= maxVarInt4 { - w.Write([]byte{uint8(i>>24) | 0x80, uint8(i >> 16), uint8(i >> 8), uint8(i)}) - } else if i <= maxVarInt8 { - w.Write([]byte{ - uint8(i>>56) | 0xc0, uint8(i >> 48), uint8(i >> 40), uint8(i >> 32), - uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i), - }) - } else { - panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) - } -} - // Append appends i in the QUIC varint format. func Append(b []byte, i uint64) []byte { if i <= maxVarInt1 { diff --git a/vendor/github.com/quic-go/quic-go/receive_stream.go b/vendor/github.com/quic-go/quic-go/receive_stream.go index 0a7e9416..89d02b73 100644 --- a/vendor/github.com/quic-go/quic-go/receive_stream.go +++ b/vendor/github.com/quic-go/quic-go/receive_stream.go @@ -179,6 +179,10 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast { s.finRead = true + s.currentFrame = nil + if s.currentFrameDone != nil { + s.currentFrameDone() + } return true, bytesRead, io.EOF } } diff --git a/vendor/github.com/quic-go/quic-go/retransmission_queue.go b/vendor/github.com/quic-go/quic-go/retransmission_queue.go index 2ce0b893..909ad622 100644 --- a/vendor/github.com/quic-go/quic-go/retransmission_queue.go +++ b/vendor/github.com/quic-go/quic-go/retransmission_queue.go @@ -3,6 +3,8 @@ package quic import ( "fmt" + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" ) @@ -21,7 +23,23 @@ func newRetransmissionQueue() *retransmissionQueue { return &retransmissionQueue{} } -func (q *retransmissionQueue) AddInitial(f wire.Frame) { +// AddPing queues a ping. +// It is used when a probe packet needs to be sent +func (q *retransmissionQueue) AddPing(encLevel protocol.EncryptionLevel) { + //nolint:exhaustive // Cannot send probe packets for 0-RTT. + switch encLevel { + case protocol.EncryptionInitial: + q.addInitial(&wire.PingFrame{}) + case protocol.EncryptionHandshake: + q.addHandshake(&wire.PingFrame{}) + case protocol.Encryption1RTT: + q.addAppData(&wire.PingFrame{}) + default: + panic("unexpected encryption level") + } +} + +func (q *retransmissionQueue) addInitial(f wire.Frame) { if cf, ok := f.(*wire.CryptoFrame); ok { q.initialCryptoData = append(q.initialCryptoData, cf) return @@ -29,7 +47,7 @@ func (q *retransmissionQueue) AddInitial(f wire.Frame) { q.initial = append(q.initial, f) } -func (q *retransmissionQueue) AddHandshake(f wire.Frame) { +func (q *retransmissionQueue) addHandshake(f wire.Frame) { if cf, ok := f.(*wire.CryptoFrame); ok { q.handshakeCryptoData = append(q.handshakeCryptoData, cf) return @@ -49,7 +67,7 @@ func (q *retransmissionQueue) HasAppData() bool { return len(q.appData) > 0 } -func (q *retransmissionQueue) AddAppData(f wire.Frame) { +func (q *retransmissionQueue) addAppData(f wire.Frame) { if _, ok := f.(*wire.StreamFrame); ok { panic("STREAM frames are handled with their respective streams.") } @@ -127,3 +145,36 @@ func (q *retransmissionQueue) DropPackets(encLevel protocol.EncryptionLevel) { panic(fmt.Sprintf("unexpected encryption level: %s", encLevel)) } } + +func (q *retransmissionQueue) InitialAckHandler() ackhandler.FrameHandler { + return (*retransmissionQueueInitialAckHandler)(q) +} + +func (q *retransmissionQueue) HandshakeAckHandler() ackhandler.FrameHandler { + return (*retransmissionQueueHandshakeAckHandler)(q) +} + +func (q *retransmissionQueue) AppDataAckHandler() ackhandler.FrameHandler { + return (*retransmissionQueueAppDataAckHandler)(q) +} + +type retransmissionQueueInitialAckHandler retransmissionQueue + +func (q *retransmissionQueueInitialAckHandler) OnAcked(wire.Frame) {} +func (q *retransmissionQueueInitialAckHandler) OnLost(f wire.Frame) { + (*retransmissionQueue)(q).addInitial(f) +} + +type retransmissionQueueHandshakeAckHandler retransmissionQueue + +func (q *retransmissionQueueHandshakeAckHandler) OnAcked(wire.Frame) {} +func (q *retransmissionQueueHandshakeAckHandler) OnLost(f wire.Frame) { + (*retransmissionQueue)(q).addHandshake(f) +} + +type retransmissionQueueAppDataAckHandler retransmissionQueue + +func (q *retransmissionQueueAppDataAckHandler) OnAcked(wire.Frame) {} +func (q *retransmissionQueueAppDataAckHandler) OnLost(f wire.Frame) { + (*retransmissionQueue)(q).addAppData(f) +} diff --git a/vendor/github.com/quic-go/quic-go/send_conn.go b/vendor/github.com/quic-go/quic-go/send_conn.go index c53ebdfa..4e7007fa 100644 --- a/vendor/github.com/quic-go/quic-go/send_conn.go +++ b/vendor/github.com/quic-go/quic-go/send_conn.go @@ -1,38 +1,65 @@ package quic import ( + "math" "net" + + "github.com/quic-go/quic-go/internal/protocol" ) // A sendConn allows sending using a simple Write() on a non-connected packet conn. type sendConn interface { - Write([]byte) error + Write(b []byte, size protocol.ByteCount) error Close() error LocalAddr() net.Addr RemoteAddr() net.Addr + + capabilities() connCapabilities } type sconn struct { rawConn remoteAddr net.Addr - info *packetInfo + info packetInfo oob []byte } var _ sendConn = &sconn{} -func newSendConn(c rawConn, remote net.Addr, info *packetInfo) sendConn { +func newSendConn(c rawConn, remote net.Addr) *sconn { + sc := &sconn{ + rawConn: c, + remoteAddr: remote, + } + if c.capabilities().GSO { + // add 32 bytes, so we can add the UDP_SEGMENT msg + sc.oob = make([]byte, 0, 32) + } + return sc +} + +func newSendConnWithPacketInfo(c rawConn, remote net.Addr, info packetInfo) *sconn { + oob := info.OOB() + if c.capabilities().GSO { + // add 32 bytes, so we can add the UDP_SEGMENT msg + l := len(oob) + oob = append(oob, make([]byte, 32)...) + oob = oob[:l] + } return &sconn{ rawConn: c, remoteAddr: remote, info: info, - oob: info.OOB(), + oob: oob, } } -func (c *sconn) Write(p []byte) error { - _, err := c.WritePacket(p, c.remoteAddr, c.oob) +func (c *sconn) Write(p []byte, size protocol.ByteCount) error { + if size > math.MaxUint16 { + panic("size overflow") + } + _, err := c.WritePacket(p, uint16(size), c.remoteAddr, c.oob) return err } @@ -42,33 +69,12 @@ func (c *sconn) RemoteAddr() net.Addr { func (c *sconn) LocalAddr() net.Addr { addr := c.rawConn.LocalAddr() - if c.info != nil { + if c.info.addr.IsValid() { if udpAddr, ok := addr.(*net.UDPAddr); ok { addrCopy := *udpAddr - addrCopy.IP = c.info.addr + addrCopy.IP = c.info.addr.AsSlice() addr = &addrCopy } } return addr } - -type spconn struct { - net.PacketConn - - remoteAddr net.Addr -} - -var _ sendConn = &spconn{} - -func newSendPconn(c net.PacketConn, remote net.Addr) sendConn { - return &spconn{PacketConn: c, remoteAddr: remote} -} - -func (c *spconn) Write(p []byte) error { - _, err := c.WriteTo(p, c.remoteAddr) - return err -} - -func (c *spconn) RemoteAddr() net.Addr { - return c.remoteAddr -} diff --git a/vendor/github.com/quic-go/quic-go/send_queue.go b/vendor/github.com/quic-go/quic-go/send_queue.go index 9eafcd37..ab7a45ca 100644 --- a/vendor/github.com/quic-go/quic-go/send_queue.go +++ b/vendor/github.com/quic-go/quic-go/send_queue.go @@ -1,15 +1,22 @@ package quic +import "github.com/quic-go/quic-go/internal/protocol" + type sender interface { - Send(p *packetBuffer) + Send(p *packetBuffer, packetSize protocol.ByteCount) Run() error WouldBlock() bool Available() <-chan struct{} Close() } +type queueEntry struct { + buf *packetBuffer + size protocol.ByteCount +} + type sendQueue struct { - queue chan *packetBuffer + queue chan queueEntry closeCalled chan struct{} // runStopped when Close() is called runStopped chan struct{} // runStopped when the run loop returns available chan struct{} @@ -26,16 +33,16 @@ func newSendQueue(conn sendConn) sender { runStopped: make(chan struct{}), closeCalled: make(chan struct{}), available: make(chan struct{}, 1), - queue: make(chan *packetBuffer, sendQueueCapacity), + queue: make(chan queueEntry, sendQueueCapacity), } } // Send sends out a packet. It's guaranteed to not block. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Otherwise Send will panic. -func (h *sendQueue) Send(p *packetBuffer) { +func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) { select { - case h.queue <- p: + case h.queue <- queueEntry{buf: p, size: size}: // clear available channel if we've reached capacity if len(h.queue) == sendQueueCapacity { select { @@ -69,8 +76,8 @@ func (h *sendQueue) Run() error { h.closeCalled = nil // prevent this case from being selected again // make sure that all queued packets are actually sent out shouldClose = true - case p := <-h.queue: - if err := h.conn.Write(p.Data); err != nil { + case e := <-h.queue: + if err := h.conn.Write(e.buf.Data, e.size); err != nil { // This additional check enables: // 1. Checking for "datagram too large" message from the kernel, as such, // 2. Path MTU discovery,and @@ -79,7 +86,7 @@ func (h *sendQueue) Run() error { return err } } - p.Release() + e.buf.Release() select { case h.available <- struct{}{}: default: diff --git a/vendor/github.com/quic-go/quic-go/send_stream.go b/vendor/github.com/quic-go/quic-go/send_stream.go index cebe30ef..62ebe2ea 100644 --- a/vendor/github.com/quic-go/quic-go/send_stream.go +++ b/vendor/github.com/quic-go/quic-go/send_stream.go @@ -18,7 +18,7 @@ type sendStreamI interface { SendStream handleStopSendingFrame(*wire.StopSendingFrame) hasData() bool - popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) + popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (frame ackhandler.StreamFrame, ok, hasMore bool) closeForShutdown(error) updateSendWindow(protocol.ByteCount) } @@ -198,7 +198,7 @@ func (s *sendStream) canBufferStreamFrame() bool { // popStreamFrame returns the next STREAM frame that is supposed to be sent on this stream // maxBytes is the maximum length this frame (including frame header) will have. -func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool /* has more data to send */) { +func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (af ackhandler.StreamFrame, ok, hasMore bool) { s.mutex.Lock() f, hasMoreData := s.popNewOrRetransmittedStreamFrame(maxBytes, v) if f != nil { @@ -207,13 +207,12 @@ func (s *sendStream) popStreamFrame(maxBytes protocol.ByteCount, v protocol.Vers s.mutex.Unlock() if f == nil { - return nil, hasMoreData + return ackhandler.StreamFrame{}, false, hasMoreData } - af := ackhandler.GetFrame() - af.Frame = f - af.OnLost = s.queueRetransmission - af.OnAcked = s.frameAcked - return af, hasMoreData + return ackhandler.StreamFrame{ + Frame: f, + Handler: (*sendStreamAckHandler)(s), + }, true, hasMoreData } func (s *sendStream) popNewOrRetransmittedStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*wire.StreamFrame, bool /* has more data to send */) { @@ -348,26 +347,6 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By } } -func (s *sendStream) frameAcked(f wire.Frame) { - f.(*wire.StreamFrame).PutBack() - - s.mutex.Lock() - if s.cancelWriteErr != nil { - s.mutex.Unlock() - return - } - s.numOutstandingFrames-- - if s.numOutstandingFrames < 0 { - panic("numOutStandingFrames negative") - } - newlyCompleted := s.isNewlyCompleted() - s.mutex.Unlock() - - if newlyCompleted { - s.sender.onStreamCompleted(s.streamID) - } -} - func (s *sendStream) isNewlyCompleted() bool { completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0 if completed && !s.completed { @@ -377,24 +356,6 @@ func (s *sendStream) isNewlyCompleted() bool { return false } -func (s *sendStream) queueRetransmission(f wire.Frame) { - sf := f.(*wire.StreamFrame) - sf.DataLenPresent = true - s.mutex.Lock() - if s.cancelWriteErr != nil { - s.mutex.Unlock() - return - } - s.retransmissionQueue = append(s.retransmissionQueue, sf) - s.numOutstandingFrames-- - if s.numOutstandingFrames < 0 { - panic("numOutStandingFrames negative") - } - s.mutex.Unlock() - - s.sender.onHasStreamData(s.streamID) -} - func (s *sendStream) Close() error { s.mutex.Lock() if s.closeForShutdownErr != nil { @@ -487,3 +448,45 @@ func (s *sendStream) signalWrite() { default: } } + +type sendStreamAckHandler sendStream + +var _ ackhandler.FrameHandler = &sendStreamAckHandler{} + +func (s *sendStreamAckHandler) OnAcked(f wire.Frame) { + sf := f.(*wire.StreamFrame) + sf.PutBack() + s.mutex.Lock() + if s.cancelWriteErr != nil { + s.mutex.Unlock() + return + } + s.numOutstandingFrames-- + if s.numOutstandingFrames < 0 { + panic("numOutStandingFrames negative") + } + newlyCompleted := (*sendStream)(s).isNewlyCompleted() + s.mutex.Unlock() + + if newlyCompleted { + s.sender.onStreamCompleted(s.streamID) + } +} + +func (s *sendStreamAckHandler) OnLost(f wire.Frame) { + sf := f.(*wire.StreamFrame) + s.mutex.Lock() + if s.cancelWriteErr != nil { + s.mutex.Unlock() + return + } + sf.DataLenPresent = true + s.retransmissionQueue = append(s.retransmissionQueue, sf) + s.numOutstandingFrames-- + if s.numOutstandingFrames < 0 { + panic("numOutStandingFrames negative") + } + s.mutex.Unlock() + + s.sender.onHasStreamData(s.streamID) +} diff --git a/vendor/github.com/quic-go/quic-go/server.go b/vendor/github.com/quic-go/quic-go/server.go index edacdd85..0f8219e3 100644 --- a/vendor/github.com/quic-go/quic-go/server.go +++ b/vendor/github.com/quic-go/quic-go/server.go @@ -20,33 +20,29 @@ import ( ) // ErrServerClosed is returned by the Listener or EarlyListener's Accept method after a call to Close. -var ErrServerClosed = errors.New("quic: Server closed") +var ErrServerClosed = errors.New("quic: server closed") // packetHandler handles packets type packetHandler interface { - handlePacket(*receivedPacket) + handlePacket(receivedPacket) shutdown() destroy(error) getPerspective() protocol.Perspective } -type unknownPacketHandler interface { - handlePacket(*receivedPacket) - setCloseError(error) -} - type packetHandlerManager interface { - AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool - Destroy() error - connRunner - SetServer(unknownPacketHandler) + Get(protocol.ConnectionID) (packetHandler, bool) + GetByResetToken(protocol.StatelessResetToken) (packetHandler, bool) + AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() (packetHandler, bool)) bool + Close(error) CloseServer() + connRunner } type quicConn interface { EarlyConnection earlyConnReady() <-chan struct{} - handlePacket(*receivedPacket) + handlePacket(receivedPacket) GetVersion() protocol.VersionNumber getPerspective() protocol.Perspective run() error @@ -54,6 +50,11 @@ type quicConn interface { shutdown() } +type zeroRTTQueue struct { + packets []receivedPacket + expiration time.Time +} + // A Listener of QUIC type baseServer struct { mutex sync.Mutex @@ -64,15 +65,17 @@ type baseServer struct { config *Config conn rawConn - // If the server is started with ListenAddr, we create a packet conn. - // If it is started with Listen, we take a packet conn as a parameter. - createdPacketConn bool tokenGenerator *handshake.TokenGenerator - connHandler packetHandlerManager + connIDGenerator ConnectionIDGenerator + connHandler packetHandlerManager + onClose func() - receivedPackets chan *receivedPacket + receivedPackets chan receivedPacket + + nextZeroRTTCleanup time.Time + zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true // set as a member, so they can be set in the tests newConn func( @@ -83,6 +86,7 @@ type baseServer struct { protocol.ConnectionID, /* client dest connection ID */ protocol.ConnectionID, /* destination connection ID */ protocol.ConnectionID, /* source connection ID */ + ConnectionIDGenerator, protocol.StatelessResetToken, *Config, *tls.Config, @@ -94,128 +98,164 @@ type baseServer struct { protocol.VersionNumber, ) quicConn - serverError error - errorChan chan struct{} - closed bool - running chan struct{} // closed as soon as run() returns + serverError error + errorChan chan struct{} + closed bool + running chan struct{} // closed as soon as run() returns + versionNegotiationQueue chan receivedPacket + invalidTokenQueue chan receivedPacket connQueue chan quicConn connQueueLen int32 // to be used as an atomic + tracer logging.Tracer + logger utils.Logger } -var ( - _ Listener = &baseServer{} - _ unknownPacketHandler = &baseServer{} -) +// A Listener listens for incoming QUIC connections. +// It returns connections once the handshake has completed. +type Listener struct { + baseServer *baseServer +} -type earlyServer struct{ *baseServer } +// Accept returns new connections. It should be called in a loop. +func (l *Listener) Accept(ctx context.Context) (Connection, error) { + return l.baseServer.Accept(ctx) +} -var _ EarlyListener = &earlyServer{} +// Close the server. All active connections will be closed. +func (l *Listener) Close() error { + return l.baseServer.Close() +} -func (s *earlyServer) Accept(ctx context.Context) (EarlyConnection, error) { - return s.baseServer.accept(ctx) +// Addr returns the local network address that the server is listening on. +func (l *Listener) Addr() net.Addr { + return l.baseServer.Addr() +} + +// An EarlyListener listens for incoming QUIC connections, and returns them before the handshake completes. +// For connections that don't use 0-RTT, this allows the server to send 0.5-RTT data. +// This data is encrypted with forward-secure keys, however, the client's identity has not yet been verified. +// For connection using 0-RTT, this allows the server to accept and respond to streams that the client opened in the +// 0-RTT data it sent. Note that at this point during the handshake, the live-ness of the +// client has not yet been confirmed, and the 0-RTT data could have been replayed by an attacker. +type EarlyListener struct { + baseServer *baseServer +} + +// Accept returns a new connections. It should be called in a loop. +func (l *EarlyListener) Accept(ctx context.Context) (EarlyConnection, error) { + return l.baseServer.accept(ctx) +} + +// Close the server. All active connections will be closed. +func (l *EarlyListener) Close() error { + return l.baseServer.Close() +} + +// Addr returns the local network addr that the server is listening on. +func (l *EarlyListener) Addr() net.Addr { + return l.baseServer.Addr() } // ListenAddr creates a QUIC server listening on a given address. -// The tls.Config must not be nil and must contain a certificate configuration. -// The quic.Config may be nil, in that case the default values will be used. -func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { - return listenAddr(addr, tlsConf, config, false) -} - -// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. -func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (EarlyListener, error) { - s, err := listenAddr(addr, tlsConf, config, true) +// See Listen for more details. +func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (*Listener, error) { + conn, err := listenUDP(addr) if err != nil { return nil, err } - return &earlyServer{s}, nil + return (&Transport{ + Conn: conn, + createdConn: true, + isSingleUse: true, + }).Listen(tlsConf, config) } -func listenAddr(addr string, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { +// ListenAddrEarly works like ListenAddr, but it returns connections before the handshake completes. +func ListenAddrEarly(addr string, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { + conn, err := listenUDP(addr) + if err != nil { + return nil, err + } + return (&Transport{ + Conn: conn, + createdConn: true, + isSingleUse: true, + }).ListenEarly(tlsConf, config) +} + +func listenUDP(addr string) (*net.UDPConn, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return nil, err - } - serv, err := listen(conn, tlsConf, config, acceptEarly) - if err != nil { - return nil, err - } - serv.createdPacketConn = true - return serv, nil + return net.ListenUDP("udp", udpAddr) } -// Listen listens for QUIC connections on a given net.PacketConn. If the -// PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn -// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP -// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write -// packets. A single net.PacketConn only be used for a single call to Listen. -// The PacketConn can be used for simultaneous calls to Dial. QUIC connection -// IDs are used for demultiplexing the different connections. +// Listen listens for QUIC connections on a given net.PacketConn. +// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does), +// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP +// will be used instead of ReadFrom and WriteTo to read/write packets. +// A single net.PacketConn can only be used for a single call to Listen. +// // The tls.Config must not be nil and must contain a certificate configuration. // Furthermore, it must define an application control (using NextProtos). // The quic.Config may be nil, in that case the default values will be used. -func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { - return listen(conn, tlsConf, config, false) +// +// This is a convenience function. More advanced use cases should instantiate a Transport, +// which offers configuration options for a more fine-grained control of the connection establishment, +// including reusing the underlying UDP socket for outgoing QUIC connections. +func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*Listener, error) { + tr := &Transport{Conn: conn, isSingleUse: true} + return tr.Listen(tlsConf, config) } // ListenEarly works like Listen, but it returns connections before the handshake completes. -func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (EarlyListener, error) { - s, err := listen(conn, tlsConf, config, true) - if err != nil { - return nil, err - } - return &earlyServer{s}, nil +func ListenEarly(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*EarlyListener, error) { + tr := &Transport{Conn: conn, isSingleUse: true} + return tr.ListenEarly(tlsConf, config) } -func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarly bool) (*baseServer, error) { - if tlsConf == nil { - return nil, errors.New("quic: tls.Config not set") - } - if err := validateConfig(config); err != nil { - return nil, err - } - config = populateServerConfig(config) - for _, v := range config.Versions { - if !protocol.IsValidVersion(v) { - return nil, fmt.Errorf("%s is not a valid QUIC version", v) - } - } - - connHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer) - if err != nil { - return nil, err - } +func newServer( + conn rawConn, + connHandler packetHandlerManager, + connIDGenerator ConnectionIDGenerator, + tlsConf *tls.Config, + config *Config, + tracer logging.Tracer, + onClose func(), + acceptEarly bool, +) (*baseServer, error) { tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err } - c, err := wrapConn(conn) - if err != nil { - return nil, err - } s := &baseServer{ - conn: c, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), - newConn: newConnection, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connIDGenerator: connIDGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), + versionNegotiationQueue: make(chan receivedPacket, 4), + invalidTokenQueue: make(chan receivedPacket, 4), + newConn: newConnection, + tracer: tracer, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + onClose: onClose, + } + if acceptEarly { + s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} } go s.run() - connHandler.SetServer(s) + go s.runSendQueue() s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } @@ -239,6 +279,19 @@ func (s *baseServer) run() { } } +func (s *baseServer) runSendQueue() { + for { + select { + case <-s.running: + return + case p := <-s.versionNegotiationQueue: + s.maybeSendVersionNegotiationPacket(p) + case p := <-s.invalidTokenQueue: + s.maybeSendInvalidToken(p) + } + } +} + // Accept returns connections that already completed the handshake. // It is only valid if acceptEarlyConns is false. func (s *baseServer) Accept(ctx context.Context) (Connection, error) { @@ -267,18 +320,12 @@ func (s *baseServer) Close() error { if s.serverError == nil { s.serverError = ErrServerClosed } - // If the server was started with ListenAddr, we created the packet conn. - // We need to close it in order to make the go routine reading from that conn return. - createdPacketConn := s.createdPacketConn s.closed = true close(s.errorChan) s.mutex.Unlock() <-s.running - s.connHandler.CloseServer() - if createdPacketConn { - return s.connHandler.Destroy() - } + s.onClose() return nil } @@ -298,22 +345,26 @@ func (s *baseServer) Addr() net.Addr { return s.conn.LocalAddr() } -func (s *baseServer) handlePacket(p *receivedPacket) { +func (s *baseServer) handlePacket(p receivedPacket) { select { case s.receivedPackets <- p: default: s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) } } } -func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { +func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer still in use? */ { + if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) { + defer s.cleanupZeroRTTQueues(p.rcvTime) + } + if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -322,42 +373,54 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s panic(fmt.Sprintf("misrouted packet: %#v", p.data)) } v, err := wire.ParseVersion(p.data) - // send a Version Negotiation Packet if the client is speaking a different protocol version - if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) { - if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize { - s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) - if err != nil { // should never happen - s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - if !s.config.DisableVersionNegotiationPackets { - go s.sendVersionNegotiationPacket(p.remoteAddr, src, dest, p.info.OOB(), v) + // drop the packet if we failed to parse the protocol version + if err != nil { + s.logger.Debugf("Dropping a packet with an unknown version") + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false } + // send a Version Negotiation Packet if the client is speaking a different protocol version + if !protocol.IsSupportedVersion(s.config.Versions, v) { + if s.config.DisableVersionNegotiationPackets { + return false + } + + if p.Size() < protocol.MinUnknownVersionPacketSize { + s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size()) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + return s.enqueueVersionNegotiationPacket(p) + } + + if wire.Is0RTTPacket(p.data) { + if !s.acceptEarlyConns { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + return s.handle0RTTPacket(p) + } + // If we're creating a new connection, the packet will be passed to the connection. // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) return false } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -367,8 +430,8 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s // There's little point in sending a Stateless Reset, since the client // might not have received the token yet. s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) } return false } @@ -383,6 +446,74 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s return true } +func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { + connID, err := wire.ParseConnectionID(p.data, 0) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) + } + return false + } + + // check again if we might have a connection now + if handler, ok := s.connHandler.Get(connID); ok { + handler.handlePacket(p) + return true + } + + if q, ok := s.zeroRTTQueues[connID]; ok { + if len(q.packets) >= protocol.Max0RTTQueueLen { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + return false + } + q.packets = append(q.packets, p) + return true + } + + if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + return false + } + queue := &zeroRTTQueue{packets: make([]receivedPacket, 1, 8)} + queue.packets[0] = p + expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration) + queue.expiration = expiration + if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) { + s.nextZeroRTTCleanup = expiration + } + s.zeroRTTQueues[connID] = queue + return true +} + +func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { + // Iterate over all queues to find those that are expired. + // This is ok since we're placing a pretty low limit on the number of queues. + var nextCleanup time.Time + for connID, q := range s.zeroRTTQueues { + if q.expiration.After(now) { + if nextCleanup.IsZero() || nextCleanup.After(q.expiration) { + nextCleanup = q.expiration + } + continue + } + for _, p := range q.packets { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + p.buffer.Release() + } + delete(s.zeroRTTQueues, connID) + if s.logger.Debug() { + s.logger.Debugf("Removing 0-RTT queue for %s.", connID) + } + } + s.nextZeroRTTCleanup = nextCleanup +} + // validateToken returns false if: // - address is invalid // - token is expired @@ -403,15 +534,23 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool { return true } -func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) error { +func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { p.buffer.Release() - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) } return errors.New("too short connection ID") } + // The server queues packets for a while, and we might already have established a connection by now. + // This results in a second check in the connection map. + // That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets). + if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok { + handler.handlePacket(p) + return nil + } + var ( token *handshake.Token retrySrcConnID *protocol.ConnectionID @@ -429,7 +568,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } clientAddrIsValid := s.validateToken(token, p.remoteAddr) - if token != nil && !clientAddrIsValid { // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. // We just ignore them, and act as if there was no token on this packet at all. @@ -440,16 +578,13 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro // For Retry tokens, we send an INVALID_ERROR if // * the token is too old, or // * the token is invalid, in case of a retry token. - go func() { - defer p.buffer.Release() - if err := s.maybeSendInvalidToken(p, hdr); err != nil { - s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) - } - }() + s.enqueueInvalidToken(p) return nil } } if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { + // Retry invalidates all 0-RTT packets sent. + delete(s.zeroRTTQueues, hdr.DestConnectionID) go func() { defer p.buffer.Release() if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { @@ -470,37 +605,43 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return nil } - connID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() + connID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } s.logger.Debugf("Changing connection ID to %s.", connID) var conn quicConn tracingID := nextConnTracingID() - if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { + if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() (packetHandler, bool) { + config := s.config + if s.config.GetConfigForClient != nil { + conf, err := s.config.GetConfigForClient(&ClientHelloInfo{RemoteAddr: p.remoteAddr}) + if err != nil { + s.logger.Debugf("Rejecting new connection due to GetConfigForClient callback") + return nil, false + } + config = populateConfig(conf) + } var tracer logging.ConnectionTracer - if s.config.Tracer != nil { + if config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. connID := hdr.DestConnectionID if origDestConnID.Len() > 0 { connID = origDestConnID } - tracer = s.config.Tracer.TracerForConnection( - context.WithValue(context.Background(), ConnectionTracingKey, tracingID), - protocol.PerspectiveServer, - connID, - ) + tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID) } conn = s.newConn( - newSendConn(s.conn, p.remoteAddr, p.info), + newSendConnWithPacketInfo(s.conn, p.remoteAddr, p.info), s.connHandler, origDestConnID, retrySrcConnID, hdr.DestConnectionID, hdr.SrcConnectionID, connID, + s.connIDGenerator, s.connHandler.GetStatelessResetToken(connID), - s.config, + config, s.tlsConf, s.tokenGenerator, clientAddrIsValid, @@ -510,8 +651,22 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro hdr.Version, ) conn.handlePacket(p) - return conn + + if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok { + for _, p := range q.packets { + conn.handlePacket(p) + } + delete(s.zeroRTTQueues, hdr.DestConnectionID) + } + + return conn, true }); !added { + go func() { + defer p.buffer.Release() + if err := s.sendConnectionRefused(p.remoteAddr, hdr, p.info); err != nil { + s.logger.Debugf("Error rejecting connection: %s", err) + } + }() return nil } go conn.run() @@ -551,11 +706,11 @@ func (s *baseServer) handleNewConn(conn quicConn) { } } -func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { +func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error { // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the connection. (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - srcConnID, err := s.config.ConnectionIDGenerator.GenerateConnectionID() + srcConnID, err := s.connIDGenerator.GenerateConnectionID() if err != nil { return err } @@ -584,47 +739,69 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack // append the Retry integrity tag tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) buf.Data = append(buf.Data, tag[:]...) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) + if s.tracer != nil { + s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) } - _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(buf.Data, uint16(len(buf.Data)), remoteAddr, info.OOB()) return err } -func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { +func (s *baseServer) enqueueInvalidToken(p receivedPacket) { + select { + case s.invalidTokenQueue <- p: + default: + // it's fine to drop INVALID_TOKEN packets when we are busy + p.buffer.Release() + } +} + +func (s *baseServer) maybeSendInvalidToken(p receivedPacket) { + defer p.buffer.Release() + + hdr, _, _, err := wire.ParsePacket(p.data) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Error parsing packet: %s", err) + return + } + // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) data := p.data[:hdr.ParsedLen()+hdr.Length] extHdr, err := unpackLongHeader(opener, hdr, data, hdr.Version) + // Only send INVALID_TOKEN if we can unprotect the packet. + // This makes sure that we won't send it for packets that were corrupted. if err != nil { - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) } - // don't return the error here. Just drop the packet. - return nil + return } hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { - // don't return the error here. Just drop the packet. - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) } - return nil + return } if s.logger.Debug() { s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) } - return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info) + if err := s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) + } } -func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { +func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info packetInfo) error { sealer, _ := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) return s.sendError(remoteAddr, hdr, sealer, qerr.ConnectionRefused, info) } // sendError sends the error as a response to the packet received with header hdr -func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info *packetInfo) error { +func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer handshake.LongHeaderSealer, errorCode qerr.TransportErrorCode, info packetInfo) error { b := getPacketBuffer() defer b.Release() @@ -661,21 +838,48 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han replyHdr.Log(s.logger) wire.LogFrame(s.logger, ccf, true) - if s.config.Tracer != nil { - s.config.Tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) + if s.tracer != nil { + s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) } - _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) + _, err = s.conn.WritePacket(b.Data, uint16(len(b.Data)), remoteAddr, info.OOB()) return err } -func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest protocol.ArbitraryLenConnectionID, oob []byte, v protocol.VersionNumber) { +func (s *baseServer) enqueueVersionNegotiationPacket(p receivedPacket) (bufferInUse bool) { + select { + case s.versionNegotiationQueue <- p: + return true + default: + // it's fine to not send version negotiation packets when we are busy + } + return false +} + +func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) { + defer p.buffer.Release() + + v, err := wire.ParseVersion(p.data) + if err != nil { + s.logger.Debugf("failed to parse version for sending version negotiation packet: %s", err) + return + } + + _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) + if err != nil { // should never happen + s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return + } + s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) - if s.config.Tracer != nil { - s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) + if s.tracer != nil { + s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, remote, oob); err != nil { + if _, err := s.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/vendor/github.com/quic-go/quic-go/stream.go b/vendor/github.com/quic-go/quic-go/stream.go index 98d2fc6e..ab76eaf8 100644 --- a/vendor/github.com/quic-go/quic-go/stream.go +++ b/vendor/github.com/quic-go/quic-go/stream.go @@ -60,7 +60,7 @@ type streamI interface { // for sending hasData() bool handleStopSendingFrame(*wire.StopSendingFrame) - popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (*ackhandler.Frame, bool) + popStreamFrame(maxBytes protocol.ByteCount, v protocol.VersionNumber) (ackhandler.StreamFrame, bool, bool) updateSendWindow(protocol.ByteCount) } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn.go b/vendor/github.com/quic-go/quic-go/sys_conn.go index d6c1d616..29aa6add 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn.go @@ -1,6 +1,7 @@ package quic import ( + "fmt" "net" "syscall" "time" @@ -15,16 +16,38 @@ import ( type OOBCapablePacketConn interface { net.PacketConn SyscallConn() (syscall.RawConn, error) + SetReadBuffer(int) error ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) } var _ OOBCapablePacketConn = &net.UDPConn{} -func wrapConn(pc net.PacketConn) (rawConn, error) { +// OptimizeConn takes a net.PacketConn and attempts to enable various optimizations that will improve QUIC performance: +// 1. It enables the Don't Fragment (DF) bit on the IP header. +// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899). +// 2. It enables reading of the ECN bits from the IP header. +// This allows the remote node to speed up its loss detection and recovery. +// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket. +// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux). +// +// In order for this to work, the connection needs to implement the OOBCapablePacketConn interface (as a *net.UDPConn does). +// +// It's only necessary to call this function explicitly if the application calls WriteTo +// after passing the connection to the Transport. +func OptimizeConn(c net.PacketConn) (net.PacketConn, error) { + return wrapConn(c) +} + +func wrapConn(pc net.PacketConn) (interface { + net.PacketConn + rawConn +}, error, +) { conn, ok := pc.(interface { SyscallConn() (syscall.RawConn, error) }) + var supportsDF bool if ok { rawConn, err := conn.SyscallConn() if err != nil { @@ -33,7 +56,8 @@ func wrapConn(pc net.PacketConn) (rawConn, error) { if _, ok := pc.LocalAddr().(*net.UDPAddr); ok { // Only set DF on sockets that we expect to be able to handle that configuration. - err = setDF(rawConn) + var err error + supportsDF, err = setDF(rawConn) if err != nil { return nil, err } @@ -42,32 +66,33 @@ func wrapConn(pc net.PacketConn) (rawConn, error) { c, ok := pc.(OOBCapablePacketConn) if !ok { utils.DefaultLogger.Infof("PacketConn is not a net.UDPConn. Disabling optimizations possible on UDP connections.") - return &basicConn{PacketConn: pc}, nil + return &basicConn{PacketConn: pc, supportsDF: supportsDF}, nil } - return newConn(c) + return newConn(c, supportsDF) } -// The basicConn is the most trivial implementation of a connection. +// The basicConn is the most trivial implementation of a rawConn. // It reads a single packet from the underlying net.PacketConn. // It is used when // * the net.PacketConn is not a OOBCapablePacketConn, and // * when the OS doesn't support OOB. type basicConn struct { net.PacketConn + supportsDF bool } var _ rawConn = &basicConn{} -func (c *basicConn) ReadPacket() (*receivedPacket, error) { +func (c *basicConn) ReadPacket() (receivedPacket, error) { buffer := getPacketBuffer() // The packet size should not exceed protocol.MaxPacketBufferSize bytes // If it does, we only read a truncated packet, which will then end up undecryptable buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize] n, addr, err := c.PacketConn.ReadFrom(buffer.Data) if err != nil { - return nil, err + return receivedPacket{}, err } - return &receivedPacket{ + return receivedPacket{ remoteAddr: addr, rcvTime: time.Now(), data: buffer.Data[:n], @@ -75,6 +100,11 @@ func (c *basicConn) ReadPacket() (*receivedPacket, error) { }, nil } -func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { +func (c *basicConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, _ []byte) (n int, err error) { + if uint16(len(b)) != packetSize { + panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) + } return c.PacketConn.WriteTo(b, addr) } + +func (c *basicConn) capabilities() connCapabilities { return connCapabilities{DF: c.supportsDF} } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_buffers.go b/vendor/github.com/quic-go/quic-go/sys_conn_buffers.go new file mode 100644 index 00000000..0c4e024a --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/sys_conn_buffers.go @@ -0,0 +1,68 @@ +package quic + +import ( + "errors" + "fmt" + "net" + "syscall" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" +) + +//go:generate sh -c "echo '// Code generated by go generate. DO NOT EDIT.\n// Source: sys_conn_buffers.go\n' > sys_conn_buffers_write.go && sed -e 's/SetReadBuffer/SetWriteBuffer/g' -e 's/setReceiveBuffer/setSendBuffer/g' -e 's/inspectReadBuffer/inspectWriteBuffer/g' -e 's/protocol\\.DesiredReceiveBufferSize/protocol\\.DesiredSendBufferSize/g' -e 's/forceSetReceiveBuffer/forceSetSendBuffer/g' -e 's/receive buffer/send buffer/g' sys_conn_buffers.go | sed '/^\\/\\/go:generate/d' >> sys_conn_buffers_write.go" +func setReceiveBuffer(c net.PacketConn, logger utils.Logger) error { + conn, ok := c.(interface{ SetReadBuffer(int) error }) + if !ok { + return errors.New("connection doesn't allow setting of receive buffer size. Not a *net.UDPConn?") + } + + var syscallConn syscall.RawConn + if sc, ok := c.(interface { + SyscallConn() (syscall.RawConn, error) + }); ok { + var err error + syscallConn, err = sc.SyscallConn() + if err != nil { + syscallConn = nil + } + } + // The connection has a SetReadBuffer method, but we couldn't obtain a syscall.RawConn. + // This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the + // net.PacketConn interface and the SetReadBuffer method. + // We have no way of checking if increasing the buffer size actually worked. + if syscallConn == nil { + return conn.SetReadBuffer(protocol.DesiredReceiveBufferSize) + } + + size, err := inspectReadBuffer(syscallConn) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if size >= protocol.DesiredReceiveBufferSize { + logger.Debugf("Conn has receive buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024) + return nil + } + // Ignore the error. We check if we succeeded by querying the buffer size afterward. + _ = conn.SetReadBuffer(protocol.DesiredReceiveBufferSize) + newSize, err := inspectReadBuffer(syscallConn) + if newSize < protocol.DesiredReceiveBufferSize { + // Try again with RCVBUFFORCE on Linux + _ = forceSetReceiveBuffer(syscallConn, protocol.DesiredReceiveBufferSize) + newSize, err = inspectReadBuffer(syscallConn) + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + } + if err != nil { + return fmt.Errorf("failed to determine receive buffer size: %w", err) + } + if newSize == size { + return fmt.Errorf("failed to increase receive buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + if newSize < protocol.DesiredReceiveBufferSize { + return fmt.Errorf("failed to sufficiently increase receive buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredReceiveBufferSize/1024, newSize/1024) + } + logger.Debugf("Increased receive buffer size to %d kiB", newSize/1024) + return nil +} diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_buffers_write.go b/vendor/github.com/quic-go/quic-go/sys_conn_buffers_write.go new file mode 100644 index 00000000..86fead96 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/sys_conn_buffers_write.go @@ -0,0 +1,70 @@ +// Code generated by go generate. DO NOT EDIT. +// Source: sys_conn_buffers.go + +package quic + +import ( + "errors" + "fmt" + "net" + "syscall" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" +) + +func setSendBuffer(c net.PacketConn, logger utils.Logger) error { + conn, ok := c.(interface{ SetWriteBuffer(int) error }) + if !ok { + return errors.New("connection doesn't allow setting of send buffer size. Not a *net.UDPConn?") + } + + var syscallConn syscall.RawConn + if sc, ok := c.(interface { + SyscallConn() (syscall.RawConn, error) + }); ok { + var err error + syscallConn, err = sc.SyscallConn() + if err != nil { + syscallConn = nil + } + } + // The connection has a SetWriteBuffer method, but we couldn't obtain a syscall.RawConn. + // This shouldn't happen for a net.UDPConn, but is possible if the connection just implements the + // net.PacketConn interface and the SetWriteBuffer method. + // We have no way of checking if increasing the buffer size actually worked. + if syscallConn == nil { + return conn.SetWriteBuffer(protocol.DesiredSendBufferSize) + } + + size, err := inspectWriteBuffer(syscallConn) + if err != nil { + return fmt.Errorf("failed to determine send buffer size: %w", err) + } + if size >= protocol.DesiredSendBufferSize { + logger.Debugf("Conn has send buffer of %d kiB (wanted: at least %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024) + return nil + } + // Ignore the error. We check if we succeeded by querying the buffer size afterward. + _ = conn.SetWriteBuffer(protocol.DesiredSendBufferSize) + newSize, err := inspectWriteBuffer(syscallConn) + if newSize < protocol.DesiredSendBufferSize { + // Try again with RCVBUFFORCE on Linux + _ = forceSetSendBuffer(syscallConn, protocol.DesiredSendBufferSize) + newSize, err = inspectWriteBuffer(syscallConn) + if err != nil { + return fmt.Errorf("failed to determine send buffer size: %w", err) + } + } + if err != nil { + return fmt.Errorf("failed to determine send buffer size: %w", err) + } + if newSize == size { + return fmt.Errorf("failed to increase send buffer size (wanted: %d kiB, got %d kiB)", protocol.DesiredSendBufferSize/1024, newSize/1024) + } + if newSize < protocol.DesiredSendBufferSize { + return fmt.Errorf("failed to sufficiently increase send buffer size (was: %d kiB, wanted: %d kiB, got: %d kiB)", size/1024, protocol.DesiredSendBufferSize/1024, newSize/1024) + } + logger.Debugf("Increased send buffer size to %d kiB", newSize/1024) + return nil +} diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_df.go b/vendor/github.com/quic-go/quic-go/sys_conn_df.go index ef9f981a..a2189412 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_df.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_df.go @@ -2,11 +2,13 @@ package quic -import "syscall" +import ( + "syscall" +) -func setDF(rawConn syscall.RawConn) error { +func setDF(syscall.RawConn) (bool, error) { // no-op on unsupported platforms - return nil + return false, nil } func isMsgSizeErr(err error) bool { diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_df_linux.go b/vendor/github.com/quic-go/quic-go/sys_conn_df_linux.go index 98542b41..92f7e725 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_df_linux.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_df_linux.go @@ -4,14 +4,23 @@ package quic import ( "errors" + "log" + "os" + "strconv" "syscall" + "unsafe" "golang.org/x/sys/unix" "github.com/quic-go/quic-go/internal/utils" ) -func setDF(rawConn syscall.RawConn) error { +// UDP_SEGMENT controls GSO (Generic Segmentation Offload) +// +//nolint:stylecheck +const UDP_SEGMENT = 103 + +func setDF(rawConn syscall.RawConn) (bool, error) { // Enabling IP_MTU_DISCOVER will force the kernel to return "sendto: message too long" // and the datagram will not be fragmented var errDFIPv4, errDFIPv6 error @@ -19,7 +28,7 @@ func setDF(rawConn syscall.RawConn) error { errDFIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_MTU_DISCOVER, unix.IP_PMTUDISC_DO) errDFIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_MTU_DISCOVER, unix.IPV6_PMTUDISC_DO) }); err != nil { - return err + return false, err } switch { case errDFIPv4 == nil && errDFIPv6 == nil: @@ -29,12 +38,46 @@ func setDF(rawConn syscall.RawConn) error { case errDFIPv4 != nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv6.") case errDFIPv4 != nil && errDFIPv6 != nil: - return errors.New("setting DF failed for both IPv4 and IPv6") + return false, errors.New("setting DF failed for both IPv4 and IPv6") } - return nil + return true, nil +} + +func maybeSetGSO(rawConn syscall.RawConn) bool { + enable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_ENABLE_GSO")) + if !enable { + return false + } + + var setErr error + if err := rawConn.Control(func(fd uintptr) { + setErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_UDP, UDP_SEGMENT, 1) + }); err != nil { + setErr = err + } + if setErr != nil { + log.Println("failed to enable GSO") + return false + } + return true } func isMsgSizeErr(err error) bool { // https://man7.org/linux/man-pages/man7/udp.7.html return errors.Is(err, unix.EMSGSIZE) } + +func appendUDPSegmentSizeMsg(b []byte, size uint16) []byte { + startLen := len(b) + const dataLen = 2 // payload is a uint16 + b = append(b, make([]byte, unix.CmsgSpace(dataLen))...) + h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen])) + h.Level = syscall.IPPROTO_UDP + h.Type = UDP_SEGMENT + h.SetLen(unix.CmsgLen(dataLen)) + + // UnixRights uses the private `data` method, but I *think* this achieves the same goal. + offset := startLen + unix.CmsgSpace(0) + *(*uint16)(unsafe.Pointer(&b[offset])) = size + return b +} diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go b/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go index 9855e8de..a7f1351a 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_df_windows.go @@ -12,20 +12,23 @@ import ( ) const ( - // same for both IPv4 and IPv6 on Windows + // IP_DONTFRAGMENT controls the Don't Fragment (DF) bit. + // + // It's the same code point for both IPv4 and IPv6 on Windows. // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IP_DONTFRAG.html // https://microsoft.github.io/windows-docs-rs/doc/windows/Win32/Networking/WinSock/constant.IPV6_DONTFRAG.html + // + //nolint:stylecheck IP_DONTFRAGMENT = 14 - IPV6_DONTFRAG = 14 ) -func setDF(rawConn syscall.RawConn) error { +func setDF(rawConn syscall.RawConn) (bool, error) { var errDFIPv4, errDFIPv6 error if err := rawConn.Control(func(fd uintptr) { errDFIPv4 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IP_DONTFRAGMENT, 1) - errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_DONTFRAG, 1) + errDFIPv6 = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IP_DONTFRAGMENT, 1) }); err != nil { - return err + return false, err } switch { case errDFIPv4 == nil && errDFIPv6 == nil: @@ -35,9 +38,9 @@ func setDF(rawConn syscall.RawConn) error { case errDFIPv4 != nil && errDFIPv6 == nil: utils.DefaultLogger.Debugf("Setting DF for IPv6.") case errDFIPv4 != nil && errDFIPv6 != nil: - return errors.New("setting DF failed for both IPv4 and IPv6") + return false, errors.New("setting DF failed for both IPv4 and IPv6") } - return nil + return true, nil } func isMsgSizeErr(err error) bool { diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go index 61c3f54b..2c12233c 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_linux.go @@ -2,7 +2,11 @@ package quic -import "golang.org/x/sys/unix" +import ( + "syscall" + + "golang.org/x/sys/unix" +) const msgTypeIPTOS = unix.IP_TOS @@ -17,3 +21,23 @@ const ( ) const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) + +func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error { + var serr error + if err := c.Control(func(fd uintptr) { + serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bytes) + }); err != nil { + return err + } + return serr +} + +func forceSetSendBuffer(c syscall.RawConn, bytes int) error { + var serr error + if err := c.Control(func(fd uintptr) { + serr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, bytes) + }); err != nil { + return err + } + return serr +} diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_helper_nonlinux.go b/vendor/github.com/quic-go/quic-go/sys_conn_helper_nonlinux.go new file mode 100644 index 00000000..80b795c3 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/sys_conn_helper_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux + +package quic + +func forceSetReceiveBuffer(c any, bytes int) error { return nil } +func forceSetSendBuffer(c any, bytes int) error { return nil } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_no_gso.go b/vendor/github.com/quic-go/quic-go/sys_conn_no_gso.go new file mode 100644 index 00000000..6f6a8c91 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/sys_conn_no_gso.go @@ -0,0 +1,8 @@ +//go:build darwin || freebsd + +package quic + +import "syscall" + +func maybeSetGSO(_ syscall.RawConn) bool { return false } +func appendUDPSegmentSizeMsg(_ []byte, _ uint16) []byte { return nil } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_no_oob.go b/vendor/github.com/quic-go/quic-go/sys_conn_no_oob.go index 7ab5040a..2a1f807e 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_no_oob.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_no_oob.go @@ -2,14 +2,20 @@ package quic -import "net" +import ( + "net" + "net/netip" +) -func newConn(c net.PacketConn) (rawConn, error) { - return &basicConn{PacketConn: c}, nil +func newConn(c net.PacketConn, supportsDF bool) (*basicConn, error) { + return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil } -func inspectReadBuffer(interface{}) (int, error) { - return 0, nil +func inspectReadBuffer(any) (int, error) { return 0, nil } +func inspectWriteBuffer(any) (int, error) { return 0, nil } + +type packetInfo struct { + addr netip.Addr } func (i *packetInfo) OOB() []byte { return nil } diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_oob.go b/vendor/github.com/quic-go/quic-go/sys_conn_oob.go index 806dfb81..fc28f94b 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_oob.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_oob.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net" + "net/netip" "syscall" "time" @@ -32,20 +33,10 @@ type batchConn interface { ReadBatch(ms []ipv4.Message, flags int) (int, error) } -func inspectReadBuffer(c interface{}) (int, error) { - conn, ok := c.(interface { - SyscallConn() (syscall.RawConn, error) - }) - if !ok { - return 0, errors.New("doesn't have a SyscallConn") - } - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) - } +func inspectReadBuffer(c syscall.RawConn) (int, error) { var size int var serr error - if err := rawConn.Control(func(fd uintptr) { + if err := c.Control(func(fd uintptr) { size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF) }); err != nil { return 0, err @@ -53,6 +44,17 @@ func inspectReadBuffer(c interface{}) (int, error) { return size, serr } +func inspectWriteBuffer(c syscall.RawConn) (int, error) { + var size int + var serr error + if err := c.Control(func(fd uintptr) { + size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF) + }); err != nil { + return 0, err + } + return size, serr +} + type oobConn struct { OOBCapablePacketConn batchConn batchConn @@ -61,11 +63,13 @@ type oobConn struct { // Packets received from the kernel, but not yet returned by ReadPacket(). messages []ipv4.Message buffers [batchSize]*packetBuffer + + cap connCapabilities } var _ rawConn = &oobConn{} -func newConn(c OOBCapablePacketConn) (*oobConn, error) { +func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) { rawConn, err := c.SyscallConn() if err != nil { return nil, err @@ -122,6 +126,10 @@ func newConn(c OOBCapablePacketConn) (*oobConn, error) { bc = ipv4.NewPacketConn(c) } + // Try enabling GSO. + // This will only succeed on Linux, and only for kernels > 4.18. + supportsGSO := maybeSetGSO(rawConn) + msgs := make([]ipv4.Message, batchSize) for i := range msgs { // preallocate the [][]byte @@ -133,13 +141,15 @@ func newConn(c OOBCapablePacketConn) (*oobConn, error) { messages: msgs, readPos: batchSize, } + oobConn.cap.DF = supportsDF + oobConn.cap.GSO = supportsGSO for i := 0; i < batchSize; i++ { oobConn.messages[i].OOB = make([]byte, oobBufferSize) } return oobConn, nil } -func (c *oobConn) ReadPacket() (*receivedPacket, error) { +func (c *oobConn) ReadPacket() (receivedPacket, error) { if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages. c.messages = c.messages[:batchSize] // replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call @@ -153,7 +163,7 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { n, err := c.batchConn.ReadBatch(c.messages, 0) if n == 0 || err != nil { - return nil, err + return receivedPacket{}, err } c.messages = c.messages[:n] } @@ -163,18 +173,21 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { c.readPos++ data := msg.OOB[:msg.NN] - var ecn protocol.ECN - var destIP net.IP - var ifIndex uint32 + p := receivedPacket{ + remoteAddr: msg.Addr, + rcvTime: time.Now(), + data: msg.Buffers[0][:msg.N], + buffer: buffer, + } for len(data) > 0 { hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data) if err != nil { - return nil, err + return receivedPacket{}, err } if hdr.Level == unix.IPPROTO_IP { switch hdr.Type { case msgTypeIPTOS: - ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ECN(body[0] & ecnMask) case msgTypeIPv4PKTINFO: // struct in_pktinfo { // unsigned int ipi_ifindex; /* Interface index */ @@ -182,80 +195,94 @@ func (c *oobConn) ReadPacket() (*receivedPacket, error) { // struct in_addr ipi_addr; /* Header Destination // address */ // }; - ip := make([]byte, 4) + var ip [4]byte if len(body) == 12 { - ifIndex = binary.LittleEndian.Uint32(body) - copy(ip, body[8:12]) + copy(ip[:], body[8:12]) + p.info.ifIndex = binary.LittleEndian.Uint32(body) } else if len(body) == 4 { // FreeBSD - copy(ip, body) + copy(ip[:], body) } - destIP = net.IP(ip) + p.info.addr = netip.AddrFrom4(ip) } } if hdr.Level == unix.IPPROTO_IPV6 { switch hdr.Type { case unix.IPV6_TCLASS: - ecn = protocol.ECN(body[0] & ecnMask) + p.ecn = protocol.ECN(body[0] & ecnMask) case msgTypeIPv6PKTINFO: // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // unsigned int ipi6_ifindex; /* send/recv interface index */ // }; if len(body) == 20 { - ip := make([]byte, 16) - copy(ip, body[:16]) - destIP = net.IP(ip) - ifIndex = binary.LittleEndian.Uint32(body[16:]) + var ip [16]byte + copy(ip[:], body[:16]) + p.info.addr = netip.AddrFrom16(ip) + p.info.ifIndex = binary.LittleEndian.Uint32(body[16:]) } } } data = remainder } - var info *packetInfo - if destIP != nil { - info = &packetInfo{ - addr: destIP, - ifIndex: ifIndex, - } - } - return &receivedPacket{ - remoteAddr: msg.Addr, - rcvTime: time.Now(), - data: msg.Buffers[0][:msg.N], - ecn: ecn, - info: info, - buffer: buffer, - }, nil + return p, nil } -func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) { +// WriteTo (re)implements the net.PacketConn method. +// This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection. +// With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set. +func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) { + return c.WritePacket(p, uint16(len(p)), addr, nil) +} + +// WritePacket writes a new packet. +// If the connection supports GSO (and we activated GSO support before), +// it appends the UDP_SEGMENT size message to oob. +// Callers are advised to make sure that oob has a sufficient capacity, +// such that appending the UDP_SEGMENT size message doesn't cause an allocation. +func (c *oobConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, oob []byte) (n int, err error) { + if c.cap.GSO { + oob = appendUDPSegmentSizeMsg(oob, packetSize) + } else if uint16(len(b)) != packetSize { + panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b))) + } n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) return n, err } +func (c *oobConn) capabilities() connCapabilities { + return c.cap +} + +type packetInfo struct { + addr netip.Addr + ifIndex uint32 +} + func (info *packetInfo) OOB() []byte { if info == nil { return nil } - if ip4 := info.addr.To4(); ip4 != nil { + if info.addr.Is4() { + ip := info.addr.As4() // struct in_pktinfo { // unsigned int ipi_ifindex; /* Interface index */ // struct in_addr ipi_spec_dst; /* Local address */ // struct in_addr ipi_addr; /* Header Destination address */ // }; cm := ipv4.ControlMessage{ - Src: ip4, + Src: ip[:], IfIndex: int(info.ifIndex), } return cm.Marshal() - } else if len(info.addr) == 16 { + } else if info.addr.Is6() { + ip := info.addr.As16() // struct in6_pktinfo { // struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // unsigned int ipi6_ifindex; /* send/recv interface index */ // }; cm := ipv6.ControlMessage{ - Src: info.addr, + Src: ip[:], IfIndex: int(info.ifIndex), } return cm.Marshal() diff --git a/vendor/github.com/quic-go/quic-go/sys_conn_windows.go b/vendor/github.com/quic-go/quic-go/sys_conn_windows.go index b003fe94..b9c1cbc8 100644 --- a/vendor/github.com/quic-go/quic-go/sys_conn_windows.go +++ b/vendor/github.com/quic-go/quic-go/sys_conn_windows.go @@ -3,32 +3,20 @@ package quic import ( - "errors" - "fmt" - "net" + "net/netip" "syscall" "golang.org/x/sys/windows" ) -func newConn(c OOBCapablePacketConn) (rawConn, error) { - return &basicConn{PacketConn: c}, nil +func newConn(c OOBCapablePacketConn, supportsDF bool) (*basicConn, error) { + return &basicConn{PacketConn: c, supportsDF: supportsDF}, nil } -func inspectReadBuffer(c net.PacketConn) (int, error) { - conn, ok := c.(interface { - SyscallConn() (syscall.RawConn, error) - }) - if !ok { - return 0, errors.New("doesn't have a SyscallConn") - } - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err) - } +func inspectReadBuffer(c syscall.RawConn) (int, error) { var size int var serr error - if err := rawConn.Control(func(fd uintptr) { + if err := c.Control(func(fd uintptr) { size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF) }); err != nil { return 0, err @@ -36,4 +24,19 @@ func inspectReadBuffer(c net.PacketConn) (int, error) { return size, serr } +func inspectWriteBuffer(c syscall.RawConn) (int, error) { + var size int + var serr error + if err := c.Control(func(fd uintptr) { + size, serr = windows.GetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF) + }); err != nil { + return 0, err + } + return size, serr +} + +type packetInfo struct { + addr netip.Addr +} + func (i *packetInfo) OOB() []byte { return nil } diff --git a/vendor/github.com/quic-go/quic-go/transport.go b/vendor/github.com/quic-go/quic-go/transport.go new file mode 100644 index 00000000..5ce02d25 --- /dev/null +++ b/vendor/github.com/quic-go/quic-go/transport.go @@ -0,0 +1,435 @@ +package quic + +import ( + "context" + "crypto/rand" + "crypto/tls" + "errors" + "log" + "net" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/quic-go/quic-go/internal/wire" + + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +// The Transport is the central point to manage incoming and outgoing QUIC connections. +// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. +// This means that a single UDP socket can be used for listening for incoming connections, as well as +// for dialing an arbitrary number of outgoing connections. +// A Transport handles a single net.PacketConn, and offers a range of configuration options +// compared to the simple helper functions like Listen and Dial that this package provides. +type Transport struct { + // A single net.PacketConn can only be handled by one Transport. + // Bad things will happen if passed to multiple Transports. + // + // If not done by the user, the connection is passed through OptimizeConn to enable a number of optimizations. + // After passing the connection to the Transport, it's invalid to call ReadFrom on the connection. + // Calling WriteTo is only valid on the connection returned by OptimizeConn. + Conn net.PacketConn + + // The length of the connection ID in bytes. + // It can be 0, or any value between 4 and 18. + // If unset, a 4 byte connection ID will be used. + ConnectionIDLength int + + // Use for generating new connection IDs. + // This allows the application to control of the connection IDs used, + // which allows routing / load balancing based on connection IDs. + // All Connection IDs returned by the ConnectionIDGenerator MUST + // have the same length. + ConnectionIDGenerator ConnectionIDGenerator + + // The StatelessResetKey is used to generate stateless reset tokens. + // If no key is configured, sending of stateless resets is disabled. + // It is highly recommended to configure a stateless reset key, as stateless resets + // allow the peer to quickly recover from crashes and reboots of this node. + // See section 10.3 of RFC 9000 for details. + StatelessResetKey *StatelessResetKey + + // A Tracer traces events that don't belong to a single QUIC connection. + Tracer logging.Tracer + + handlerMap packetHandlerManager + + mutex sync.Mutex + initOnce sync.Once + initErr error + + // Set in init. + // If no ConnectionIDGenerator is set, this is the ConnectionIDLength. + connIDLen int + // Set in init. + // If no ConnectionIDGenerator is set, this is set to a default. + connIDGenerator ConnectionIDGenerator + + server unknownPacketHandler + + conn rawConn + + closeQueue chan closePacket + statelessResetQueue chan receivedPacket + + listening chan struct{} // is closed when listen returns + closed bool + createdConn bool + isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial + + logger utils.Logger +} + +// Listen starts listening for incoming QUIC connections. +// There can only be a single listener on any net.PacketConn. +// Listen may only be called again after the current Listener was closed. +func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(conf); err != nil { + return nil, err + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.server != nil { + return nil, errListenerAlreadySet + } + conf = populateServerConfig(conf) + if err := t.init(true); err != nil { + return nil, err + } + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false) + if err != nil { + return nil, err + } + t.server = s + return &Listener{baseServer: s}, nil +} + +// ListenEarly starts listening for incoming QUIC connections. +// There can only be a single listener on any net.PacketConn. +// Listen may only be called again after the current Listener was closed. +func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { + if tlsConf == nil { + return nil, errors.New("quic: tls.Config not set") + } + if err := validateConfig(conf); err != nil { + return nil, err + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.server != nil { + return nil, errListenerAlreadySet + } + conf = populateServerConfig(conf) + if err := t.init(true); err != nil { + return nil, err + } + s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) + if err != nil { + return nil, err + } + t.server = s + return &EarlyListener{baseServer: s}, nil +} + +// Dial dials a new connection to a remote host (not using 0-RTT). +func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateConfig(conf) + if err := t.init(false); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) +} + +// DialEarly dials a new connection, attempting to use 0-RTT if possible. +func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateConfig(conf) + if err := t.init(false); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) +} + +func (t *Transport) init(isServer bool) error { + t.initOnce.Do(func() { + getMultiplexer().AddConn(t.Conn) + + var conn rawConn + if c, ok := t.Conn.(rawConn); ok { + conn = c + } else { + var err error + conn, err = wrapConn(t.Conn) + if err != nil { + t.initErr = err + return + } + } + t.conn = conn + + t.logger = utils.DefaultLogger // TODO: make this configurable + t.conn = conn + t.handlerMap = newPacketHandlerMap(t.StatelessResetKey, t.enqueueClosePacket, t.logger) + t.listening = make(chan struct{}) + + t.closeQueue = make(chan closePacket, 4) + t.statelessResetQueue = make(chan receivedPacket, 4) + + if t.ConnectionIDGenerator != nil { + t.connIDGenerator = t.ConnectionIDGenerator + t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen() + } else { + connIDLen := t.ConnectionIDLength + if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) { + connIDLen = protocol.DefaultConnectionIDLength + } + t.connIDLen = connIDLen + t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} + } + + go t.listen(conn) + go t.runSendQueue() + }) + return t.initErr +} + +func (t *Transport) enqueueClosePacket(p closePacket) { + select { + case t.closeQueue <- p: + default: + // Oops, we're backlogged. + // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. + } +} + +func (t *Transport) runSendQueue() { + for { + select { + case <-t.listening: + return + case p := <-t.closeQueue: + t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB()) + case p := <-t.statelessResetQueue: + t.sendStatelessReset(p) + } + } +} + +// Close closes the underlying connection and waits until listen has returned. +// It is invalid to start new listeners or connections after that. +func (t *Transport) Close() error { + t.close(errors.New("closing")) + if t.createdConn { + if err := t.Conn.Close(); err != nil { + return err + } + } else if t.conn != nil { + t.conn.SetReadDeadline(time.Now()) + defer func() { t.conn.SetReadDeadline(time.Time{}) }() + } + if t.listening != nil { + <-t.listening // wait until listening returns + } + return nil +} + +func (t *Transport) closeServer() { + t.handlerMap.CloseServer() + t.mutex.Lock() + t.server = nil + if t.isSingleUse { + t.closed = true + } + t.mutex.Unlock() + if t.createdConn { + t.Conn.Close() + } + if t.isSingleUse { + t.conn.SetReadDeadline(time.Now()) + defer func() { t.conn.SetReadDeadline(time.Time{}) }() + <-t.listening // wait until listening returns + } +} + +func (t *Transport) close(e error) { + t.mutex.Lock() + defer t.mutex.Unlock() + if t.closed { + return + } + + if t.handlerMap != nil { + t.handlerMap.Close(e) + } + if t.server != nil { + t.server.setCloseError(e) + } + t.closed = true +} + +// only print warnings about the UDP receive buffer size once +var setBufferWarningOnce sync.Once + +func (t *Transport) listen(conn rawConn) { + defer close(t.listening) + defer getMultiplexer().RemoveConn(t.Conn) + + if err := setReceiveBuffer(t.Conn, t.logger); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + setBufferWarningOnce.Do(func() { + if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { + return + } + log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err) + }) + } + } + if err := setSendBuffer(t.Conn, t.logger); err != nil { + if !strings.Contains(err.Error(), "use of closed network connection") { + setBufferWarningOnce.Do(func() { + if disable, _ := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING")); disable { + return + } + log.Printf("%s. See https://github.com/quic-go/quic-go/wiki/UDP-Buffer-Sizes for details.", err) + }) + } + } + + for { + p, err := conn.ReadPacket() + //nolint:staticcheck // SA1019 ignore this! + // TODO: This code is used to ignore wsa errors on Windows. + // Since net.Error.Temporary is deprecated as of Go 1.18, we should find a better solution. + // See https://github.com/quic-go/quic-go/issues/1737 for details. + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + t.mutex.Lock() + closed := t.closed + t.mutex.Unlock() + if closed { + return + } + t.logger.Debugf("Temporary error reading from conn: %w", err) + continue + } + if err != nil { + t.close(err) + return + } + t.handlePacket(p) + } +} + +func (t *Transport) handlePacket(p receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, t.connIDLen) + if err != nil { + t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) + if t.Tracer != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + p.buffer.MaybeRelease() + return + } + + if isStatelessReset := t.maybeHandleStatelessReset(p.data); isStatelessReset { + return + } + if handler, ok := t.handlerMap.Get(connID); ok { + handler.handlePacket(p) + return + } + if !wire.IsLongHeaderPacket(p.data[0]) { + t.maybeSendStatelessReset(p) + return + } + + t.mutex.Lock() + defer t.mutex.Unlock() + if t.server == nil { // no server set + t.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + return + } + t.server.handlePacket(p) +} + +func (t *Transport) maybeSendStatelessReset(p receivedPacket) { + if t.StatelessResetKey == nil { + p.buffer.Release() + return + } + + // Don't send a stateless reset in response to very small packets. + // This includes packets that could be stateless resets. + if len(p.data) <= protocol.MinStatelessResetSize { + p.buffer.Release() + return + } + + select { + case t.statelessResetQueue <- p: + default: + // it's fine to not send a stateless reset when we're busy + p.buffer.Release() + } +} + +func (t *Transport) sendStatelessReset(p receivedPacket) { + defer p.buffer.Release() + + connID, err := wire.ParseConnectionID(p.data, t.connIDLen) + if err != nil { + t.logger.Errorf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) + return + } + token := t.handlerMap.GetStatelessResetToken(connID) + t.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) + data := make([]byte, protocol.MinStatelessResetSize-16, protocol.MinStatelessResetSize) + rand.Read(data) + data[0] = (data[0] & 0x7f) | 0x40 + data = append(data, token[:]...) + if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil { + t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) + } +} + +func (t *Transport) maybeHandleStatelessReset(data []byte) bool { + // stateless resets are always short header packets + if wire.IsLongHeaderPacket(data[0]) { + return false + } + if len(data) < 17 /* type byte + 16 bytes for the reset token */ { + return false + } + + token := *(*protocol.StatelessResetToken)(data[len(data)-16:]) + if conn, ok := t.handlerMap.GetByResetToken(token); ok { + t.logger.Debugf("Received a stateless reset with token %#x. Closing connection.", token) + go conn.destroy(&StatelessResetError{Token: token}) + return true + } + return false +} diff --git a/vendor/github.com/quic-go/quic-go/zero_rtt_queue.go b/vendor/github.com/quic-go/quic-go/zero_rtt_queue.go deleted file mode 100644 index b81a936e..00000000 --- a/vendor/github.com/quic-go/quic-go/zero_rtt_queue.go +++ /dev/null @@ -1,34 +0,0 @@ -package quic - -import ( - "time" - - "github.com/quic-go/quic-go/internal/protocol" -) - -type zeroRTTQueue struct { - queue []*receivedPacket - retireTimer *time.Timer -} - -var _ packetHandler = &zeroRTTQueue{} - -func (h *zeroRTTQueue) handlePacket(p *receivedPacket) { - if len(h.queue) < protocol.Max0RTTQueueLen { - h.queue = append(h.queue, p) - } -} -func (h *zeroRTTQueue) shutdown() {} -func (h *zeroRTTQueue) destroy(error) {} -func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } -func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) { - for _, p := range h.queue { - sess.handlePacket(p) - } -} - -func (h *zeroRTTQueue) Clear() { - for _, p := range h.queue { - p.buffer.Release() - } -} diff --git a/vendor/golang.org/x/tools/go/types/objectpath/objectpath.go b/vendor/golang.org/x/tools/go/types/objectpath/objectpath.go new file mode 100644 index 00000000..aa7dfacc --- /dev/null +++ b/vendor/golang.org/x/tools/go/types/objectpath/objectpath.go @@ -0,0 +1,764 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package objectpath defines a naming scheme for types.Objects +// (that is, named entities in Go programs) relative to their enclosing +// package. +// +// Type-checker objects are canonical, so they are usually identified by +// their address in memory (a pointer), but a pointer has meaning only +// within one address space. By contrast, objectpath names allow the +// identity of an object to be sent from one program to another, +// establishing a correspondence between types.Object variables that are +// distinct but logically equivalent. +// +// A single object may have multiple paths. In this example, +// +// type A struct{ X int } +// type B A +// +// the field X has two paths due to its membership of both A and B. +// The For(obj) function always returns one of these paths, arbitrarily +// but consistently. +package objectpath + +import ( + "fmt" + "go/types" + "sort" + "strconv" + "strings" + + "golang.org/x/tools/internal/typeparams" + + _ "unsafe" // for go:linkname +) + +// A Path is an opaque name that identifies a types.Object +// relative to its package. Conceptually, the name consists of a +// sequence of destructuring operations applied to the package scope +// to obtain the original object. +// The name does not include the package itself. +type Path string + +// Encoding +// +// An object path is a textual and (with training) human-readable encoding +// of a sequence of destructuring operators, starting from a types.Package. +// The sequences represent a path through the package/object/type graph. +// We classify these operators by their type: +// +// PO package->object Package.Scope.Lookup +// OT object->type Object.Type +// TT type->type Type.{Elem,Key,Params,Results,Underlying} [EKPRU] +// TO type->object Type.{At,Field,Method,Obj} [AFMO] +// +// All valid paths start with a package and end at an object +// and thus may be defined by the regular language: +// +// objectpath = PO (OT TT* TO)* +// +// The concrete encoding follows directly: +// - The only PO operator is Package.Scope.Lookup, which requires an identifier. +// - The only OT operator is Object.Type, +// which we encode as '.' because dot cannot appear in an identifier. +// - The TT operators are encoded as [EKPRUTC]; +// one of these (TypeParam) requires an integer operand, +// which is encoded as a string of decimal digits. +// - The TO operators are encoded as [AFMO]; +// three of these (At,Field,Method) require an integer operand, +// which is encoded as a string of decimal digits. +// These indices are stable across different representations +// of the same package, even source and export data. +// The indices used are implementation specific and may not correspond to +// the argument to the go/types function. +// +// In the example below, +// +// package p +// +// type T interface { +// f() (a string, b struct{ X int }) +// } +// +// field X has the path "T.UM0.RA1.F0", +// representing the following sequence of operations: +// +// p.Lookup("T") T +// .Type().Underlying().Method(0). f +// .Type().Results().At(1) b +// .Type().Field(0) X +// +// The encoding is not maximally compact---every R or P is +// followed by an A, for example---but this simplifies the +// encoder and decoder. +const ( + // object->type operators + opType = '.' // .Type() (Object) + + // type->type operators + opElem = 'E' // .Elem() (Pointer, Slice, Array, Chan, Map) + opKey = 'K' // .Key() (Map) + opParams = 'P' // .Params() (Signature) + opResults = 'R' // .Results() (Signature) + opUnderlying = 'U' // .Underlying() (Named) + opTypeParam = 'T' // .TypeParams.At(i) (Named, Signature) + opConstraint = 'C' // .Constraint() (TypeParam) + + // type->object operators + opAt = 'A' // .At(i) (Tuple) + opField = 'F' // .Field(i) (Struct) + opMethod = 'M' // .Method(i) (Named or Interface; not Struct: "promoted" names are ignored) + opObj = 'O' // .Obj() (Named, TypeParam) +) + +// For is equivalent to new(Encoder).For(obj). +// +// It may be more efficient to reuse a single Encoder across several calls. +func For(obj types.Object) (Path, error) { + return new(Encoder).For(obj) +} + +// An Encoder amortizes the cost of encoding the paths of multiple objects. +// The zero value of an Encoder is ready to use. +type Encoder struct { + scopeNamesMemo map[*types.Scope][]string // memoization of Scope.Names() + namedMethodsMemo map[*types.Named][]*types.Func // memoization of namedMethods() +} + +// For returns the path to an object relative to its package, +// or an error if the object is not accessible from the package's Scope. +// +// The For function guarantees to return a path only for the following objects: +// - package-level types +// - exported package-level non-types +// - methods +// - parameter and result variables +// - struct fields +// These objects are sufficient to define the API of their package. +// The objects described by a package's export data are drawn from this set. +// +// For does not return a path for predeclared names, imported package +// names, local names, and unexported package-level names (except +// types). +// +// Example: given this definition, +// +// package p +// +// type T interface { +// f() (a string, b struct{ X int }) +// } +// +// For(X) would return a path that denotes the following sequence of operations: +// +// p.Scope().Lookup("T") (TypeName T) +// .Type().Underlying().Method(0). (method Func f) +// .Type().Results().At(1) (field Var b) +// .Type().Field(0) (field Var X) +// +// where p is the package (*types.Package) to which X belongs. +func (enc *Encoder) For(obj types.Object) (Path, error) { + pkg := obj.Pkg() + + // This table lists the cases of interest. + // + // Object Action + // ------ ------ + // nil reject + // builtin reject + // pkgname reject + // label reject + // var + // package-level accept + // func param/result accept + // local reject + // struct field accept + // const + // package-level accept + // local reject + // func + // package-level accept + // init functions reject + // concrete method accept + // interface method accept + // type + // package-level accept + // local reject + // + // The only accessible package-level objects are members of pkg itself. + // + // The cases are handled in four steps: + // + // 1. reject nil and builtin + // 2. accept package-level objects + // 3. reject obviously invalid objects + // 4. search the API for the path to the param/result/field/method. + + // 1. reference to nil or builtin? + if pkg == nil { + return "", fmt.Errorf("predeclared %s has no path", obj) + } + scope := pkg.Scope() + + // 2. package-level object? + if scope.Lookup(obj.Name()) == obj { + // Only exported objects (and non-exported types) have a path. + // Non-exported types may be referenced by other objects. + if _, ok := obj.(*types.TypeName); !ok && !obj.Exported() { + return "", fmt.Errorf("no path for non-exported %v", obj) + } + return Path(obj.Name()), nil + } + + // 3. Not a package-level object. + // Reject obviously non-viable cases. + switch obj := obj.(type) { + case *types.TypeName: + if _, ok := obj.Type().(*typeparams.TypeParam); !ok { + // With the exception of type parameters, only package-level type names + // have a path. + return "", fmt.Errorf("no path for %v", obj) + } + case *types.Const, // Only package-level constants have a path. + *types.Label, // Labels are function-local. + *types.PkgName: // PkgNames are file-local. + return "", fmt.Errorf("no path for %v", obj) + + case *types.Var: + // Could be: + // - a field (obj.IsField()) + // - a func parameter or result + // - a local var. + // Sadly there is no way to distinguish + // a param/result from a local + // so we must proceed to the find. + + case *types.Func: + // A func, if not package-level, must be a method. + if recv := obj.Type().(*types.Signature).Recv(); recv == nil { + return "", fmt.Errorf("func is not a method: %v", obj) + } + + if path, ok := enc.concreteMethod(obj); ok { + // Fast path for concrete methods that avoids looping over scope. + return path, nil + } + + default: + panic(obj) + } + + // 4. Search the API for the path to the var (field/param/result) or method. + + // First inspect package-level named types. + // In the presence of path aliases, these give + // the best paths because non-types may + // refer to types, but not the reverse. + empty := make([]byte, 0, 48) // initial space + names := enc.scopeNames(scope) + for _, name := range names { + o := scope.Lookup(name) + tname, ok := o.(*types.TypeName) + if !ok { + continue // handle non-types in second pass + } + + path := append(empty, name...) + path = append(path, opType) + + T := o.Type() + + if tname.IsAlias() { + // type alias + if r := find(obj, T, path, nil); r != nil { + return Path(r), nil + } + } else { + if named, _ := T.(*types.Named); named != nil { + if r := findTypeParam(obj, typeparams.ForNamed(named), path, nil); r != nil { + // generic named type + return Path(r), nil + } + } + // defined (named) type + if r := find(obj, T.Underlying(), append(path, opUnderlying), nil); r != nil { + return Path(r), nil + } + } + } + + // Then inspect everything else: + // non-types, and declared methods of defined types. + for _, name := range names { + o := scope.Lookup(name) + path := append(empty, name...) + if _, ok := o.(*types.TypeName); !ok { + if o.Exported() { + // exported non-type (const, var, func) + if r := find(obj, o.Type(), append(path, opType), nil); r != nil { + return Path(r), nil + } + } + continue + } + + // Inspect declared methods of defined types. + if T, ok := o.Type().(*types.Named); ok { + path = append(path, opType) + // Note that method index here is always with respect + // to canonical ordering of methods, regardless of how + // they appear in the underlying type. + for i, m := range enc.namedMethods(T) { + path2 := appendOpArg(path, opMethod, i) + if m == obj { + return Path(path2), nil // found declared method + } + if r := find(obj, m.Type(), append(path2, opType), nil); r != nil { + return Path(r), nil + } + } + } + } + + return "", fmt.Errorf("can't find path for %v in %s", obj, pkg.Path()) +} + +func appendOpArg(path []byte, op byte, arg int) []byte { + path = append(path, op) + path = strconv.AppendInt(path, int64(arg), 10) + return path +} + +// concreteMethod returns the path for meth, which must have a non-nil receiver. +// The second return value indicates success and may be false if the method is +// an interface method or if it is an instantiated method. +// +// This function is just an optimization that avoids the general scope walking +// approach. You are expected to fall back to the general approach if this +// function fails. +func (enc *Encoder) concreteMethod(meth *types.Func) (Path, bool) { + // Concrete methods can only be declared on package-scoped named types. For + // that reason we can skip the expensive walk over the package scope: the + // path will always be package -> named type -> method. We can trivially get + // the type name from the receiver, and only have to look over the type's + // methods to find the method index. + // + // Methods on generic types require special consideration, however. Consider + // the following package: + // + // L1: type S[T any] struct{} + // L2: func (recv S[A]) Foo() { recv.Bar() } + // L3: func (recv S[B]) Bar() { } + // L4: type Alias = S[int] + // L5: func _[T any]() { var s S[int]; s.Foo() } + // + // The receivers of methods on generic types are instantiations. L2 and L3 + // instantiate S with the type-parameters A and B, which are scoped to the + // respective methods. L4 and L5 each instantiate S with int. Each of these + // instantiations has its own method set, full of methods (and thus objects) + // with receivers whose types are the respective instantiations. In other + // words, we have + // + // S[A].Foo, S[A].Bar + // S[B].Foo, S[B].Bar + // S[int].Foo, S[int].Bar + // + // We may thus be trying to produce object paths for any of these objects. + // + // S[A].Foo and S[B].Bar are the origin methods, and their paths are S.Foo + // and S.Bar, which are the paths that this function naturally produces. + // + // S[A].Bar, S[B].Foo, and both methods on S[int] are instantiations that + // don't correspond to the origin methods. For S[int], this is significant. + // The most precise object path for S[int].Foo, for example, is Alias.Foo, + // not S.Foo. Our function, however, would produce S.Foo, which would + // resolve to a different object. + // + // For S[A].Bar and S[B].Foo it could be argued that S.Bar and S.Foo are + // still the correct paths, since only the origin methods have meaningful + // paths. But this is likely only true for trivial cases and has edge cases. + // Since this function is only an optimization, we err on the side of giving + // up, deferring to the slower but definitely correct algorithm. Most users + // of objectpath will only be giving us origin methods, anyway, as referring + // to instantiated methods is usually not useful. + + if typeparams.OriginMethod(meth) != meth { + return "", false + } + + recvT := meth.Type().(*types.Signature).Recv().Type() + if ptr, ok := recvT.(*types.Pointer); ok { + recvT = ptr.Elem() + } + + named, ok := recvT.(*types.Named) + if !ok { + return "", false + } + + if types.IsInterface(named) { + // Named interfaces don't have to be package-scoped + // + // TODO(dominikh): opt: if scope.Lookup(name) == named, then we can apply this optimization to interface + // methods, too, I think. + return "", false + } + + // Preallocate space for the name, opType, opMethod, and some digits. + name := named.Obj().Name() + path := make([]byte, 0, len(name)+8) + path = append(path, name...) + path = append(path, opType) + for i, m := range enc.namedMethods(named) { + if m == meth { + path = appendOpArg(path, opMethod, i) + return Path(path), true + } + } + + // Due to golang/go#59944, go/types fails to associate the receiver with + // certain methods on cgo types. + // + // TODO(rfindley): replace this panic once golang/go#59944 is fixed in all Go + // versions gopls supports. + return "", false + // panic(fmt.Sprintf("couldn't find method %s on type %s; methods: %#v", meth, named, enc.namedMethods(named))) +} + +// find finds obj within type T, returning the path to it, or nil if not found. +// +// The seen map is used to short circuit cycles through type parameters. If +// nil, it will be allocated as necessary. +func find(obj types.Object, T types.Type, path []byte, seen map[*types.TypeName]bool) []byte { + switch T := T.(type) { + case *types.Basic, *types.Named: + // Named types belonging to pkg were handled already, + // so T must belong to another package. No path. + return nil + case *types.Pointer: + return find(obj, T.Elem(), append(path, opElem), seen) + case *types.Slice: + return find(obj, T.Elem(), append(path, opElem), seen) + case *types.Array: + return find(obj, T.Elem(), append(path, opElem), seen) + case *types.Chan: + return find(obj, T.Elem(), append(path, opElem), seen) + case *types.Map: + if r := find(obj, T.Key(), append(path, opKey), seen); r != nil { + return r + } + return find(obj, T.Elem(), append(path, opElem), seen) + case *types.Signature: + if r := findTypeParam(obj, typeparams.ForSignature(T), path, seen); r != nil { + return r + } + if r := find(obj, T.Params(), append(path, opParams), seen); r != nil { + return r + } + return find(obj, T.Results(), append(path, opResults), seen) + case *types.Struct: + for i := 0; i < T.NumFields(); i++ { + fld := T.Field(i) + path2 := appendOpArg(path, opField, i) + if fld == obj { + return path2 // found field var + } + if r := find(obj, fld.Type(), append(path2, opType), seen); r != nil { + return r + } + } + return nil + case *types.Tuple: + for i := 0; i < T.Len(); i++ { + v := T.At(i) + path2 := appendOpArg(path, opAt, i) + if v == obj { + return path2 // found param/result var + } + if r := find(obj, v.Type(), append(path2, opType), seen); r != nil { + return r + } + } + return nil + case *types.Interface: + for i := 0; i < T.NumMethods(); i++ { + m := T.Method(i) + path2 := appendOpArg(path, opMethod, i) + if m == obj { + return path2 // found interface method + } + if r := find(obj, m.Type(), append(path2, opType), seen); r != nil { + return r + } + } + return nil + case *typeparams.TypeParam: + name := T.Obj() + if name == obj { + return append(path, opObj) + } + if seen[name] { + return nil + } + if seen == nil { + seen = make(map[*types.TypeName]bool) + } + seen[name] = true + if r := find(obj, T.Constraint(), append(path, opConstraint), seen); r != nil { + return r + } + return nil + } + panic(T) +} + +func findTypeParam(obj types.Object, list *typeparams.TypeParamList, path []byte, seen map[*types.TypeName]bool) []byte { + for i := 0; i < list.Len(); i++ { + tparam := list.At(i) + path2 := appendOpArg(path, opTypeParam, i) + if r := find(obj, tparam, path2, seen); r != nil { + return r + } + } + return nil +} + +// Object returns the object denoted by path p within the package pkg. +func Object(pkg *types.Package, p Path) (types.Object, error) { + if p == "" { + return nil, fmt.Errorf("empty path") + } + + pathstr := string(p) + var pkgobj, suffix string + if dot := strings.IndexByte(pathstr, opType); dot < 0 { + pkgobj = pathstr + } else { + pkgobj = pathstr[:dot] + suffix = pathstr[dot:] // suffix starts with "." + } + + obj := pkg.Scope().Lookup(pkgobj) + if obj == nil { + return nil, fmt.Errorf("package %s does not contain %q", pkg.Path(), pkgobj) + } + + // abstraction of *types.{Pointer,Slice,Array,Chan,Map} + type hasElem interface { + Elem() types.Type + } + // abstraction of *types.{Named,Signature} + type hasTypeParams interface { + TypeParams() *typeparams.TypeParamList + } + // abstraction of *types.{Named,TypeParam} + type hasObj interface { + Obj() *types.TypeName + } + + // The loop state is the pair (t, obj), + // exactly one of which is non-nil, initially obj. + // All suffixes start with '.' (the only object->type operation), + // followed by optional type->type operations, + // then a type->object operation. + // The cycle then repeats. + var t types.Type + for suffix != "" { + code := suffix[0] + suffix = suffix[1:] + + // Codes [AFM] have an integer operand. + var index int + switch code { + case opAt, opField, opMethod, opTypeParam: + rest := strings.TrimLeft(suffix, "0123456789") + numerals := suffix[:len(suffix)-len(rest)] + suffix = rest + i, err := strconv.Atoi(numerals) + if err != nil { + return nil, fmt.Errorf("invalid path: bad numeric operand %q for code %q", numerals, code) + } + index = int(i) + case opObj: + // no operand + default: + // The suffix must end with a type->object operation. + if suffix == "" { + return nil, fmt.Errorf("invalid path: ends with %q, want [AFMO]", code) + } + } + + if code == opType { + if t != nil { + return nil, fmt.Errorf("invalid path: unexpected %q in type context", opType) + } + t = obj.Type() + obj = nil + continue + } + + if t == nil { + return nil, fmt.Errorf("invalid path: code %q in object context", code) + } + + // Inv: t != nil, obj == nil + + switch code { + case opElem: + hasElem, ok := t.(hasElem) // Pointer, Slice, Array, Chan, Map + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want pointer, slice, array, chan or map)", code, t, t) + } + t = hasElem.Elem() + + case opKey: + mapType, ok := t.(*types.Map) + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want map)", code, t, t) + } + t = mapType.Key() + + case opParams: + sig, ok := t.(*types.Signature) + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want signature)", code, t, t) + } + t = sig.Params() + + case opResults: + sig, ok := t.(*types.Signature) + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want signature)", code, t, t) + } + t = sig.Results() + + case opUnderlying: + named, ok := t.(*types.Named) + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want named)", code, t, t) + } + t = named.Underlying() + + case opTypeParam: + hasTypeParams, ok := t.(hasTypeParams) // Named, Signature + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want named or signature)", code, t, t) + } + tparams := hasTypeParams.TypeParams() + if n := tparams.Len(); index >= n { + return nil, fmt.Errorf("tuple index %d out of range [0-%d)", index, n) + } + t = tparams.At(index) + + case opConstraint: + tparam, ok := t.(*typeparams.TypeParam) + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want type parameter)", code, t, t) + } + t = tparam.Constraint() + + case opAt: + tuple, ok := t.(*types.Tuple) + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want tuple)", code, t, t) + } + if n := tuple.Len(); index >= n { + return nil, fmt.Errorf("tuple index %d out of range [0-%d)", index, n) + } + obj = tuple.At(index) + t = nil + + case opField: + structType, ok := t.(*types.Struct) + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want struct)", code, t, t) + } + if n := structType.NumFields(); index >= n { + return nil, fmt.Errorf("field index %d out of range [0-%d)", index, n) + } + obj = structType.Field(index) + t = nil + + case opMethod: + switch t := t.(type) { + case *types.Interface: + if index >= t.NumMethods() { + return nil, fmt.Errorf("method index %d out of range [0-%d)", index, t.NumMethods()) + } + obj = t.Method(index) // Id-ordered + + case *types.Named: + methods := namedMethods(t) // (unmemoized) + if index >= len(methods) { + return nil, fmt.Errorf("method index %d out of range [0-%d)", index, len(methods)) + } + obj = methods[index] // Id-ordered + + default: + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want interface or named)", code, t, t) + } + t = nil + + case opObj: + hasObj, ok := t.(hasObj) + if !ok { + return nil, fmt.Errorf("cannot apply %q to %s (got %T, want named or type param)", code, t, t) + } + obj = hasObj.Obj() + t = nil + + default: + return nil, fmt.Errorf("invalid path: unknown code %q", code) + } + } + + if obj.Pkg() != pkg { + return nil, fmt.Errorf("path denotes %s, which belongs to a different package", obj) + } + + return obj, nil // success +} + +// namedMethods returns the methods of a Named type in ascending Id order. +func namedMethods(named *types.Named) []*types.Func { + methods := make([]*types.Func, named.NumMethods()) + for i := range methods { + methods[i] = named.Method(i) + } + sort.Slice(methods, func(i, j int) bool { + return methods[i].Id() < methods[j].Id() + }) + return methods +} + +// namedMethods is a memoization of the namedMethods function. Callers must not modify the result. +func (enc *Encoder) namedMethods(named *types.Named) []*types.Func { + m := enc.namedMethodsMemo + if m == nil { + m = make(map[*types.Named][]*types.Func) + enc.namedMethodsMemo = m + } + methods, ok := m[named] + if !ok { + methods = namedMethods(named) // allocates and sorts + m[named] = methods + } + return methods +} + +// scopeNames is a memoization of scope.Names. Callers must not modify the result. +func (enc *Encoder) scopeNames(scope *types.Scope) []string { + m := enc.scopeNamesMemo + if m == nil { + m = make(map[*types.Scope][]string) + enc.scopeNamesMemo = m + } + names, ok := m[scope] + if !ok { + names = scope.Names() // allocates and sorts + m[scope] = names + } + return names +} diff --git a/vendor/golang.org/x/tools/internal/gcimporter/gcimporter.go b/vendor/golang.org/x/tools/internal/gcimporter/gcimporter.go index 0372fb3a..a973dece 100644 --- a/vendor/golang.org/x/tools/internal/gcimporter/gcimporter.go +++ b/vendor/golang.org/x/tools/internal/gcimporter/gcimporter.go @@ -7,6 +7,18 @@ // Package gcimporter provides various functions for reading // gc-generated object files that can be used to implement the // Importer interface defined by the Go 1.5 standard library package. +// +// The encoding is deterministic: if the encoder is applied twice to +// the same types.Package data structure, both encodings are equal. +// This property may be important to avoid spurious changes in +// applications such as build systems. +// +// However, the encoder is not necessarily idempotent. Importing an +// exported package may yield a types.Package that, while it +// represents the same set of Go types as the original, may differ in +// the details of its internal representation. Because of these +// differences, re-encoding the imported package may yield a +// different, but equally valid, encoding of the package. package gcimporter // import "golang.org/x/tools/internal/gcimporter" import ( diff --git a/vendor/golang.org/x/tools/internal/gcimporter/iexport.go b/vendor/golang.org/x/tools/internal/gcimporter/iexport.go index ba53cdcd..a0dc0b5e 100644 --- a/vendor/golang.org/x/tools/internal/gcimporter/iexport.go +++ b/vendor/golang.org/x/tools/internal/gcimporter/iexport.go @@ -44,12 +44,12 @@ func IExportShallow(fset *token.FileSet, pkg *types.Package) ([]byte, error) { return out.Bytes(), err } -// IImportShallow decodes "shallow" types.Package data encoded by IExportShallow -// in the same executable. This function cannot import data from +// IImportShallow decodes "shallow" types.Package data encoded by +// IExportShallow in the same executable. This function cannot import data from // cmd/compile or gcexportdata.Write. -func IImportShallow(fset *token.FileSet, imports map[string]*types.Package, data []byte, path string, insert InsertType) (*types.Package, error) { +func IImportShallow(fset *token.FileSet, getPackage GetPackageFunc, data []byte, path string, insert InsertType) (*types.Package, error) { const bundle = false - pkgs, err := iimportCommon(fset, imports, data, bundle, path, insert) + pkgs, err := iimportCommon(fset, getPackage, data, bundle, path, insert) if err != nil { return nil, err } diff --git a/vendor/golang.org/x/tools/internal/gcimporter/iimport.go b/vendor/golang.org/x/tools/internal/gcimporter/iimport.go index 448f903e..be6dace1 100644 --- a/vendor/golang.org/x/tools/internal/gcimporter/iimport.go +++ b/vendor/golang.org/x/tools/internal/gcimporter/iimport.go @@ -85,7 +85,7 @@ const ( // If the export data version is not recognized or the format is otherwise // compromised, an error is returned. func IImportData(fset *token.FileSet, imports map[string]*types.Package, data []byte, path string) (int, *types.Package, error) { - pkgs, err := iimportCommon(fset, imports, data, false, path, nil) + pkgs, err := iimportCommon(fset, GetPackageFromMap(imports), data, false, path, nil) if err != nil { return 0, nil, err } @@ -94,10 +94,33 @@ func IImportData(fset *token.FileSet, imports map[string]*types.Package, data [] // IImportBundle imports a set of packages from the serialized package bundle. func IImportBundle(fset *token.FileSet, imports map[string]*types.Package, data []byte) ([]*types.Package, error) { - return iimportCommon(fset, imports, data, true, "", nil) + return iimportCommon(fset, GetPackageFromMap(imports), data, true, "", nil) } -func iimportCommon(fset *token.FileSet, imports map[string]*types.Package, data []byte, bundle bool, path string, insert InsertType) (pkgs []*types.Package, err error) { +// A GetPackageFunc is a function that gets the package with the given path +// from the importer state, creating it (with the specified name) if necessary. +// It is an abstraction of the map historically used to memoize package creation. +// +// Two calls with the same path must return the same package. +// +// If the given getPackage func returns nil, the import will fail. +type GetPackageFunc = func(path, name string) *types.Package + +// GetPackageFromMap returns a GetPackageFunc that retrieves packages from the +// given map of package path -> package. +// +// The resulting func may mutate m: if a requested package is not found, a new +// package will be inserted into m. +func GetPackageFromMap(m map[string]*types.Package) GetPackageFunc { + return func(path, name string) *types.Package { + if _, ok := m[path]; !ok { + m[path] = types.NewPackage(path, name) + } + return m[path] + } +} + +func iimportCommon(fset *token.FileSet, getPackage GetPackageFunc, data []byte, bundle bool, path string, insert InsertType) (pkgs []*types.Package, err error) { const currentVersion = iexportVersionCurrent version := int64(-1) if !debug { @@ -195,10 +218,9 @@ func iimportCommon(fset *token.FileSet, imports map[string]*types.Package, data if pkgPath == "" { pkgPath = path } - pkg := imports[pkgPath] + pkg := getPackage(pkgPath, pkgName) if pkg == nil { - pkg = types.NewPackage(pkgPath, pkgName) - imports[pkgPath] = pkg + errorf("internal error: getPackage returned nil package for %s", pkgPath) } else if pkg.Name() != pkgName { errorf("conflicting names %s and %s for package %q", pkg.Name(), pkgName, path) } diff --git a/vendor/golang.org/x/tools/internal/gcimporter/ureader_yes.go b/vendor/golang.org/x/tools/internal/gcimporter/ureader_yes.go index b285a11c..34fc783f 100644 --- a/vendor/golang.org/x/tools/internal/gcimporter/ureader_yes.go +++ b/vendor/golang.org/x/tools/internal/gcimporter/ureader_yes.go @@ -12,6 +12,7 @@ package gcimporter import ( "go/token" "go/types" + "sort" "strings" "golang.org/x/tools/internal/pkgbits" @@ -121,6 +122,16 @@ func readUnifiedPackage(fset *token.FileSet, ctxt *types.Context, imports map[st iface.Complete() } + // Imports() of pkg are all of the transitive packages that were loaded. + var imps []*types.Package + for _, imp := range pr.pkgs { + if imp != nil && imp != pkg { + imps = append(imps, imp) + } + } + sort.Sort(byPath(imps)) + pkg.SetImports(imps) + pkg.MarkComplete() return pkg } @@ -260,39 +271,9 @@ func (r *reader) doPkg() *types.Package { pkg := types.NewPackage(path, name) r.p.imports[path] = pkg - imports := make([]*types.Package, r.Len()) - for i := range imports { - imports[i] = r.pkg() - } - pkg.SetImports(flattenImports(imports)) - return pkg } -// flattenImports returns the transitive closure of all imported -// packages rooted from pkgs. -func flattenImports(pkgs []*types.Package) []*types.Package { - var res []*types.Package - seen := make(map[*types.Package]struct{}) - for _, pkg := range pkgs { - if _, ok := seen[pkg]; ok { - continue - } - seen[pkg] = struct{}{} - res = append(res, pkg) - - // pkg.Imports() is already flattened. - for _, pkg := range pkg.Imports() { - if _, ok := seen[pkg]; ok { - continue - } - seen[pkg] = struct{}{} - res = append(res, pkg) - } - } - return res -} - // @@@ Types func (r *reader) typ() types.Type { diff --git a/vendor/golang.org/x/tools/internal/gocommand/invoke.go b/vendor/golang.org/x/tools/internal/gocommand/invoke.go index d5055169..3c0afe72 100644 --- a/vendor/golang.org/x/tools/internal/gocommand/invoke.go +++ b/vendor/golang.org/x/tools/internal/gocommand/invoke.go @@ -8,10 +8,12 @@ package gocommand import ( "bytes" "context" + "errors" "fmt" "io" "log" "os" + "reflect" "regexp" "runtime" "strconv" @@ -215,6 +217,18 @@ func (i *Invocation) run(ctx context.Context, stdout, stderr io.Writer) error { cmd := exec.Command("go", goArgs...) cmd.Stdout = stdout cmd.Stderr = stderr + + // cmd.WaitDelay was added only in go1.20 (see #50436). + if waitDelay := reflect.ValueOf(cmd).Elem().FieldByName("WaitDelay"); waitDelay.IsValid() { + // https://go.dev/issue/59541: don't wait forever copying stderr + // after the command has exited. + // After CL 484741 we copy stdout manually, so we we'll stop reading that as + // soon as ctx is done. However, we also don't want to wait around forever + // for stderr. Give a much-longer-than-reasonable delay and then assume that + // something has wedged in the kernel or runtime. + waitDelay.Set(reflect.ValueOf(30 * time.Second)) + } + // On darwin the cwd gets resolved to the real path, which breaks anything that // expects the working directory to keep the original path, including the // go command when dealing with modules. @@ -229,6 +243,7 @@ func (i *Invocation) run(ctx context.Context, stdout, stderr io.Writer) error { cmd.Env = append(cmd.Env, "PWD="+i.WorkingDir) cmd.Dir = i.WorkingDir } + defer func(start time.Time) { log("%s for %v", time.Since(start), cmdDebugStr(cmd)) }(time.Now()) return runCmdContext(ctx, cmd) @@ -242,10 +257,85 @@ var DebugHangingGoCommands = false // runCmdContext is like exec.CommandContext except it sends os.Interrupt // before os.Kill. -func runCmdContext(ctx context.Context, cmd *exec.Cmd) error { - if err := cmd.Start(); err != nil { +func runCmdContext(ctx context.Context, cmd *exec.Cmd) (err error) { + // If cmd.Stdout is not an *os.File, the exec package will create a pipe and + // copy it to the Writer in a goroutine until the process has finished and + // either the pipe reaches EOF or command's WaitDelay expires. + // + // However, the output from 'go list' can be quite large, and we don't want to + // keep reading (and allocating buffers) if we've already decided we don't + // care about the output. We don't want to wait for the process to finish, and + // we don't wait to wait for the WaitDelay to expire either. + // + // Instead, if cmd.Stdout requires a copying goroutine we explicitly replace + // it with a pipe (which is an *os.File), which we can close in order to stop + // copying output as soon as we realize we don't care about it. + var stdoutW *os.File + if cmd.Stdout != nil { + if _, ok := cmd.Stdout.(*os.File); !ok { + var stdoutR *os.File + stdoutR, stdoutW, err = os.Pipe() + if err != nil { + return err + } + prevStdout := cmd.Stdout + cmd.Stdout = stdoutW + + stdoutErr := make(chan error, 1) + go func() { + _, err := io.Copy(prevStdout, stdoutR) + if err != nil { + err = fmt.Errorf("copying stdout: %w", err) + } + stdoutErr <- err + }() + defer func() { + // We started a goroutine to copy a stdout pipe. + // Wait for it to finish, or terminate it if need be. + var err2 error + select { + case err2 = <-stdoutErr: + stdoutR.Close() + case <-ctx.Done(): + stdoutR.Close() + // Per https://pkg.go.dev/os#File.Close, the call to stdoutR.Close + // should cause the Read call in io.Copy to unblock and return + // immediately, but we still need to receive from stdoutErr to confirm + // that that has happened. + <-stdoutErr + err2 = ctx.Err() + } + if err == nil { + err = err2 + } + }() + + // Per https://pkg.go.dev/os/exec#Cmd, “If Stdout and Stderr are the + // same writer, and have a type that can be compared with ==, at most + // one goroutine at a time will call Write.” + // + // Since we're starting a goroutine that writes to cmd.Stdout, we must + // also update cmd.Stderr so that that still holds. + func() { + defer func() { recover() }() + if cmd.Stderr == prevStdout { + cmd.Stderr = cmd.Stdout + } + }() + } + } + + err = cmd.Start() + if stdoutW != nil { + // The child process has inherited the pipe file, + // so close the copy held in this process. + stdoutW.Close() + stdoutW = nil + } + if err != nil { return err } + resChan := make(chan error, 1) go func() { resChan <- cmd.Wait() @@ -253,11 +343,14 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) error { // If we're interested in debugging hanging Go commands, stop waiting after a // minute and panic with interesting information. - if DebugHangingGoCommands { + debug := DebugHangingGoCommands + if debug { + timer := time.NewTimer(1 * time.Minute) + defer timer.Stop() select { case err := <-resChan: return err - case <-time.After(1 * time.Minute): + case <-timer.C: HandleHangingGoCommand(cmd.Process) case <-ctx.Done(): } @@ -270,30 +363,25 @@ func runCmdContext(ctx context.Context, cmd *exec.Cmd) error { } // Cancelled. Interrupt and see if it ends voluntarily. - cmd.Process.Signal(os.Interrupt) - select { - case err := <-resChan: - return err - case <-time.After(time.Second): + if err := cmd.Process.Signal(os.Interrupt); err == nil { + // (We used to wait only 1s but this proved + // fragile on loaded builder machines.) + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + select { + case err := <-resChan: + return err + case <-timer.C: + } } // Didn't shut down in response to interrupt. Kill it hard. // TODO(rfindley): per advice from bcmills@, it may be better to send SIGQUIT // on certain platforms, such as unix. - if err := cmd.Process.Kill(); err != nil && DebugHangingGoCommands { - // Don't panic here as this reliably fails on windows with EINVAL. + if err := cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) && debug { log.Printf("error killing the Go command: %v", err) } - // See above: don't wait indefinitely if we're debugging hanging Go commands. - if DebugHangingGoCommands { - select { - case err := <-resChan: - return err - case <-time.After(10 * time.Second): // a shorter wait as resChan should return quickly following Kill - HandleHangingGoCommand(cmd.Process) - } - } return <-resChan } diff --git a/vendor/golang.org/x/tools/internal/gocommand/version.go b/vendor/golang.org/x/tools/internal/gocommand/version.go index 307a76d4..446c5846 100644 --- a/vendor/golang.org/x/tools/internal/gocommand/version.go +++ b/vendor/golang.org/x/tools/internal/gocommand/version.go @@ -23,21 +23,11 @@ import ( func GoVersion(ctx context.Context, inv Invocation, r *Runner) (int, error) { inv.Verb = "list" inv.Args = []string{"-e", "-f", `{{context.ReleaseTags}}`, `--`, `unsafe`} - inv.Env = append(append([]string{}, inv.Env...), "GO111MODULE=off") - // Unset any unneeded flags, and remove them from BuildFlags, if they're - // present. - inv.ModFile = "" + inv.BuildFlags = nil // This is not a build command. inv.ModFlag = "" - var buildFlags []string - for _, flag := range inv.BuildFlags { - // Flags can be prefixed by one or two dashes. - f := strings.TrimPrefix(strings.TrimPrefix(flag, "-"), "-") - if strings.HasPrefix(f, "mod=") || strings.HasPrefix(f, "modfile=") { - continue - } - buildFlags = append(buildFlags, flag) - } - inv.BuildFlags = buildFlags + inv.ModFile = "" + inv.Env = append(inv.Env[:len(inv.Env):len(inv.Env)], "GO111MODULE=off") + stdoutBytes, err := r.Run(ctx, inv) if err != nil { return 0, err diff --git a/vendor/golang.org/x/tools/internal/imports/fix.go b/vendor/golang.org/x/tools/internal/imports/fix.go index 642a5ac2..6b493525 100644 --- a/vendor/golang.org/x/tools/internal/imports/fix.go +++ b/vendor/golang.org/x/tools/internal/imports/fix.go @@ -414,9 +414,16 @@ func (p *pass) fix() ([]*ImportFix, bool) { }) } } + // Collecting fixes involved map iteration, so sort for stability. See + // golang/go#59976. + sortFixes(fixes) + // collect selected fixes in a separate slice, so that it can be sorted + // separately. Note that these fixes must occur after fixes to existing + // imports. TODO(rfindley): figure out why. + var selectedFixes []*ImportFix for _, imp := range selected { - fixes = append(fixes, &ImportFix{ + selectedFixes = append(selectedFixes, &ImportFix{ StmtInfo: ImportInfo{ Name: p.importSpecName(imp), ImportPath: imp.ImportPath, @@ -425,8 +432,25 @@ func (p *pass) fix() ([]*ImportFix, bool) { FixType: AddImport, }) } + sortFixes(selectedFixes) - return fixes, true + return append(fixes, selectedFixes...), true +} + +func sortFixes(fixes []*ImportFix) { + sort.Slice(fixes, func(i, j int) bool { + fi, fj := fixes[i], fixes[j] + if fi.StmtInfo.ImportPath != fj.StmtInfo.ImportPath { + return fi.StmtInfo.ImportPath < fj.StmtInfo.ImportPath + } + if fi.StmtInfo.Name != fj.StmtInfo.Name { + return fi.StmtInfo.Name < fj.StmtInfo.Name + } + if fi.IdentName != fj.IdentName { + return fi.IdentName < fj.IdentName + } + return fi.FixType < fj.FixType + }) } // importSpecName gets the import name of imp in the import spec. diff --git a/vendor/golang.org/x/tools/internal/tokeninternal/tokeninternal.go b/vendor/golang.org/x/tools/internal/tokeninternal/tokeninternal.go index a3fb2d4f..7e638ec2 100644 --- a/vendor/golang.org/x/tools/internal/tokeninternal/tokeninternal.go +++ b/vendor/golang.org/x/tools/internal/tokeninternal/tokeninternal.go @@ -7,7 +7,9 @@ package tokeninternal import ( + "fmt" "go/token" + "sort" "sync" "unsafe" ) @@ -57,3 +59,93 @@ func GetLines(file *token.File) []int { panic("unexpected token.File size") } } + +// AddExistingFiles adds the specified files to the FileSet if they +// are not already present. It panics if any pair of files in the +// resulting FileSet would overlap. +func AddExistingFiles(fset *token.FileSet, files []*token.File) { + // Punch through the FileSet encapsulation. + type tokenFileSet struct { + // This type remained essentially consistent from go1.16 to go1.21. + mutex sync.RWMutex + base int + files []*token.File + _ *token.File // changed to atomic.Pointer[token.File] in go1.19 + } + + // If the size of token.FileSet changes, this will fail to compile. + const delta = int64(unsafe.Sizeof(tokenFileSet{})) - int64(unsafe.Sizeof(token.FileSet{})) + var _ [-delta * delta]int + + type uP = unsafe.Pointer + var ptr *tokenFileSet + *(*uP)(uP(&ptr)) = uP(fset) + ptr.mutex.Lock() + defer ptr.mutex.Unlock() + + // Merge and sort. + newFiles := append(ptr.files, files...) + sort.Slice(newFiles, func(i, j int) bool { + return newFiles[i].Base() < newFiles[j].Base() + }) + + // Reject overlapping files. + // Discard adjacent identical files. + out := newFiles[:0] + for i, file := range newFiles { + if i > 0 { + prev := newFiles[i-1] + if file == prev { + continue + } + if prev.Base()+prev.Size()+1 > file.Base() { + panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)", + prev.Name(), prev.Base(), prev.Base()+prev.Size(), + file.Name(), file.Base(), file.Base()+file.Size())) + } + } + out = append(out, file) + } + newFiles = out + + ptr.files = newFiles + + // Advance FileSet.Base(). + if len(newFiles) > 0 { + last := newFiles[len(newFiles)-1] + newBase := last.Base() + last.Size() + 1 + if ptr.base < newBase { + ptr.base = newBase + } + } +} + +// FileSetFor returns a new FileSet containing a sequence of new Files with +// the same base, size, and line as the input files, for use in APIs that +// require a FileSet. +// +// Precondition: the input files must be non-overlapping, and sorted in order +// of their Base. +func FileSetFor(files ...*token.File) *token.FileSet { + fset := token.NewFileSet() + for _, f := range files { + f2 := fset.AddFile(f.Name(), f.Base(), f.Size()) + lines := GetLines(f) + f2.SetLines(lines) + } + return fset +} + +// CloneFileSet creates a new FileSet holding all files in fset. It does not +// create copies of the token.Files in fset: they are added to the resulting +// FileSet unmodified. +func CloneFileSet(fset *token.FileSet) *token.FileSet { + var files []*token.File + fset.Iterate(func(f *token.File) bool { + files = append(files, f) + return true + }) + newFileSet := token.NewFileSet() + AddExistingFiles(newFileSet, files) + return newFileSet +} diff --git a/vendor/golang.org/x/tools/internal/typeparams/common.go b/vendor/golang.org/x/tools/internal/typeparams/common.go index 25a1426d..cfba8189 100644 --- a/vendor/golang.org/x/tools/internal/typeparams/common.go +++ b/vendor/golang.org/x/tools/internal/typeparams/common.go @@ -87,7 +87,6 @@ func IsTypeParam(t types.Type) bool { func OriginMethod(fn *types.Func) *types.Func { recv := fn.Type().(*types.Signature).Recv() if recv == nil { - return fn } base := recv.Type() diff --git a/vendor/golang.org/x/tools/internal/typesinternal/types.go b/vendor/golang.org/x/tools/internal/typesinternal/types.go index ce7d4351..3c53fbc6 100644 --- a/vendor/golang.org/x/tools/internal/typesinternal/types.go +++ b/vendor/golang.org/x/tools/internal/typesinternal/types.go @@ -11,6 +11,8 @@ import ( "go/types" "reflect" "unsafe" + + "golang.org/x/tools/go/types/objectpath" ) func SetUsesCgo(conf *types.Config) bool { @@ -50,3 +52,10 @@ func ReadGo116ErrorData(err types.Error) (code ErrorCode, start, end token.Pos, } var SetGoVersion = func(conf *types.Config, version string) bool { return false } + +// NewObjectpathEncoder returns a function closure equivalent to +// objectpath.For but amortized for multiple (sequential) calls. +// It is a temporary workaround, pending the approval of proposal 58668. +// +//go:linkname NewObjectpathFunc golang.org/x/tools/go/types/objectpath.newEncoderFor +func NewObjectpathFunc() func(types.Object) (objectpath.Path, error) diff --git a/vendor/modules.txt b/vendor/modules.txt index 6cdd771b..c9531586 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -15,14 +15,14 @@ github.com/davecgh/go-spew/spew # github.com/dchest/safefile v0.0.0-20151022103144-855e8d98f185 ## explicit github.com/dchest/safefile -# github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 +# github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 ## explicit; go 1.13 github.com/go-task/slim-sprig # github.com/golang/mock v1.6.0 ## explicit; go 1.11 github.com/golang/mock/mockgen github.com/golang/mock/mockgen/model -# github.com/golang/protobuf v1.5.2 +# github.com/golang/protobuf v1.5.3 ## explicit; go 1.9 github.com/golang/protobuf/proto github.com/golang/protobuf/ptypes @@ -73,7 +73,7 @@ github.com/kardianos/service # github.com/miekg/dns v1.1.55 ## explicit; go 1.19 github.com/miekg/dns -# github.com/onsi/ginkgo/v2 v2.2.0 +# github.com/onsi/ginkgo/v2 v2.9.5 ## explicit; go 1.18 github.com/onsi/ginkgo/v2/config github.com/onsi/ginkgo/v2/formatter @@ -112,7 +112,7 @@ github.com/quic-go/qtls-go1-19 # github.com/quic-go/qtls-go1-20 v0.2.2 ## explicit; go 1.20 github.com/quic-go/qtls-go1-20 -# github.com/quic-go/quic-go v0.34.0 +# github.com/quic-go/quic-go v0.36.1 ## explicit; go 1.19 github.com/quic-go/quic-go github.com/quic-go/quic-go/http3 @@ -126,6 +126,7 @@ github.com/quic-go/quic-go/internal/qerr github.com/quic-go/quic-go/internal/qtls github.com/quic-go/quic-go/internal/utils github.com/quic-go/quic-go/internal/utils/linkedlist +github.com/quic-go/quic-go/internal/utils/ringbuffer github.com/quic-go/quic-go/internal/wire github.com/quic-go/quic-go/logging github.com/quic-go/quic-go/quicvarint @@ -153,7 +154,7 @@ golang.org/x/crypto/salsa20/salsa # golang.org/x/exp v0.0.0-20221205204356-47842c84f3db ## explicit; go 1.18 golang.org/x/exp/constraints -# golang.org/x/mod v0.8.0 +# golang.org/x/mod v0.10.0 ## explicit; go 1.17 golang.org/x/mod/internal/lazyregexp golang.org/x/mod/modfile @@ -189,13 +190,14 @@ golang.org/x/text/secure/bidirule golang.org/x/text/transform golang.org/x/text/unicode/bidi golang.org/x/text/unicode/norm -# golang.org/x/tools v0.6.0 +# golang.org/x/tools v0.9.1 ## explicit; go 1.18 golang.org/x/tools/go/ast/astutil golang.org/x/tools/go/ast/inspector golang.org/x/tools/go/gcexportdata golang.org/x/tools/go/internal/packagesdriver golang.org/x/tools/go/packages +golang.org/x/tools/go/types/objectpath golang.org/x/tools/imports golang.org/x/tools/internal/event golang.org/x/tools/internal/event/core