Update deps

This commit is contained in:
Frank Denis 2023-10-12 15:53:53 +02:00
parent 0232870870
commit b37a5c991a
161 changed files with 3502 additions and 3245 deletions

14
go.mod
View File

@ -1,6 +1,6 @@
module github.com/dnscrypt/dnscrypt-proxy module github.com/dnscrypt/dnscrypt-proxy
go 1.21.1 go 1.21.3
require ( require (
github.com/BurntSushi/toml v1.3.2 github.com/BurntSushi/toml v1.3.2
@ -20,17 +20,16 @@ require (
github.com/kardianos/service v1.2.2 github.com/kardianos/service v1.2.2
github.com/miekg/dns v1.1.56 github.com/miekg/dns v1.1.56
github.com/powerman/check v1.7.0 github.com/powerman/check v1.7.0
github.com/quic-go/quic-go v0.38.1 github.com/quic-go/quic-go v0.39.0
golang.org/x/crypto v0.13.0 golang.org/x/crypto v0.14.0
golang.org/x/net v0.15.0 golang.org/x/net v0.17.0
golang.org/x/sys v0.12.0 golang.org/x/sys v0.13.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1
) )
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/hashicorp/go-syslog v1.0.0 // indirect github.com/hashicorp/go-syslog v1.0.0 // indirect
@ -39,8 +38,9 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/powerman/deepequal v0.1.0 // indirect github.com/powerman/deepequal v0.1.0 // indirect
github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qtls-go1-20 v0.3.3 // indirect github.com/quic-go/qtls-go1-20 v0.3.4 // indirect
github.com/smartystreets/goconvey v1.7.2 // indirect github.com/smartystreets/goconvey v1.7.2 // indirect
go.uber.org/mock v0.3.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.12.0 // indirect golang.org/x/mod v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect

44
go.sum
View File

@ -16,8 +16,6 @@ github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
@ -74,10 +72,10 @@ github.com/powerman/deepequal v0.1.0 h1:sVwtyTsBuYIvdbLR1O2wzRY63YgPqdGZmk/o80l+
github.com/powerman/deepequal v0.1.0/go.mod h1:3k7aG/slufBhUANdN67o/UPg8i5YaiJ6FmibWX0cn04= github.com/powerman/deepequal v0.1.0/go.mod h1:3k7aG/slufBhUANdN67o/UPg8i5YaiJ6FmibWX0cn04=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/qtls-go1-20 v0.3.3 h1:17/glZSLI9P9fDAeyCHBFSWSqJcwx1byhLwP5eUIDCM= github.com/quic-go/qtls-go1-20 v0.3.4 h1:MfFAPULvst4yoMgY9QmtpYmfij/em7O8UUi+bNVm7Cg=
github.com/quic-go/qtls-go1-20 v0.3.3/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/qtls-go1-20 v0.3.4/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/quic-go/quic-go v0.38.1 h1:M36YWA5dEhEeT+slOu/SwMEucbYd0YFidxG3KlGPZaE= github.com/quic-go/quic-go v0.39.0 h1:AgP40iThFMY0bj8jGxROhw3S0FMGa8ryqsmi9tBH3So=
github.com/quic-go/quic-go v0.38.1/go.mod h1:ijnZM7JsFIkp4cRyjxJNIzdSfCLmUMg9wdyhGmg+SN4= github.com/quic-go/quic-go v0.39.0/go.mod h1:T09QsDQWjLiQ74ZmacDfqZmhY/NLnw5BC40MANNNZ1Q=
github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs= github.com/smartystreets/assertions v1.2.0 h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs=
github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/assertions v1.2.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg= github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
@ -85,51 +83,33 @@ github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o=
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f h1:BWUVssLB0HVOSY78gIdvk1dTVYtT1y8SBWtPYuTJ/6w= google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f h1:BWUVssLB0HVOSY78gIdvk1dTVYtT1y8SBWtPYuTJ/6w=
google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM=
google.golang.org/grpc v1.53.0 h1:LAv2ds7cmFV/XTS3XG1NneeENYrXGmorPxsBbptIjNc= google.golang.org/grpc v1.53.0 h1:LAv2ds7cmFV/XTS3XG1NneeENYrXGmorPxsBbptIjNc=

View File

@ -1,26 +0,0 @@
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !go1.12
package main
import (
"log"
)
func printModuleVersion() {
log.Printf("No version information is available for Mockgen compiled with " +
"version 1.11")
}

View File

@ -230,16 +230,22 @@ func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
return nil return nil
} }
// The handshake goroutine has exited. // The handshake goroutine has exited.
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
c.hand.Write(c.quic.readbuf) c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil c.quic.readbuf = nil
for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil { for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
b := q.conn.hand.Bytes() b := q.conn.hand.Bytes()
n := int(b[1])<<16 | int(b[2])<<8 | int(b[3]) n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if 4+n < len(b) { if n > maxHandshake {
q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
break
}
if len(b) < 4+n {
return nil return nil
} }
if err := q.conn.handlePostHandshakeMessage(); err != nil { if err := q.conn.handlePostHandshakeMessage(); err != nil {
return quicError(err) q.conn.handshakeErr = err
} }
} }
if q.conn.handshakeErr != nil { if q.conn.handshakeErr != nil {

View File

@ -200,12 +200,13 @@ http.Client{
## Projects using quic-go ## Projects using quic-go
| Project | Description | Stars | | Project | Description | Stars |
|-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------| | --------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) | | [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) |
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) | | [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) | | [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) | | [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) | | [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) | | [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) | | [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) | | [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |

View File

@ -34,7 +34,7 @@ type client struct {
conn quicConn conn quicConn
tracer logging.ConnectionTracer tracer *logging.ConnectionTracer
tracingID uint64 tracingID uint64
logger utils.Logger logger utils.Logger
} }
@ -153,7 +153,7 @@ func dial(
if c.config.Tracer != nil { if c.config.Tracer != nil {
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
} }
if c.tracer != nil { if c.tracer != nil && c.tracer.StartedConnection != nil {
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
} }
if err := c.dial(ctx); err != nil { if err := c.dial(ctx); err != nil {
@ -233,7 +233,7 @@ func (c *client) dial(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.conn.shutdown() c.conn.shutdown()
return ctx.Err() return context.Cause(ctx)
case err := <-errorChan: case err := <-errorChan:
return err return err
case recreateErr := <-recreateChan: case recreateErr := <-recreateChan:

View File

@ -1,19 +1,10 @@
coverage: coverage:
round: nearest round: nearest
ignore: ignore:
- streams_map_incoming_bidi.go
- streams_map_incoming_uni.go
- streams_map_outgoing_bidi.go
- streams_map_outgoing_uni.go
- http3/gzip_reader.go - http3/gzip_reader.go
- interop/ - interop/
- internal/ackhandler/packet_linkedlist.go
- internal/handshake/cipher_suite.go - internal/handshake/cipher_suite.go
- internal/utils/byteinterval_linkedlist.go
- internal/utils/newconnectionid_linkedlist.go
- internal/utils/packetinterval_linkedlist.go
- internal/utils/linkedlist/linkedlist.go - internal/utils/linkedlist/linkedlist.go
- logging/null_tracer.go
- fuzzing/ - fuzzing/
- metrics/ - metrics/
status: status:

View File

@ -6,7 +6,6 @@ import (
"time" "time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/quicvarint"
) )
@ -17,7 +16,11 @@ func (c *Config) Clone() *Config {
} }
func (c *Config) handshakeTimeout() time.Duration { func (c *Config) handshakeTimeout() time.Duration {
return utils.Max(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout) return 2 * c.HandshakeIdleTimeout
}
func (c *Config) maxRetryTokenAge() time.Duration {
return c.handshakeTimeout()
} }
func validateConfig(config *Config) error { func validateConfig(config *Config) error {
@ -50,12 +53,6 @@ func validateConfig(config *Config) error {
// it may be called with nil // it may be called with nil
func populateServerConfig(config *Config) *Config { func populateServerConfig(config *Config) *Config {
config = populateConfig(config) config = populateConfig(config)
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
if config.MaxRetryTokenAge == 0 {
config.MaxRetryTokenAge = protocol.RetryTokenValidity
}
if config.RequireAddressValidation == nil { if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return false } config.RequireAddressValidation = func(net.Addr) bool { return false }
} }
@ -110,26 +107,23 @@ func populateConfig(config *Config) *Config {
} }
return &Config{ return &Config{
GetConfigForClient: config.GetConfigForClient, GetConfigForClient: config.GetConfigForClient,
Versions: versions, Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout, HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout, MaxIdleTimeout: idleTimeout,
MaxTokenAge: config.MaxTokenAge, RequireAddressValidation: config.RequireAddressValidation,
MaxRetryTokenAge: config.MaxRetryTokenAge, KeepAlivePeriod: config.KeepAlivePeriod,
RequireAddressValidation: config.RequireAddressValidation, InitialStreamReceiveWindow: initialStreamReceiveWindow,
KeepAlivePeriod: config.KeepAlivePeriod, MaxStreamReceiveWindow: maxStreamReceiveWindow,
InitialStreamReceiveWindow: initialStreamReceiveWindow, InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow, MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
InitialConnectionReceiveWindow: initialConnectionReceiveWindow, AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxConnectionReceiveWindow: maxConnectionReceiveWindow, MaxIncomingStreams: maxIncomingStreams,
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease, MaxIncomingUniStreams: maxIncomingUniStreams,
MaxIncomingStreams: maxIncomingStreams, TokenStore: config.TokenStore,
MaxIncomingUniStreams: maxIncomingUniStreams, EnableDatagrams: config.EnableDatagrams,
TokenStore: config.TokenStore, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
EnableDatagrams: config.EnableDatagrams, Allow0RTT: config.Allow0RTT,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, Tracer: config.Tracer,
DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
} }
} }

View File

@ -208,7 +208,7 @@ type connection struct {
connState ConnectionState connState ConnectionState
logID string logID string
tracer logging.ConnectionTracer tracer *logging.ConnectionTracer
logger utils.Logger logger utils.Logger
} }
@ -232,7 +232,7 @@ var newConnection = func(
tlsConf *tls.Config, tlsConf *tls.Config,
tokenGenerator *handshake.TokenGenerator, tokenGenerator *handshake.TokenGenerator,
clientAddressValidated bool, clientAddressValidated bool,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
tracingID uint64, tracingID uint64,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
@ -243,7 +243,7 @@ var newConnection = func(
handshakeDestConnID: destConnID, handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(), srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator, tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(true), oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
tracer: tracer, tracer: tracer,
logger: logger, logger: logger,
@ -278,6 +278,7 @@ var newConnection = func(
getMaxPacketSize(s.conn.RemoteAddr()), getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats, s.rttStats,
clientAddressValidated, clientAddressValidated,
s.conn.capabilities().ECN,
s.perspective, s.perspective,
s.tracer, s.tracer,
s.logger, s.logger,
@ -310,7 +311,7 @@ var newConnection = func(
} else { } else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount params.MaxDatagramFrameSize = protocol.InvalidByteCount
} }
if s.tracer != nil { if s.tracer != nil && s.tracer.SentTransportParameters != nil {
s.tracer.SentTransportParameters(params) s.tracer.SentTransportParameters(params)
} }
cs := handshake.NewCryptoSetupServer( cs := handshake.NewCryptoSetupServer(
@ -344,7 +345,7 @@ var newClientConnection = func(
initialPacketNumber protocol.PacketNumber, initialPacketNumber protocol.PacketNumber,
enable0RTT bool, enable0RTT bool,
hasNegotiatedVersion bool, hasNegotiatedVersion bool,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
tracingID uint64, tracingID uint64,
logger utils.Logger, logger utils.Logger,
v protocol.VersionNumber, v protocol.VersionNumber,
@ -385,13 +386,14 @@ var newClientConnection = func(
initialPacketNumber, initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()), getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats, s.rttStats,
false, /* has no effect */ false, // has no effect
s.conn.capabilities().ECN,
s.perspective, s.perspective,
s.tracer, s.tracer,
s.logger, s.logger,
) )
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
oneRTTStream := newCryptoStream(true) oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{ params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -416,7 +418,7 @@ var newClientConnection = func(
} else { } else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount params.MaxDatagramFrameSize = protocol.InvalidByteCount
} }
if s.tracer != nil { if s.tracer != nil && s.tracer.SentTransportParameters != nil {
s.tracer.SentTransportParameters(params) s.tracer.SentTransportParameters(params)
} }
cs := handshake.NewCryptoSetupClient( cs := handshake.NewCryptoSetupClient(
@ -447,8 +449,8 @@ var newClientConnection = func(
} }
func (s *connection) preSetup() { func (s *connection) preSetup() {
s.initialStream = newCryptoStream(false) s.initialStream = newCryptoStream()
s.handshakeStream = newCryptoStream(false) s.handshakeStream = newCryptoStream()
s.sendQueue = newSendQueue(s.conn) s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue() s.retransmissionQueue = newRetransmissionQueue()
s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams) s.frameParser = wire.NewFrameParser(s.config.EnableDatagrams)
@ -640,8 +642,10 @@ runLoop:
s.cryptoStreamHandler.Close() s.cryptoStreamHandler.Close()
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
s.handleCloseError(&closeErr) s.handleCloseError(&closeErr)
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { if s.tracer != nil && s.tracer.Close != nil {
s.tracer.Close() if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) {
s.tracer.Close()
}
} }
s.logger.Infof("Connection %s closed.", s.logID) s.logger.Infof("Connection %s closed.", s.logID)
s.timer.Stop() s.timer.Stop()
@ -671,6 +675,7 @@ func (s *connection) ConnectionState() ConnectionState {
cs := s.cryptoStreamHandler.ConnectionState() cs := s.cryptoStreamHandler.ConnectionState()
s.connState.TLS = cs.ConnectionState s.connState.TLS = cs.ConnectionState
s.connState.Used0RTT = cs.Used0RTT s.connState.Used0RTT = cs.Used0RTT
s.connState.GSO = s.conn.capabilities().GSO
return s.connState return s.connState
} }
@ -798,14 +803,14 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool {
var err error var err error
destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
} }
s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err) s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err)
break break
} }
if destConnID != lastConnID { if destConnID != lastConnID {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
} }
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID) s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID)
@ -816,7 +821,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool {
if wire.IsLongHeaderPacket(p.data[0]) { if wire.IsLongHeaderPacket(p.data[0]) {
hdr, packetData, rest, err := wire.ParsePacket(p.data) hdr, packetData, rest, err := wire.ParsePacket(p.data)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
dropReason := logging.PacketDropHeaderParseError dropReason := logging.PacketDropHeaderParseError
if err == wire.ErrUnsupportedVersion { if err == wire.ErrUnsupportedVersion {
dropReason = logging.PacketDropUnsupportedVersion dropReason = logging.PacketDropUnsupportedVersion
@ -829,7 +834,7 @@ func (s *connection) handlePacketImpl(rp receivedPacket) bool {
lastConnID = hdr.DestConnectionID lastConnID = hdr.DestConnectionID
if hdr.Version != s.version { if hdr.Version != s.version {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion) s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion)
} }
s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version) s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version)
@ -888,14 +893,14 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc
if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) {
s.logger.Debugf("Dropping (potentially) duplicate packet.") s.logger.Debugf("Dropping (potentially) duplicate packet.")
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate)
} }
return false return false
} }
var log func([]logging.Frame) var log func([]logging.Frame)
if s.tracer != nil { if s.tracer != nil && s.tracer.ReceivedShortHeaderPacket != nil {
log = func(frames []logging.Frame) { log = func(frames []logging.Frame) {
s.tracer.ReceivedShortHeaderPacket( s.tracer.ReceivedShortHeaderPacket(
&logging.ShortHeader{ &logging.ShortHeader{
@ -905,6 +910,7 @@ func (s *connection) handleShortHeaderPacket(p receivedPacket, destConnID protoc
KeyPhase: keyPhase, KeyPhase: keyPhase,
}, },
p.Size(), p.Size(),
p.ecn,
frames, frames,
) )
} }
@ -927,13 +933,13 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
}() }()
if hdr.Type == protocol.PacketTypeRetry { if hdr.Type == protocol.PacketTypeRetry {
return s.handleRetryPacket(hdr, p.data) return s.handleRetryPacket(hdr, p.data, p.rcvTime)
} }
// The server can change the source connection ID with the first Handshake packet. // The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored. // After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID)
} }
s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID) s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID)
@ -941,7 +947,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
} }
// drop 0-RTT packets, if we are a client // drop 0-RTT packets, if we are a client
if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT { if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable) s.tracer.DroppedPacket(logging.PacketType0RTT, p.Size(), logging.PacketDropKeyUnavailable)
} }
return false return false
@ -960,7 +966,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) { if s.receivedPacketHandler.IsPotentiallyDuplicate(packet.hdr.PacketNumber, packet.encryptionLevel) {
s.logger.Debugf("Dropping (potentially) duplicate packet.") s.logger.Debugf("Dropping (potentially) duplicate packet.")
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate) s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropDuplicate)
} }
return false return false
@ -976,7 +982,7 @@ func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header)
func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) { func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool) {
switch err { switch err {
case handshake.ErrKeysDropped: case handshake.ErrKeysDropped:
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable) s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable)
} }
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size()) s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size())
@ -992,7 +998,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P
}) })
case handshake.ErrDecryptionFailed: case handshake.ErrDecryptionFailed:
// This might be a packet injected by an attacker. Drop it. // This might be a packet injected by an attacker. Drop it.
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError) s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError)
} }
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err) s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err)
@ -1000,7 +1006,7 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P
var headerErr *headerParseError var headerErr *headerParseError
if errors.As(err, &headerErr) { if errors.As(err, &headerErr) {
// This might be a packet injected by an attacker. Drop it. // This might be a packet injected by an attacker. Drop it.
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError)
} }
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err) s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err)
@ -1013,16 +1019,16 @@ func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.P
return false return false
} }
func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ { func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ {
if s.perspective == protocol.PerspectiveServer { if s.perspective == protocol.PerspectiveServer {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
} }
s.logger.Debugf("Ignoring Retry.") s.logger.Debugf("Ignoring Retry.")
return false return false
} }
if s.receivedFirstPacket { if s.receivedFirstPacket {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
} }
s.logger.Debugf("Ignoring Retry, since we already received a packet.") s.logger.Debugf("Ignoring Retry, since we already received a packet.")
@ -1030,7 +1036,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa
} }
destConnID := s.connIDManager.Get() destConnID := s.connIDManager.Get()
if hdr.SrcConnectionID == destConnID { if hdr.SrcConnectionID == destConnID {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
} }
s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
@ -1045,7 +1051,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa
tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version) tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version)
if !bytes.Equal(data[len(data)-16:], tag[:]) { if !bytes.Equal(data[len(data)-16:], tag[:]) {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError) s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError)
} }
s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.") s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.")
@ -1057,12 +1063,12 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID)
} }
if s.tracer != nil { if s.tracer != nil && s.tracer.ReceivedRetry != nil {
s.tracer.ReceivedRetry(hdr) s.tracer.ReceivedRetry(hdr)
} }
newDestConnID := hdr.SrcConnectionID newDestConnID := hdr.SrcConnectionID
s.receivedRetry = true s.receivedRetry = true
if err := s.sentPacketHandler.ResetForRetry(); err != nil { if err := s.sentPacketHandler.ResetForRetry(rcvTime); err != nil {
s.closeLocal(err) s.closeLocal(err)
return false return false
} }
@ -1078,7 +1084,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa
func (s *connection) handleVersionNegotiationPacket(p receivedPacket) { func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return return
@ -1086,7 +1092,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError)
} }
s.logger.Debugf("Error parsing Version Negotiation packet: %s", err) s.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
@ -1095,7 +1101,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
for _, v := range supportedVersions { for _, v := range supportedVersions {
if v == s.version { if v == s.version {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion) s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedVersion)
} }
// The Version Negotiation packet contains the version that we offered. // The Version Negotiation packet contains the version that we offered.
@ -1105,7 +1111,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
} }
s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions)
if s.tracer != nil { if s.tracer != nil && s.tracer.ReceivedVersionNegotiationPacket != nil {
s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions) s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions)
} }
newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions)
@ -1117,7 +1123,7 @@ func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
s.logger.Infof("No compatible QUIC version found.") s.logger.Infof("No compatible QUIC version found.")
return return
} }
if s.tracer != nil { if s.tracer != nil && s.tracer.NegotiatedVersion != nil {
s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions) s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions)
} }
@ -1137,7 +1143,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
) error { ) error {
if !s.receivedFirstPacket { if !s.receivedFirstPacket {
s.receivedFirstPacket = true s.receivedFirstPacket = true
if !s.versionNegotiated && s.tracer != nil { if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil {
var clientVersions, serverVersions []protocol.VersionNumber var clientVersions, serverVersions []protocol.VersionNumber
switch s.perspective { switch s.perspective {
case protocol.PerspectiveClient: case protocol.PerspectiveClient:
@ -1164,7 +1170,7 @@ func (s *connection) handleUnpackedLongHeaderPacket(
s.handshakeDestConnID = packet.hdr.SrcConnectionID s.handshakeDestConnID = packet.hdr.SrcConnectionID
s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
} }
if s.tracer != nil { if s.tracer != nil && s.tracer.StartedConnection != nil {
s.tracer.StartedConnection( s.tracer.StartedConnection(
s.conn.LocalAddr(), s.conn.LocalAddr(),
s.conn.RemoteAddr(), s.conn.RemoteAddr(),
@ -1188,9 +1194,9 @@ func (s *connection) handleUnpackedLongHeaderPacket(
s.keepAlivePingSent = false s.keepAlivePingSent = false
var log func([]logging.Frame) var log func([]logging.Frame)
if s.tracer != nil { if s.tracer != nil && s.tracer.ReceivedLongHeaderPacket != nil {
log = func(frames []logging.Frame) { log = func(frames []logging.Frame) {
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames) s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames)
} }
} }
isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log) isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log)
@ -1336,7 +1342,7 @@ func (s *connection) handlePacket(p receivedPacket) {
select { select {
case s.receivedPackets <- p: case s.receivedPackets <- p:
default: default:
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
} }
} }
@ -1616,7 +1622,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
s.datagramQueue.CloseWithError(e) s.datagramQueue.CloseWithError(e)
} }
if s.tracer != nil && !errors.As(e, &recreateErr) { if s.tracer != nil && s.tracer.ClosedConnection != nil && !errors.As(e, &recreateErr) {
s.tracer.ClosedConnection(e) s.tracer.ClosedConnection(e)
} }
@ -1643,7 +1649,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
} }
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error { func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) error {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedEncryptionLevel != nil {
s.tracer.DroppedEncryptionLevel(encLevel) s.tracer.DroppedEncryptionLevel(encLevel)
} }
s.sentPacketHandler.DropPackets(encLevel) s.sentPacketHandler.DropPackets(encLevel)
@ -1678,7 +1684,7 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters
} }
func (s *connection) handleTransportParameters(params *wire.TransportParameters) error { func (s *connection) handleTransportParameters(params *wire.TransportParameters) error {
if s.tracer != nil { if s.tracer != nil && s.tracer.ReceivedTransportParameters != nil {
s.tracer.ReceivedTransportParameters(params) s.tracer.ReceivedTransportParameters(params)
} }
if err := s.checkTransportParameters(params); err != nil { if err := s.checkTransportParameters(params); err != nil {
@ -1830,9 +1836,10 @@ func (s *connection) sendPackets(now time.Time) error {
if err != nil { if err != nil {
return err return err
} }
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) ecn := s.sentPacketHandler.ECNMode(true)
s.registerPackedShortHeaderPacket(p, now) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
s.sendQueue.Send(buf, buf.Len()) s.registerPackedShortHeaderPacket(p, ecn, now)
s.sendQueue.Send(buf, 0, ecn)
// This is kind of a hack. We need to trigger sending again somehow. // This is kind of a hack. We need to trigger sending again somehow.
s.pacingDeadline = deadlineSendImmediately s.pacingDeadline = deadlineSendImmediately
return nil return nil
@ -1852,7 +1859,7 @@ func (s *connection) sendPackets(now time.Time) error {
return err return err
} }
s.sentFirstPacket = true s.sentFirstPacket = true
if err := s.sendPackedCoalescedPacket(packet, now); err != nil { if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil {
return err return err
} }
sendMode := s.sentPacketHandler.SendMode(now) sendMode := s.sentPacketHandler.SendMode(now)
@ -1873,7 +1880,8 @@ func (s *connection) sendPackets(now time.Time) error {
func (s *connection) sendPacketsWithoutGSO(now time.Time) error { func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for { for {
buf := getPacketBuffer() buf := getPacketBuffer()
if _, err := s.appendPacket(buf, s.mtuDiscoverer.CurrentSize(), now); err != nil { ecn := s.sentPacketHandler.ECNMode(true)
if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil {
if err == errNothingToPack { if err == errNothingToPack {
buf.Release() buf.Release()
return nil return nil
@ -1881,7 +1889,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
return err return err
} }
s.sendQueue.Send(buf, buf.Len()) s.sendQueue.Send(buf, 0, ecn)
if s.sendQueue.WouldBlock() { if s.sendQueue.WouldBlock() {
return nil return nil
@ -1906,9 +1914,10 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer() buf := getLargePacketBuffer()
maxSize := s.mtuDiscoverer.CurrentSize() maxSize := s.mtuDiscoverer.CurrentSize()
ecn := s.sentPacketHandler.ECNMode(true)
for { for {
var dontSendMore bool var dontSendMore bool
size, err := s.appendPacket(buf, maxSize, now) size, err := s.appendOneShortHeaderPacket(buf, maxSize, ecn, now)
if err != nil { if err != nil {
if err != errNothingToPack { if err != errNothingToPack {
return err return err
@ -1930,15 +1939,19 @@ func (s *connection) sendPacketsWithGSO(now time.Time) error {
} }
} }
// Don't send more packets in this batch if they require a different ECN marking than the previous ones.
nextECN := s.sentPacketHandler.ECNMode(true)
// Append another packet if // Append another packet if
// 1. The congestion controller and pacer allow sending more // 1. The congestion controller and pacer allow sending more
// 2. The last packet appended was a full-size packet // 2. The last packet appended was a full-size packet
// 3. We still have enough space for another full-size packet in the buffer // 3. The next packet will have the same ECN marking
if !dontSendMore && size == maxSize && buf.Len()+maxSize <= buf.Cap() { // 4. We still have enough space for another full-size packet in the buffer
if !dontSendMore && size == maxSize && nextECN == ecn && buf.Len()+maxSize <= buf.Cap() {
continue continue
} }
s.sendQueue.Send(buf, maxSize) s.sendQueue.Send(buf, uint16(maxSize), ecn)
if dontSendMore { if dontSendMore {
return nil return nil
@ -1967,6 +1980,7 @@ func (s *connection) resetPacingDeadline() {
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error { func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed { if !s.handshakeConfirmed {
ecn := s.sentPacketHandler.ECNMode(false)
packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version) packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
return err return err
@ -1974,9 +1988,10 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if packet == nil { if packet == nil {
return nil return nil
} }
return s.sendPackedCoalescedPacket(packet, time.Now()) return s.sendPackedCoalescedPacket(packet, ecn, time.Now())
} }
ecn := s.sentPacketHandler.ECNMode(true)
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version) p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
if err != nil { if err != nil {
if err == errNothingToPack { if err == errNothingToPack {
@ -1984,9 +1999,9 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
} }
return err return err
} }
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buf.Len(), false) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, now) s.registerPackedShortHeaderPacket(p, ecn, now)
s.sendQueue.Send(buf, buf.Len()) s.sendQueue.Send(buf, 0, ecn)
return nil return nil
} }
@ -2018,24 +2033,24 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
} }
return s.sendPackedCoalescedPacket(packet, now) return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now)
} }
// appendPacket appends a new packet to the given packetBuffer. // appendOneShortHeaderPacket appends a new packet to the given packetBuffer.
// If there was nothing to pack, the returned size is 0. // If there was nothing to pack, the returned size is 0.
func (s *connection) appendPacket(buf *packetBuffer, maxSize protocol.ByteCount, now time.Time) (protocol.ByteCount, error) { func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) {
startLen := buf.Len() startLen := buf.Len()
p, err := s.packer.AppendPacket(buf, maxSize, s.version) p, err := s.packer.AppendPacket(buf, maxSize, s.version)
if err != nil { if err != nil {
return 0, err return 0, err
} }
size := buf.Len() - startLen size := buf.Len() - startLen
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, size, false) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false)
s.registerPackedShortHeaderPacket(p, now) s.registerPackedShortHeaderPacket(p, ecn, now)
return size, nil return size, nil
} }
func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now time.Time) { func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
} }
@ -2044,12 +2059,12 @@ func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, now ti
if p.Ack != nil { if p.Ack != nil {
largestAcked = p.Ack.LargestAcked() largestAcked = p.Ack.LargestAcked()
} }
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket)
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
} }
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) error { func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error {
s.logCoalescedPacket(packet) s.logCoalescedPacket(packet, ecn)
for _, p := range packet.longHdrPackets { for _, p := range packet.longHdrPackets {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now s.firstAckElicitingPacketAfterIdleSentTime = now
@ -2058,7 +2073,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
if p.ack != nil { if p.ack != nil {
largestAcked = p.ack.LargestAcked() largestAcked = p.ack.LargestAcked()
} }
s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), p.length, false) s.sentPacketHandler.SentPacket(now, p.header.PacketNumber, largestAcked, p.streamFrames, p.frames, p.EncryptionLevel(), ecn, p.length, false)
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake { if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent. // On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001. // See Section 4.9.1 of RFC 9001.
@ -2075,10 +2090,10 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
if p.Ack != nil { if p.Ack != nil {
largestAcked = p.Ack.LargestAcked() largestAcked = p.Ack.LargestAcked()
} }
s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, p.Length, p.IsPathMTUProbePacket) s.sentPacketHandler.SentPacket(now, p.PacketNumber, largestAcked, p.StreamFrames, p.Frames, protocol.Encryption1RTT, ecn, p.Length, p.IsPathMTUProbePacket)
} }
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, packet.buffer.Len()) s.sendQueue.Send(packet.buffer, 0, ecn)
return nil return nil
} }
@ -2100,11 +2115,12 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.logCoalescedPacket(packet) ecn := s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket())
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, packet.buffer.Len()) s.logCoalescedPacket(packet, ecn)
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
} }
func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging // quic-go logging
if s.logger.Debug() { if s.logger.Debug() {
p.header.Log(s.logger) p.header.Log(s.logger)
@ -2120,7 +2136,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
} }
// tracing // tracing
if s.tracer != nil { if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil {
frames := make([]logging.Frame, 0, len(p.frames)) frames := make([]logging.Frame, 0, len(p.frames))
for _, f := range p.frames { for _, f := range p.frames {
frames = append(frames, logutils.ConvertFrame(f.Frame)) frames = append(frames, logutils.ConvertFrame(f.Frame))
@ -2132,7 +2148,7 @@ func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
if p.ack != nil { if p.ack != nil {
ack = logutils.ConvertAckFrame(p.ack) ack = logutils.ConvertAckFrame(p.ack)
} }
s.tracer.SentLongHeaderPacket(p.header, p.length, ack, frames) s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames)
} }
} }
@ -2144,11 +2160,12 @@ func (s *connection) logShortHeaderPacket(
pn protocol.PacketNumber, pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen, pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit, kp protocol.KeyPhaseBit,
ecn protocol.ECN,
size protocol.ByteCount, size protocol.ByteCount,
isCoalesced bool, isCoalesced bool,
) { ) {
if s.logger.Debug() && !isCoalesced { if s.logger.Debug() && !isCoalesced {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID) s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn)
} }
// quic-go logging // quic-go logging
if s.logger.Debug() { if s.logger.Debug() {
@ -2165,7 +2182,7 @@ func (s *connection) logShortHeaderPacket(
} }
// tracing // tracing
if s.tracer != nil { if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil {
fs := make([]logging.Frame, 0, len(frames)+len(streamFrames)) fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
for _, f := range frames { for _, f := range frames {
fs = append(fs, logutils.ConvertFrame(f.Frame)) fs = append(fs, logutils.ConvertFrame(f.Frame))
@ -2185,13 +2202,14 @@ func (s *connection) logShortHeaderPacket(
KeyPhase: kp, KeyPhase: kp,
}, },
size, size,
ecn,
ack, ack,
fs, fs,
) )
} }
} }
func (s *connection) logCoalescedPacket(packet *coalescedPacket) { func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
if s.logger.Debug() { if s.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake, // There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet. // during which we might call PackCoalescedPacket but just pack a short header packet.
@ -2204,6 +2222,7 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
packet.shortHdrPacket.PacketNumber, packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen, packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase, packet.shortHdrPacket.KeyPhase,
ecn,
packet.shortHdrPacket.Length, packet.shortHdrPacket.Length,
false, false,
) )
@ -2216,10 +2235,10 @@ func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
} }
} }
for _, p := range packet.longHdrPackets { for _, p := range packet.longHdrPackets {
s.logLongHeaderPacket(p) s.logLongHeaderPacket(p, ecn)
} }
if p := packet.shortHdrPacket; p != nil { if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
} }
} }
@ -2285,14 +2304,14 @@ func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging
panic("shouldn't queue undecryptable packets after handshake completion") panic("shouldn't queue undecryptable packets after handshake completion")
} }
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention) s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropDOSPrevention)
} }
s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size()) s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size())
return return
} }
s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size()) s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size())
if s.tracer != nil { if s.tracer != nil && s.tracer.BufferedPacket != nil {
s.tracer.BufferedPacket(pt, p.Size()) s.tracer.BufferedPacket(pt, p.Size())
} }
s.undecryptablePackets = append(s.undecryptablePackets, p) s.undecryptablePackets = append(s.undecryptablePackets, p)

View File

@ -30,17 +30,10 @@ type cryptoStreamImpl struct {
writeOffset protocol.ByteCount writeOffset protocol.ByteCount
writeBuf []byte writeBuf []byte
// Reassemble TLS handshake messages before returning them from GetCryptoData.
// This is only needed because crypto/tls doesn't correctly handle post-handshake messages.
onlyCompleteMsg bool
} }
func newCryptoStream(onlyCompleteMsg bool) cryptoStream { func newCryptoStream() cryptoStream {
return &cryptoStreamImpl{ return &cryptoStreamImpl{queue: newFrameSorter()}
queue: newFrameSorter(),
onlyCompleteMsg: onlyCompleteMsg,
}
} }
func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
@ -78,20 +71,6 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
// GetCryptoData retrieves data that was received in CRYPTO frames // GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte { func (s *cryptoStreamImpl) GetCryptoData() []byte {
if s.onlyCompleteMsg {
if len(s.msgBuf) < 4 {
return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
}
b := s.msgBuf b := s.msgBuf
s.msgBuf = nil s.msgBuf = nil
return b return b

View File

@ -63,7 +63,8 @@ func (r *body) wasStreamHijacked() bool {
} }
func (r *body) Read(b []byte) (int, error) { func (r *body) Read(b []byte) (int, error) {
return r.str.Read(b) n, err := r.str.Read(b)
return n, maybeReplaceError(err)
} }
func (r *body) Close() error { func (r *body) Close() error {
@ -106,7 +107,7 @@ func (r *hijackableBody) Read(b []byte) (int, error) {
if err != nil { if err != nil {
r.requestDone() r.requestDone()
} }
return n, err return n, maybeReplaceError(err)
} }
func (r *hijackableBody) requestDone() { func (r *hijackableBody) requestDone() {

View File

@ -318,13 +318,13 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
} }
conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
} }
return nil, rerr.err return nil, maybeReplaceError(rerr.err)
} }
if opt.DontCloseRequestStream { if opt.DontCloseRequestStream {
close(reqDone) close(reqDone)
<-done <-done
} }
return rsp, rerr.err return rsp, maybeReplaceError(rerr.err)
} }
// cancelingReader reads from the io.Reader. // cancelingReader reads from the io.Reader.

58
vendor/github.com/quic-go/quic-go/http3/error.go generated vendored Normal file
View File

@ -0,0 +1,58 @@
package http3
import (
"errors"
"fmt"
"github.com/quic-go/quic-go"
)
// Error is returned from the round tripper (for HTTP clients)
// and inside the HTTP handler (for HTTP servers) if an HTTP/3 error occurs.
// See section 8 of RFC 9114.
type Error struct {
Remote bool
ErrorCode ErrCode
ErrorMessage string
}
var _ error = &Error{}
func (e *Error) Error() string {
s := e.ErrorCode.string()
if s == "" {
s = fmt.Sprintf("H3 error (%#x)", uint64(e.ErrorCode))
}
// Usually errors are remote. Only make it explicit for local errors.
if !e.Remote {
s += " (local)"
}
if e.ErrorMessage != "" {
s += ": " + e.ErrorMessage
}
return s
}
func maybeReplaceError(err error) error {
if err == nil {
return nil
}
var (
e Error
strErr *quic.StreamError
appErr *quic.ApplicationError
)
switch {
default:
return err
case errors.As(err, &strErr):
e.Remote = strErr.Remote
e.ErrorCode = ErrCode(strErr.ErrorCode)
case errors.As(err, &appErr):
e.Remote = appErr.Remote
e.ErrorCode = ErrCode(appErr.ErrorCode)
e.ErrorMessage = appErr.ErrorMessage
}
return &e
}

View File

@ -30,6 +30,14 @@ const (
) )
func (e ErrCode) String() string { func (e ErrCode) String() string {
s := e.string()
if s != "" {
return s
}
return fmt.Sprintf("unknown error code: %#x", uint16(e))
}
func (e ErrCode) string() string {
switch e { switch e {
case ErrCodeNoError: case ErrCodeNoError:
return "H3_NO_ERROR" return "H3_NO_ERROR"
@ -68,6 +76,6 @@ func (e ErrCode) String() string {
case ErrCodeDatagramError: case ErrCodeDatagramError:
return "H3_DATAGRAM_ERROR" return "H3_DATAGRAM_ERROR"
default: default:
return fmt.Sprintf("unknown error code: %#x", uint16(e)) return ""
} }
} }

View File

@ -2,7 +2,7 @@
package http3 package http3
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package http3 -destination mock_roundtripcloser_test.go github.com/quic-go/quic-go/http3 RoundTripCloser"
type RoundTripCloser = roundTripCloser type RoundTripCloser = roundTripCloser
//go:generate sh -c "go run github.com/golang/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener" //go:generate sh -c "go run go.uber.org/mock/mockgen -package http3 -destination mock_quic_early_listener_test.go github.com/quic-go/quic-go/http3 QUICEarlyListener"

View File

@ -166,9 +166,10 @@ func (w *responseWriter) Write(p []byte) (int, error) {
w.buf = w.buf[:0] w.buf = w.buf[:0]
w.buf = df.Append(w.buf) w.buf = df.Append(w.buf)
if _, err := w.bufferedStr.Write(w.buf); err != nil { if _, err := w.bufferedStr.Write(w.buf); err != nil {
return 0, err return 0, maybeReplaceError(err)
} }
return w.bufferedStr.Write(p) n, err := w.bufferedStr.Write(p)
return n, maybeReplaceError(err)
} }
func (w *responseWriter) FlushError() error { func (w *responseWriter) FlushError() error {
@ -177,7 +178,7 @@ func (w *responseWriter) FlushError() error {
} }
if !w.written { if !w.written {
if err := w.writeHeader(); err != nil { if err := w.writeHeader(); err != nil {
return err return maybeReplaceError(err)
} }
w.written = true w.written = true
} }

View File

@ -212,6 +212,9 @@ type EarlyConnection interface {
// StatelessResetKey is a key used to derive stateless reset tokens. // StatelessResetKey is a key used to derive stateless reset tokens.
type StatelessResetKey [32]byte type StatelessResetKey [32]byte
// TokenGeneratorKey is a key used to encrypt session resumption tokens.
type TokenGeneratorKey = handshake.TokenProtectorKey
// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. // A ConnectionID is a QUIC Connection ID, as defined in RFC 9000.
// It is not able to handle QUIC Connection IDs longer than 20 bytes, // It is not able to handle QUIC Connection IDs longer than 20 bytes,
// as they are allowed by RFC 8999. // as they are allowed by RFC 8999.
@ -250,7 +253,8 @@ type Config struct {
// If not set, it uses all versions available. // If not set, it uses all versions available.
Versions []VersionNumber Versions []VersionNumber
// HandshakeIdleTimeout is the idle timeout before completion of the handshake. // HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted. // If we don't receive any packet from the peer within this time, the connection attempt is aborted.
// Additionally, if the handshake doesn't complete in twice this time, the connection attempt is also aborted.
// If this value is zero, the timeout is set to 5 seconds. // If this value is zero, the timeout is set to 5 seconds.
HandshakeIdleTimeout time.Duration HandshakeIdleTimeout time.Duration
// MaxIdleTimeout is the maximum duration that may pass without any incoming network activity. // MaxIdleTimeout is the maximum duration that may pass without any incoming network activity.
@ -264,13 +268,6 @@ type Config struct {
// See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details. // See https://datatracker.ietf.org/doc/html/rfc9000#section-8 for details.
// If not set, every client is forced to prove its remote address. // If not set, every client is forced to prove its remote address.
RequireAddressValidation func(net.Addr) bool RequireAddressValidation func(net.Addr) bool
// MaxRetryTokenAge is the maximum age of a Retry token.
// If not set, it defaults to 5 seconds. Only valid for a server.
MaxRetryTokenAge time.Duration
// MaxTokenAge is the maximum age of the token presented during the handshake,
// for tokens that were issued on a previous connection.
// If not set, it defaults to 24 hours. Only valid for a server.
MaxTokenAge time.Duration
// The TokenStore stores tokens received from the server. // The TokenStore stores tokens received from the server.
// Tokens are used to skip address validation on future connection attempts. // Tokens are used to skip address validation on future connection attempts.
// The key used to store tokens is the ServerName from the tls.Config, if set // The key used to store tokens is the ServerName from the tls.Config, if set
@ -322,16 +319,12 @@ type Config struct {
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit. // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. // If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
DisablePathMTUDiscovery bool DisablePathMTUDiscovery bool
// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
// This can be useful if version information is exchanged out-of-band.
// It has no effect for a client.
DisableVersionNegotiationPackets bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// Only valid for the server. // Only valid for the server.
Allow0RTT bool Allow0RTT bool
// Enable QUIC datagram support (RFC 9221). // Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool EnableDatagrams bool
Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer Tracer func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer
} }
type ClientHelloInfo struct { type ClientHelloInfo struct {
@ -351,4 +344,6 @@ type ConnectionState struct {
Used0RTT bool Used0RTT bool
// Version is the QUIC version of the QUIC connection. // Version is the QUIC version of the QUIC connection.
Version VersionNumber Version VersionNumber
// GSO says if generic segmentation offload is used
GSO bool
} }

View File

@ -14,10 +14,11 @@ func NewAckHandler(
initialMaxDatagramSize protocol.ByteCount, initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
clientAddressValidated bool, clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective, pers protocol.Perspective,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
) (SentPacketHandler, ReceivedPacketHandler) { ) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, enableECN, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, rttStats, logger) return sph, newReceivedPacketHandler(sph, rttStats, logger)
} }

View File

@ -0,0 +1,296 @@
package ackhandler
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
type ecnState uint8
const (
ecnStateInitial ecnState = iota
ecnStateTesting
ecnStateUnknown
ecnStateCapable
ecnStateFailed
)
// must fit into an uint8, otherwise numSentTesting and numLostTesting must have a larger type
const numECNTestingPackets = 10
type ecnHandler interface {
SentPacket(protocol.PacketNumber, protocol.ECN)
Mode() protocol.ECN
HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool)
LostPacket(protocol.PacketNumber)
}
// The ecnTracker performs ECN validation of a path.
// Once failed, it doesn't do any re-validation of the path.
// It is designed only work for 1-RTT packets, it doesn't handle multiple packet number spaces.
// In order to avoid revealing any internal state to on-path observers,
// callers should make sure to start using ECN (i.e. calling Mode) for the very first 1-RTT packet sent.
// The validation logic implemented here strictly follows the algorithm described in RFC 9000 section 13.4.2 and A.4.
type ecnTracker struct {
state ecnState
numSentTesting, numLostTesting uint8
firstTestingPacket protocol.PacketNumber
lastTestingPacket protocol.PacketNumber
firstCapablePacket protocol.PacketNumber
numSentECT0, numSentECT1 int64
numAckedECT0, numAckedECT1, numAckedECNCE int64
tracer *logging.ConnectionTracer
logger utils.Logger
}
var _ ecnHandler = &ecnTracker{}
func newECNTracker(logger utils.Logger, tracer *logging.ConnectionTracer) *ecnTracker {
return &ecnTracker{
firstTestingPacket: protocol.InvalidPacketNumber,
lastTestingPacket: protocol.InvalidPacketNumber,
firstCapablePacket: protocol.InvalidPacketNumber,
state: ecnStateInitial,
logger: logger,
tracer: tracer,
}
}
func (e *ecnTracker) SentPacket(pn protocol.PacketNumber, ecn protocol.ECN) {
//nolint:exhaustive // These are the only ones we need to take care of.
switch ecn {
case protocol.ECNNon:
return
case protocol.ECT0:
e.numSentECT0++
case protocol.ECT1:
e.numSentECT1++
case protocol.ECNUnsupported:
if e.state != ecnStateFailed {
panic("didn't expect ECN to be unsupported")
}
default:
panic(fmt.Sprintf("sent packet with unexpected ECN marking: %s", ecn))
}
if e.state == ecnStateCapable && e.firstCapablePacket == protocol.InvalidPacketNumber {
e.firstCapablePacket = pn
}
if e.state != ecnStateTesting {
return
}
e.numSentTesting++
if e.firstTestingPacket == protocol.InvalidPacketNumber {
e.firstTestingPacket = pn
}
if e.numSentECT0+e.numSentECT1 >= numECNTestingPackets {
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateUnknown, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateUnknown
e.lastTestingPacket = pn
}
}
func (e *ecnTracker) Mode() protocol.ECN {
switch e.state {
case ecnStateInitial:
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateTesting, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateTesting
return e.Mode()
case ecnStateTesting, ecnStateCapable:
return protocol.ECT0
case ecnStateUnknown, ecnStateFailed:
return protocol.ECNNon
default:
panic(fmt.Sprintf("unknown ECN state: %d", e.state))
}
}
func (e *ecnTracker) LostPacket(pn protocol.PacketNumber) {
if e.state != ecnStateTesting && e.state != ecnStateUnknown {
return
}
if !e.isTestingPacket(pn) {
return
}
e.numLostTesting++
// Only proceed if we have sent all 10 testing packets.
if e.state != ecnStateUnknown {
return
}
if e.numLostTesting >= e.numSentTesting {
e.logger.Debugf("Disabling ECN. All testing packets were lost.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedLostAllTestingPackets)
}
e.state = ecnStateFailed
return
}
// Path validation also fails if some testing packets are lost, and all other testing packets where CE-marked
e.failIfMangled()
}
// HandleNewlyAcked handles the ECN counts on an ACK frame.
// It must only be called for ACK frames that increase the largest acknowledged packet number,
// see section 13.4.2.1 of RFC 9000.
func (e *ecnTracker) HandleNewlyAcked(packets []*packet, ect0, ect1, ecnce int64) (congested bool) {
if e.state == ecnStateFailed {
return false
}
// ECN validation can fail if the received total count for either ECT(0) or ECT(1) exceeds
// the total number of packets sent with each corresponding ECT codepoint.
if ect0 > e.numSentECT0 || ect1 > e.numSentECT1 {
e.logger.Debugf("Disabling ECN. Received more ECT(0) / ECT(1) acknowledgements than packets sent.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedMoreECNCountsThanSent)
}
e.state = ecnStateFailed
return false
}
// Count ECT0 and ECT1 marks that we used when sending the packets that are now being acknowledged.
var ackedECT0, ackedECT1 int64
for _, p := range packets {
//nolint:exhaustive // We only ever send ECT(0) and ECT(1).
switch e.ecnMarking(p.PacketNumber) {
case protocol.ECT0:
ackedECT0++
case protocol.ECT1:
ackedECT1++
}
}
// If an ACK frame newly acknowledges a packet that the endpoint sent with either the ECT(0) or ECT(1)
// codepoint set, ECN validation fails if the corresponding ECN counts are not present in the ACK frame.
// This check detects:
// * paths that bleach all ECN marks, and
// * peers that don't report any ECN counts
if (ackedECT0 > 0 || ackedECT1 > 0) && ect0 == 0 && ect1 == 0 && ecnce == 0 {
e.logger.Debugf("Disabling ECN. ECN-marked packet acknowledged, but no ECN counts on ACK frame.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedNoECNCounts)
}
e.state = ecnStateFailed
return false
}
// Determine the increase in ECT0, ECT1 and ECNCE marks
newECT0 := ect0 - e.numAckedECT0
newECT1 := ect1 - e.numAckedECT1
newECNCE := ecnce - e.numAckedECNCE
// We're only processing ACKs that increase the Largest Acked.
// Therefore, the ECN counters should only ever increase.
// Any decrease means that the peer's counting logic is broken.
if newECT0 < 0 || newECT1 < 0 || newECNCE < 0 {
e.logger.Debugf("Disabling ECN. ECN counts decreased unexpectedly.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedDecreasedECNCounts)
}
e.state = ecnStateFailed
return false
}
// ECN validation also fails if the sum of the increase in ECT(0) and ECN-CE counts is less than the number
// of newly acknowledged packets that were originally sent with an ECT(0) marking.
// This could be the result of (partial) bleaching.
if newECT0+newECNCE < ackedECT0 {
e.logger.Debugf("Disabling ECN. Received less ECT(0) + ECN-CE than packets sent with ECT(0).")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
}
e.state = ecnStateFailed
return false
}
// Similarly, ECN validation fails if the sum of the increases to ECT(1) and ECN-CE counts is less than
// the number of newly acknowledged packets sent with an ECT(1) marking.
if newECT1+newECNCE < ackedECT1 {
e.logger.Debugf("Disabling ECN. Received less ECT(1) + ECN-CE than packets sent with ECT(1).")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedTooFewECNCounts)
}
e.state = ecnStateFailed
return false
}
// update our counters
e.numAckedECT0 = ect0
e.numAckedECT1 = ect1
e.numAckedECNCE = ecnce
// Detect mangling (a path remarking all ECN-marked testing packets as CE),
// once all 10 testing packets have been sent out.
if e.state == ecnStateUnknown {
e.failIfMangled()
if e.state == ecnStateFailed {
return false
}
}
if e.state == ecnStateTesting || e.state == ecnStateUnknown {
var ackedTestingPacket bool
for _, p := range packets {
if e.isTestingPacket(p.PacketNumber) {
ackedTestingPacket = true
break
}
}
// This check won't succeed if the path is mangling ECN-marks (i.e. rewrites all ECN-marked packets to CE).
if ackedTestingPacket && (newECT0 > 0 || newECT1 > 0) {
e.logger.Debugf("ECN capability confirmed.")
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateCapable, logging.ECNTriggerNoTrigger)
}
e.state = ecnStateCapable
}
}
// Don't trust CE marks before having confirmed ECN capability of the path.
// Otherwise, mangling would be misinterpreted as actual congestion.
return e.state == ecnStateCapable && newECNCE > 0
}
// failIfMangled fails ECN validation if all testing packets are lost or CE-marked.
func (e *ecnTracker) failIfMangled() {
numAckedECNCE := e.numAckedECNCE + int64(e.numLostTesting)
if e.numSentECT0+e.numSentECT1 > numAckedECNCE {
return
}
if e.tracer != nil && e.tracer.ECNStateUpdated != nil {
e.tracer.ECNStateUpdated(logging.ECNStateFailed, logging.ECNFailedManglingDetected)
}
e.state = ecnStateFailed
}
func (e *ecnTracker) ecnMarking(pn protocol.PacketNumber) protocol.ECN {
if pn < e.firstTestingPacket || e.firstTestingPacket == protocol.InvalidPacketNumber {
return protocol.ECNNon
}
if pn < e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber {
return protocol.ECT0
}
if pn < e.firstCapablePacket || e.firstCapablePacket == protocol.InvalidPacketNumber {
return protocol.ECNNon
}
// We don't need to deal with the case when ECN validation fails,
// since we're ignoring any ECN counts reported in ACK frames in that case.
return protocol.ECT0
}
func (e *ecnTracker) isTestingPacket(pn protocol.PacketNumber) bool {
if e.firstTestingPacket == protocol.InvalidPacketNumber {
return false
}
return pn >= e.firstTestingPacket && (pn <= e.lastTestingPacket || e.lastTestingPacket == protocol.InvalidPacketNumber)
}

View File

@ -10,13 +10,13 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet // SentPacket may modify the packet
SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, size protocol.ByteCount, isPathMTUProbePacket bool) SentPacket(t time.Time, pn, largestAcked protocol.PacketNumber, streamFrames []StreamFrame, frames []Frame, encLevel protocol.EncryptionLevel, ecn protocol.ECN, size protocol.ByteCount, isPathMTUProbePacket bool)
// ReceivedAck processes an ACK frame. // ReceivedAck processes an ACK frame.
// It does not store a copy of the frame. // It does not store a copy of the frame.
ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error) ReceivedAck(f *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* 1-RTT packet acked */, error)
ReceivedBytes(protocol.ByteCount) ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel) DropPackets(protocol.EncryptionLevel)
ResetForRetry() error ResetForRetry(rcvTime time.Time) error
SetHandshakeConfirmed() SetHandshakeConfirmed()
// The SendMode determines if and what kind of packets can be sent. // The SendMode determines if and what kind of packets can be sent.
@ -29,6 +29,7 @@ type SentPacketHandler interface {
// only to be called once the handshake is complete // only to be called once the handshake is complete
QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */ QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */
ECNMode(isShortHeaderPacket bool) protocol.ECN // isShortHeaderPacket should only be true for non-coalesced 1-RTT packets
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
@ -44,7 +45,7 @@ type sentPacketTracker interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface { type ReceivedPacketHandler interface {
IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool
ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, ackEliciting bool) error
DropPackets(protocol.EncryptionLevel) DropPackets(protocol.EncryptionLevel)
GetAlarmTimeout() time.Time GetAlarmTimeout() time.Time

View File

@ -2,5 +2,8 @@
package ackhandler package ackhandler
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_sent_packet_tracker_test.go github.com/quic-go/quic-go/internal/ackhandler SentPacketTracker"
type SentPacketTracker = sentPacketTracker type SentPacketTracker = sentPacketTracker
//go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package ackhandler -destination mock_ecn_handler_test.go github.com/quic-go/quic-go/internal/ackhandler ECNHandler"
type ECNHandler = ecnHandler

View File

@ -40,29 +40,29 @@ func (h *receivedPacketHandler) ReceivedPacket(
ecn protocol.ECN, ecn protocol.ECN,
encLevel protocol.EncryptionLevel, encLevel protocol.EncryptionLevel,
rcvTime time.Time, rcvTime time.Time,
shouldInstigateAck bool, ackEliciting bool,
) error { ) error {
h.sentPackets.ReceivedPacket(encLevel) h.sentPackets.ReceivedPacket(encLevel)
switch encLevel { switch encLevel {
case protocol.EncryptionInitial: case protocol.EncryptionInitial:
return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) return h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
// The Handshake packet number space might already have been dropped as a result // The Handshake packet number space might already have been dropped as a result
// of processing the CRYPTO frame that was contained in this packet. // of processing the CRYPTO frame that was contained in this packet.
if h.handshakePackets == nil { if h.handshakePackets == nil {
return nil return nil
} }
return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) return h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.Encryption0RTT: case protocol.Encryption0RTT:
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket { if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {
return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket) return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket)
} }
return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck) return h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting)
case protocol.Encryption1RTT: case protocol.Encryption1RTT:
if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket { if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket {
h.lowest1RTTPacket = pn h.lowest1RTTPacket = pn
} }
if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck); err != nil { if err := h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, ackEliciting); err != nil {
return err return err
} }
h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked()) h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked())

View File

@ -13,10 +13,10 @@ import (
const packetsBeforeAck = 2 const packetsBeforeAck = 2
type receivedPacketTracker struct { type receivedPacketTracker struct {
largestObserved protocol.PacketNumber largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time largestObservedRcvdTime time.Time
ect0, ect1, ecnce uint64 ect0, ect1, ecnce uint64
packetHistory *receivedPacketHistory packetHistory *receivedPacketHistory
@ -45,25 +45,25 @@ func newReceivedPacketTracker(
} }
} }
func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) error { func (h *receivedPacketTracker) ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, ackEliciting bool) error {
if isNew := h.packetHistory.ReceivedPacket(packetNumber); !isNew { if isNew := h.packetHistory.ReceivedPacket(pn); !isNew {
return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", packetNumber) return fmt.Errorf("recevedPacketTracker BUG: ReceivedPacket called for old / duplicate packet %d", pn)
} }
isMissing := h.isMissing(packetNumber) isMissing := h.isMissing(pn)
if packetNumber >= h.largestObserved { if pn >= h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = pn
h.largestObservedReceivedTime = rcvTime h.largestObservedRcvdTime = rcvTime
} }
if shouldInstigateAck { if ackEliciting {
h.hasNewAck = true h.hasNewAck = true
} }
if shouldInstigateAck { if ackEliciting {
h.maybeQueueAck(packetNumber, rcvTime, isMissing) h.maybeQueueACK(pn, rcvTime, isMissing)
} }
//nolint:exhaustive // Only need to count ECT(0), ECT(1) and ECNCE.
switch ecn { switch ecn {
case protocol.ECNNon:
case protocol.ECT0: case protocol.ECT0:
h.ect0++ h.ect0++
case protocol.ECT1: case protocol.ECT1:
@ -76,14 +76,14 @@ func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumbe
// IgnoreBelow sets a lower limit for acknowledging packets. // IgnoreBelow sets a lower limit for acknowledging packets.
// Packets with packet numbers smaller than p will not be acked. // Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) { func (h *receivedPacketTracker) IgnoreBelow(pn protocol.PacketNumber) {
if p <= h.ignoreBelow { if pn <= h.ignoreBelow {
return return
} }
h.ignoreBelow = p h.ignoreBelow = pn
h.packetHistory.DeleteBelow(p) h.packetHistory.DeleteBelow(pn)
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %d.", p) h.logger.Debugf("\tIgnoring all packets below %d.", pn)
} }
} }
@ -103,8 +103,8 @@ func (h *receivedPacketTracker) hasNewMissingPackets() bool {
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1 return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
} }
// maybeQueueAck queues an ACK, if necessary. // maybeQueueACK queues an ACK, if necessary.
func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) { func (h *receivedPacketTracker) maybeQueueACK(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) {
// always acknowledge the first packet // always acknowledge the first packet
if h.lastAck == nil { if h.lastAck == nil {
if !h.ackQueued { if !h.ackQueued {
@ -175,7 +175,7 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
ack = &wire.AckFrame{} ack = &wire.AckFrame{}
} }
ack.Reset() ack.Reset()
ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedReceivedTime)) ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedRcvdTime))
ack.ECT0 = h.ect0 ack.ECT0 = h.ect0
ack.ECT1 = h.ect1 ack.ECT1 = h.ect1
ack.ECNCE = h.ecnce ack.ECNCE = h.ecnce

View File

@ -92,9 +92,12 @@ type sentPacketHandler struct {
// The alarm timeout // The alarm timeout
alarm time.Time alarm time.Time
enableECN bool
ecnTracker ecnHandler
perspective protocol.Perspective perspective protocol.Perspective
tracer logging.ConnectionTracer tracer *logging.ConnectionTracer
logger utils.Logger logger utils.Logger
} }
@ -110,8 +113,9 @@ func newSentPacketHandler(
initialMaxDatagramSize protocol.ByteCount, initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
clientAddressValidated bool, clientAddressValidated bool,
enableECN bool,
pers protocol.Perspective, pers protocol.Perspective,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
) *sentPacketHandler { ) *sentPacketHandler {
congestion := congestion.NewCubicSender( congestion := congestion.NewCubicSender(
@ -122,7 +126,7 @@ func newSentPacketHandler(
tracer, tracer,
) )
return &sentPacketHandler{ h := &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer, peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated, peerAddressValidated: pers == protocol.PerspectiveClient || clientAddressValidated,
initialPackets: newPacketNumberSpace(initialPN, false), initialPackets: newPacketNumberSpace(initialPN, false),
@ -134,6 +138,11 @@ func newSentPacketHandler(
tracer: tracer, tracer: tracer,
logger: logger, logger: logger,
} }
if enableECN {
h.enableECN = true
h.ecnTracker = newECNTracker(logger, tracer)
}
return h
} }
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
@ -187,7 +196,7 @@ func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
default: default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel)) panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
} }
if h.tracer != nil && h.ptoCount != 0 { if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 {
h.tracer.UpdatedPTOCount(0) h.tracer.UpdatedPTOCount(0)
} }
h.ptoCount = 0 h.ptoCount = 0
@ -228,6 +237,7 @@ func (h *sentPacketHandler) SentPacket(
streamFrames []StreamFrame, streamFrames []StreamFrame,
frames []Frame, frames []Frame,
encLevel protocol.EncryptionLevel, encLevel protocol.EncryptionLevel,
ecn protocol.ECN,
size protocol.ByteCount, size protocol.ByteCount,
isPathMTUProbePacket bool, isPathMTUProbePacket bool,
) { ) {
@ -252,6 +262,10 @@ func (h *sentPacketHandler) SentPacket(
} }
h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting) h.congestion.OnPacketSent(t, h.bytesInFlight, pn, size, isAckEliciting)
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
h.ecnTracker.SentPacket(pn, ecn)
}
if !isAckEliciting { if !isAckEliciting {
pnSpace.history.SentNonAckElicitingPacket(pn) pnSpace.history.SentNonAckElicitingPacket(pn)
if !h.peerCompletedAddressValidation { if !h.peerCompletedAddressValidation {
@ -272,7 +286,7 @@ func (h *sentPacketHandler) SentPacket(
p.includedInBytesInFlight = true p.includedInBytesInFlight = true
pnSpace.history.SentAckElicitingPacket(p) pnSpace.history.SentAckElicitingPacket(p)
if h.tracer != nil { if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
h.setLossDetectionTimer() h.setLossDetectionTimer()
@ -302,8 +316,6 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
} }
} }
pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked)
// Servers complete address validation when a protected packet is received. // Servers complete address validation when a protected packet is received.
if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation && if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation &&
(encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) { (encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) {
@ -333,6 +345,17 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
h.congestion.MaybeExitSlowStart() h.congestion.MaybeExitSlowStart()
} }
} }
// Only inform the ECN tracker about new 1-RTT ACKs if the ACK increases the largest acked.
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil && largestAcked > pnSpace.largestAcked {
congested := h.ecnTracker.HandleNewlyAcked(ackedPackets, int64(ack.ECT0), int64(ack.ECT1), int64(ack.ECNCE))
if congested {
h.congestion.OnCongestionEvent(largestAcked, 0, priorInFlight)
}
}
pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked)
if err := h.detectLostPackets(rcvTime, encLevel); err != nil { if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
return false, err return false, err
} }
@ -353,14 +376,14 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
// Reset the pto_count unless the client is unsure if the server has validated the client's address. // Reset the pto_count unless the client is unsure if the server has validated the client's address.
if h.peerCompletedAddressValidation { if h.peerCompletedAddressValidation {
if h.tracer != nil && h.ptoCount != 0 { if h.tracer != nil && h.tracer.UpdatedPTOCount != nil && h.ptoCount != 0 {
h.tracer.UpdatedPTOCount(0) h.tracer.UpdatedPTOCount(0)
} }
h.ptoCount = 0 h.ptoCount = 0
} }
h.numProbesToSend = 0 h.numProbesToSend = 0
if h.tracer != nil { if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
@ -439,7 +462,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
if err := pnSpace.history.Remove(p.PacketNumber); err != nil { if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
return nil, err return nil, err
} }
if h.tracer != nil { if h.tracer != nil && h.tracer.AcknowledgedPacket != nil {
h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber) h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber)
} }
} }
@ -532,7 +555,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() {
if !lossTime.IsZero() { if !lossTime.IsZero() {
// Early retransmit timer or time loss detection. // Early retransmit timer or time loss detection.
h.alarm = lossTime h.alarm = lossTime
if h.tracer != nil && h.alarm != oldAlarm { if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm) h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm)
} }
return return
@ -543,7 +566,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() {
h.alarm = time.Time{} h.alarm = time.Time{}
if !oldAlarm.IsZero() { if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. Amplification limited.") h.logger.Debugf("Canceling loss detection timer. Amplification limited.")
if h.tracer != nil { if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled() h.tracer.LossTimerCanceled()
} }
} }
@ -555,7 +578,7 @@ func (h *sentPacketHandler) setLossDetectionTimer() {
h.alarm = time.Time{} h.alarm = time.Time{}
if !oldAlarm.IsZero() { if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. No packets in flight.") h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
if h.tracer != nil { if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled() h.tracer.LossTimerCanceled()
} }
} }
@ -568,14 +591,14 @@ func (h *sentPacketHandler) setLossDetectionTimer() {
if !oldAlarm.IsZero() { if !oldAlarm.IsZero() {
h.alarm = time.Time{} h.alarm = time.Time{}
h.logger.Debugf("Canceling loss detection timer. No PTO needed..") h.logger.Debugf("Canceling loss detection timer. No PTO needed..")
if h.tracer != nil { if h.tracer != nil && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled() h.tracer.LossTimerCanceled()
} }
} }
return return
} }
h.alarm = ptoTime h.alarm = ptoTime
if h.tracer != nil && h.alarm != oldAlarm { if h.tracer != nil && h.tracer.SetLossTimer != nil && h.alarm != oldAlarm {
h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm) h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm)
} }
} }
@ -606,7 +629,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber) h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
} }
if h.tracer != nil { if h.tracer != nil && h.tracer.LostPacket != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold) h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
} }
} }
@ -616,7 +639,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber) h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
} }
if h.tracer != nil { if h.tracer != nil && h.tracer.LostPacket != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold) h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
} }
} }
@ -635,7 +658,10 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
h.removeFromBytesInFlight(p) h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p) h.queueFramesForRetransmission(p)
if !p.IsPathMTUProbePacket { if !p.IsPathMTUProbePacket {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight) h.congestion.OnCongestionEvent(p.PacketNumber, p.Length, priorInFlight)
}
if encLevel == protocol.Encryption1RTT && h.ecnTracker != nil {
h.ecnTracker.LostPacket(p.PacketNumber)
} }
} }
} }
@ -650,7 +676,7 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime) h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime)
} }
if h.tracer != nil { if h.tracer != nil && h.tracer.LossTimerExpired != nil {
h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel) h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
} }
// Early retransmit or time loss detection // Early retransmit or time loss detection
@ -687,8 +713,12 @@ func (h *sentPacketHandler) OnLossDetectionTimeout() error {
h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount) h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount)
} }
if h.tracer != nil { if h.tracer != nil {
h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel) if h.tracer.LossTimerExpired != nil {
h.tracer.UpdatedPTOCount(h.ptoCount) h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel)
}
if h.tracer.UpdatedPTOCount != nil {
h.tracer.UpdatedPTOCount(h.ptoCount)
}
} }
h.numProbesToSend += 2 h.numProbesToSend += 2
//nolint:exhaustive // We never arm a PTO timer for 0-RTT packets. //nolint:exhaustive // We never arm a PTO timer for 0-RTT packets.
@ -712,6 +742,16 @@ func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm return h.alarm
} }
func (h *sentPacketHandler) ECNMode(isShortHeaderPacket bool) protocol.ECN {
if !h.enableECN {
return protocol.ECNUnsupported
}
if !isShortHeaderPacket {
return protocol.ECNNon
}
return h.ecnTracker.Mode()
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek() pn := pnSpace.pns.Peek()
@ -824,7 +864,7 @@ func (h *sentPacketHandler) queueFramesForRetransmission(p *packet) {
p.Frames = nil p.Frames = nil
} }
func (h *sentPacketHandler) ResetForRetry() error { func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
h.bytesInFlight = 0 h.bytesInFlight = 0
var firstPacketSendTime time.Time var firstPacketSendTime time.Time
h.initialPackets.history.Iterate(func(p *packet) (bool, error) { h.initialPackets.history.Iterate(func(p *packet) (bool, error) {
@ -850,12 +890,11 @@ func (h *sentPacketHandler) ResetForRetry() error {
// Otherwise, we don't know which Initial the Retry was sent in response to. // Otherwise, we don't know which Initial the Retry was sent in response to.
if h.ptoCount == 0 { if h.ptoCount == 0 {
// Don't set the RTT to a value lower than 5ms here. // Don't set the RTT to a value lower than 5ms here.
now := time.Now()
h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now) h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation()) h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
} }
if h.tracer != nil { if h.tracer != nil && h.tracer.UpdatedMetrics != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight()) h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
} }
} }
@ -864,8 +903,10 @@ func (h *sentPacketHandler) ResetForRetry() error {
oldAlarm := h.alarm oldAlarm := h.alarm
h.alarm = time.Time{} h.alarm = time.Time{}
if h.tracer != nil { if h.tracer != nil {
h.tracer.UpdatedPTOCount(0) if h.tracer.UpdatedPTOCount != nil {
if !oldAlarm.IsZero() { h.tracer.UpdatedPTOCount(0)
}
if !oldAlarm.IsZero() && h.tracer.LossTimerCanceled != nil {
h.tracer.LossTimerCanceled() h.tracer.LossTimerCanceled()
} }
} }

View File

@ -56,7 +56,7 @@ type cubicSender struct {
maxDatagramSize protocol.ByteCount maxDatagramSize protocol.ByteCount
lastState logging.CongestionState lastState logging.CongestionState
tracer logging.ConnectionTracer tracer *logging.ConnectionTracer
} }
var ( var (
@ -70,7 +70,7 @@ func NewCubicSender(
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
initialMaxDatagramSize protocol.ByteCount, initialMaxDatagramSize protocol.ByteCount,
reno bool, reno bool,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
) *cubicSender { ) *cubicSender {
return newCubicSender( return newCubicSender(
clock, clock,
@ -90,7 +90,7 @@ func newCubicSender(
initialMaxDatagramSize, initialMaxDatagramSize,
initialCongestionWindow, initialCongestionWindow,
initialMaxCongestionWindow protocol.ByteCount, initialMaxCongestionWindow protocol.ByteCount,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
) *cubicSender { ) *cubicSender {
c := &cubicSender{ c := &cubicSender{
rttStats: rttStats, rttStats: rttStats,
@ -108,7 +108,7 @@ func newCubicSender(
maxDatagramSize: initialMaxDatagramSize, maxDatagramSize: initialMaxDatagramSize,
} }
c.pacer = newPacer(c.BandwidthEstimate) c.pacer = newPacer(c.BandwidthEstimate)
if c.tracer != nil { if c.tracer != nil && c.tracer.UpdatedCongestionState != nil {
c.lastState = logging.CongestionStateSlowStart c.lastState = logging.CongestionStateSlowStart
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart) c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
} }
@ -188,7 +188,7 @@ func (c *cubicSender) OnPacketAcked(
} }
} }
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) { func (c *cubicSender) OnCongestionEvent(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets // TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected. // already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback { if packetNumber <= c.largestSentAtLastCutback {
@ -296,7 +296,7 @@ func (c *cubicSender) OnConnectionMigration() {
} }
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) { func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
if c.tracer == nil || new == c.lastState { if c.tracer == nil || c.tracer.UpdatedCongestionState == nil || new == c.lastState {
return return
} }
c.tracer.UpdatedCongestionState(new) c.tracer.UpdatedCongestionState(new)

View File

@ -14,7 +14,7 @@ type SendAlgorithm interface {
CanSend(bytesInFlight protocol.ByteCount) bool CanSend(bytesInFlight protocol.ByteCount) bool
MaybeExitSlowStart() MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time) OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount) OnCongestionEvent(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
OnRetransmissionTimeout(packetsRetransmitted bool) OnRetransmissionTimeout(packetsRetransmitted bool)
SetMaxDatagramSize(protocol.ByteCount) SetMaxDatagramSize(protocol.ByteCount)
} }

View File

@ -43,7 +43,7 @@ type cryptoSetup struct {
rttStats *utils.RTTStats rttStats *utils.RTTStats
tracer logging.ConnectionTracer tracer *logging.ConnectionTracer
logger utils.Logger logger utils.Logger
perspective protocol.Perspective perspective protocol.Perspective
@ -77,7 +77,7 @@ func NewCryptoSetupClient(
tlsConf *tls.Config, tlsConf *tls.Config,
enable0RTT bool, enable0RTT bool,
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
version protocol.VersionNumber, version protocol.VersionNumber,
) CryptoSetup { ) CryptoSetup {
@ -111,7 +111,7 @@ func NewCryptoSetupServer(
tlsConf *tls.Config, tlsConf *tls.Config,
allow0RTT bool, allow0RTT bool,
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
version protocol.VersionNumber, version protocol.VersionNumber,
) CryptoSetup { ) CryptoSetup {
@ -127,7 +127,7 @@ func NewCryptoSetupServer(
cs.allow0RTT = allow0RTT cs.allow0RTT = allow0RTT
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT) qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr) addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
cs.tlsConf = quicConf.TLSConfig cs.tlsConf = quicConf.TLSConfig
@ -165,13 +165,13 @@ func newCryptoSetup(
connID protocol.ConnectionID, connID protocol.ConnectionID,
tp *wire.TransportParameters, tp *wire.TransportParameters,
rttStats *utils.RTTStats, rttStats *utils.RTTStats,
tracer logging.ConnectionTracer, tracer *logging.ConnectionTracer,
logger utils.Logger, logger utils.Logger,
perspective protocol.Perspective, perspective protocol.Perspective,
version protocol.VersionNumber, version protocol.VersionNumber,
) *cryptoSetup { ) *cryptoSetup {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version) initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil { if tracer != nil && tracer.UpdatedKeyFromTLS != nil {
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
} }
@ -193,7 +193,7 @@ func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version) initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
h.initialSealer = initialSealer h.initialSealer = initialSealer
h.initialOpener = initialOpener h.initialOpener = initialOpener
if h.tracer != nil { if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
} }
@ -347,10 +347,13 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo
} }
func (h *cryptoSetup) getDataForSessionTicket() []byte { func (h *cryptoSetup) getDataForSessionTicket() []byte {
return (&sessionTicket{ ticket := &sessionTicket{
Parameters: h.ourParams, RTT: h.rttStats.SmoothedRTT(),
RTT: h.rttStats.SmoothedRTT(), }
}).Marshal() if h.allow0RTT {
ticket.Parameters = h.ourParams
}
return ticket.Marshal()
} }
// GetSessionTicket generates a new session ticket. // GetSessionTicket generates a new session ticket.
@ -379,12 +382,16 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
return ticket, nil return ticket, nil
} }
// accept0RTT is called for the server when receiving the client's session ticket. // handleSessionTicket is called for the server when receiving the client's session ticket.
// It decides whether to accept 0-RTT. // It reads parameters from the session ticket and decides whether to accept 0-RTT when the session ticket is used for 0-RTT.
func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { func (h *cryptoSetup) handleSessionTicket(sessionTicketData []byte, using0RTT bool) bool {
var t sessionTicket var t sessionTicket
if err := t.Unmarshal(sessionTicketData); err != nil { if err := t.Unmarshal(sessionTicketData, using0RTT); err != nil {
h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error()) h.logger.Debugf("Unmarshalling session ticket failed: %s", err.Error())
return false
}
h.rttStats.SetInitialRTT(t.RTT)
if !using0RTT {
return false return false
} }
valid := h.ourParams.ValidFor0RTT(t.Parameters) valid := h.ourParams.ValidFor0RTT(t.Parameters)
@ -397,7 +404,6 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
return false return false
} }
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
return true return true
} }
@ -451,7 +457,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr
} }
h.mutex.Unlock() h.mutex.Unlock()
h.events = append(h.events, Event{Kind: EventReceivedReadKeys}) h.events = append(h.events, Event{Kind: EventReceivedReadKeys})
if h.tracer != nil { if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite())
} }
} }
@ -473,7 +479,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
if h.logger.Debug() { if h.logger.Debug() {
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
} }
if h.tracer != nil { if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
} }
// don't set used0RTT here. 0-RTT might still get rejected. // don't set used0RTT here. 0-RTT might still get rejected.
@ -497,7 +503,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
h.used0RTT.Store(true) h.used0RTT.Store(true)
h.zeroRTTSealer = nil h.zeroRTTSealer = nil
h.logger.Debugf("Dropping 0-RTT keys.") h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil { if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
} }
} }
@ -505,7 +511,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
panic("unexpected write encryption level") panic("unexpected write encryption level")
} }
h.mutex.Unlock() h.mutex.Unlock()
if h.tracer != nil { if h.tracer != nil && h.tracer.UpdatedKeyFromTLS != nil {
h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective)
} }
} }
@ -645,7 +651,7 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) { if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
h.zeroRTTOpener = nil h.zeroRTTOpener = nil
h.logger.Debugf("Dropping 0-RTT keys.") h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil { if h.tracer != nil && h.tracer.DroppedEncryptionLevel != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT) h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
} }
} }

View File

@ -10,7 +10,7 @@ import (
"github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/quicvarint"
) )
const sessionTicketRevision = 3 const sessionTicketRevision = 4
type sessionTicket struct { type sessionTicket struct {
Parameters *wire.TransportParameters Parameters *wire.TransportParameters
@ -21,10 +21,13 @@ func (t *sessionTicket) Marshal() []byte {
b := make([]byte, 0, 256) b := make([]byte, 0, 256)
b = quicvarint.Append(b, sessionTicketRevision) b = quicvarint.Append(b, sessionTicketRevision)
b = quicvarint.Append(b, uint64(t.RTT.Microseconds())) b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
if t.Parameters == nil {
return b
}
return t.Parameters.MarshalForSessionTicket(b) return t.Parameters.MarshalForSessionTicket(b)
} }
func (t *sessionTicket) Unmarshal(b []byte) error { func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
r := bytes.NewReader(b) r := bytes.NewReader(b)
rev, err := quicvarint.Read(r) rev, err := quicvarint.Read(r)
if err != nil { if err != nil {
@ -37,11 +40,15 @@ func (t *sessionTicket) Unmarshal(b []byte) error {
if err != nil { if err != nil {
return errors.New("failed to read RTT") return errors.New("failed to read RTT")
} }
var tp wire.TransportParameters if using0RTT {
if err := tp.UnmarshalFromSessionTicket(r); err != nil { var tp wire.TransportParameters
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error()) if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
} else if r.Len() > 0 {
return fmt.Errorf("the session ticket has more bytes than expected")
} }
t.Parameters = &tp
t.RTT = time.Duration(rtt) * time.Microsecond t.RTT = time.Duration(rtt) * time.Microsecond
return nil return nil
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"encoding/asn1" "encoding/asn1"
"fmt" "fmt"
"io"
"net" "net"
"time" "time"
@ -45,15 +44,9 @@ type TokenGenerator struct {
tokenProtector tokenProtector tokenProtector tokenProtector
} }
// NewTokenGenerator initializes a new TookenGenerator // NewTokenGenerator initializes a new TokenGenerator
func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { func NewTokenGenerator(key TokenProtectorKey) *TokenGenerator {
tokenProtector, err := newTokenProtector(rand) return &TokenGenerator{tokenProtector: newTokenProtector(key)}
if err != nil {
return nil, err
}
return &TokenGenerator{
tokenProtector: tokenProtector,
}, nil
} }
// NewRetryToken generates a new token for a Retry for a given source address // NewRetryToken generates a new token for a Retry for a given source address

View File

@ -3,6 +3,7 @@ package handshake
import ( import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"io" "io"
@ -10,6 +11,9 @@ import (
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )
// TokenProtectorKey is the key used to encrypt both Retry and session resumption tokens.
type TokenProtectorKey [32]byte
// TokenProtector is used to create and verify a token // TokenProtector is used to create and verify a token
type tokenProtector interface { type tokenProtector interface {
// NewToken creates a new token // NewToken creates a new token
@ -18,40 +22,29 @@ type tokenProtector interface {
DecodeToken([]byte) ([]byte, error) DecodeToken([]byte) ([]byte, error)
} }
const ( const tokenNonceSize = 32
tokenSecretSize = 32
tokenNonceSize = 32
)
// tokenProtector is used to create and verify a token // tokenProtector is used to create and verify a token
type tokenProtectorImpl struct { type tokenProtectorImpl struct {
rand io.Reader key TokenProtectorKey
secret []byte
} }
// newTokenProtector creates a source for source address tokens // newTokenProtector creates a source for source address tokens
func newTokenProtector(rand io.Reader) (tokenProtector, error) { func newTokenProtector(key TokenProtectorKey) tokenProtector {
secret := make([]byte, tokenSecretSize) return &tokenProtectorImpl{key: key}
if _, err := rand.Read(secret); err != nil {
return nil, err
}
return &tokenProtectorImpl{
rand: rand,
secret: secret,
}, nil
} }
// NewToken encodes data into a new token. // NewToken encodes data into a new token.
func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, tokenNonceSize) var nonce [tokenNonceSize]byte
if _, err := s.rand.Read(nonce); err != nil { if _, err := rand.Read(nonce[:]); err != nil {
return nil, err return nil, err
} }
aead, aeadNonce, err := s.createAEAD(nonce) aead, aeadNonce, err := s.createAEAD(nonce[:])
if err != nil { if err != nil {
return nil, err return nil, err
} }
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil return append(nonce[:], aead.Seal(nil, aeadNonce, data, nil)...), nil
} }
// DecodeToken decodes a token. // DecodeToken decodes a token.
@ -68,7 +61,7 @@ func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
} }
func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) { func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source")) h := hkdf.New(sha256.New, s.key[:], nonce, []byte("quic-go token source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256 key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil { if _, err := io.ReadFull(h, key); err != nil {
return nil, nil, err return nil, nil, err

View File

@ -57,7 +57,7 @@ type updatableAEAD struct {
rttStats *utils.RTTStats rttStats *utils.RTTStats
tracer logging.ConnectionTracer tracer *logging.ConnectionTracer
logger utils.Logger logger utils.Logger
version protocol.VersionNumber version protocol.VersionNumber
@ -70,7 +70,7 @@ var (
_ ShortHeaderSealer = &updatableAEAD{} _ ShortHeaderSealer = &updatableAEAD{}
) )
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD { func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.VersionNumber) *updatableAEAD {
return &updatableAEAD{ return &updatableAEAD{
firstPacketNumber: protocol.InvalidPacketNumber, firstPacketNumber: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber, largestAcked: protocol.InvalidPacketNumber,
@ -86,7 +86,7 @@ func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer,
func (a *updatableAEAD) rollKeys() { func (a *updatableAEAD) rollKeys() {
if a.prevRcvAEAD != nil { if a.prevRcvAEAD != nil {
a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
if a.tracer != nil { if a.tracer != nil && a.tracer.DroppedKey != nil {
a.tracer.DroppedKey(a.keyPhase - 1) a.tracer.DroppedKey(a.keyPhase - 1)
} }
a.prevRcvAEADExpiry = time.Time{} a.prevRcvAEADExpiry = time.Time{}
@ -182,7 +182,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
a.prevRcvAEAD = nil a.prevRcvAEAD = nil
a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
a.prevRcvAEADExpiry = time.Time{} a.prevRcvAEADExpiry = time.Time{}
if a.tracer != nil { if a.tracer != nil && a.tracer.DroppedKey != nil {
a.tracer.DroppedKey(a.keyPhase - 1) a.tracer.DroppedKey(a.keyPhase - 1)
} }
} }
@ -216,7 +216,7 @@ func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
// The peer initiated this key update. It's safe to drop the keys for the previous generation now. // The peer initiated this key update. It's safe to drop the keys for the previous generation now.
// Start a timer to drop the previous key generation. // Start a timer to drop the previous key generation.
a.startKeyDropTimer(rcvTime) a.startKeyDropTimer(rcvTime)
if a.tracer != nil { if a.tracer != nil && a.tracer.UpdatedKey != nil {
a.tracer.UpdatedKey(a.keyPhase, true) a.tracer.UpdatedKey(a.keyPhase, true)
} }
a.firstRcvdWithCurrentKey = pn a.firstRcvdWithCurrentKey = pn
@ -308,7 +308,7 @@ func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.shouldInitiateKeyUpdate() { if a.shouldInitiateKeyUpdate() {
a.rollKeys() a.rollKeys()
a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
if a.tracer != nil { if a.tracer != nil && a.tracer.UpdatedKey != nil {
a.tracer.UpdatedKey(a.keyPhase, false) a.tracer.UpdatedKey(a.keyPhase, false)
} }
} }

View File

@ -65,9 +65,6 @@ const MaxAcceptQueueSize = 32
// TokenValidity is the duration that a (non-retry) token is considered valid // TokenValidity is the duration that a (non-retry) token is considered valid
const TokenValidity = 24 * time.Hour const TokenValidity = 24 * time.Hour
// RetryTokenValidity is the duration that a retry token is considered valid
const RetryTokenValidity = 10 * time.Second
// MaxOutstandingSentPackets is maximum number of packets saved for retransmission. // MaxOutstandingSentPackets is maximum number of packets saved for retransmission.
// When reached, it imposes a soft limit on sending new packets: // When reached, it imposes a soft limit on sending new packets:
// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent. // Sending ACKs and retransmission is still allowed, but now new regular packets can be sent.
@ -108,9 +105,6 @@ const DefaultIdleTimeout = 30 * time.Second
// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion. // DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion.
const DefaultHandshakeIdleTimeout = 5 * time.Second const DefaultHandshakeIdleTimeout = 5 * time.Second
// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
const DefaultHandshakeTimeout = 10 * time.Second
// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive. // MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive.
// It should be shorter than the time that NATs clear their mapping. // It should be shorter than the time that NATs clear their mapping.
const MaxKeepAliveInterval = 20 * time.Second const MaxKeepAliveInterval = 20 * time.Second

View File

@ -37,14 +37,48 @@ func (t PacketType) String() string {
type ECN uint8 type ECN uint8
const ( const (
ECNNon ECN = iota // 00 ECNUnsupported ECN = iota
ECT1 // 01 ECNNon // 00
ECT0 // 10 ECT1 // 01
ECNCE // 11 ECT0 // 10
ECNCE // 11
) )
func ParseECNHeaderBits(bits byte) ECN {
switch bits {
case 0:
return ECNNon
case 0b00000010:
return ECT0
case 0b00000001:
return ECT1
case 0b00000011:
return ECNCE
default:
panic("invalid ECN bits")
}
}
func (e ECN) ToHeaderBits() byte {
//nolint:exhaustive // There are only 4 values.
switch e {
case ECNNon:
return 0
case ECT0:
return 0b00000010
case ECT1:
return 0b00000001
case ECNCE:
return 0b00000011
default:
panic("ECN unsupported")
}
}
func (e ECN) String() string { func (e ECN) String() string {
switch e { switch e {
case ECNUnsupported:
return "ECN unsupported"
case ECNNon: case ECNNon:
return "Not-ECT" return "Not-ECT"
case ECT1: case ECT1:

View File

@ -39,13 +39,15 @@ const (
QUICHandshakeDone = qtls.QUICHandshakeDone QUICHandshakeDone = qtls.QUICHandshakeDone
) )
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) { func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, handleSessionTicket func([]byte, bool) bool) {
qtls.InitSessionTicketKeys(conf.TLSConfig) qtls.InitSessionTicketKeys(conf.TLSConfig)
conf.TLSConfig = conf.TLSConfig.Clone() conf.TLSConfig = conf.TLSConfig.Clone()
conf.TLSConfig.MinVersion = tls.VersionTLS13 conf.TLSConfig.MinVersion = tls.VersionTLS13
conf.ExtraConfig = &qtls.ExtraConfig{ conf.ExtraConfig = &qtls.ExtraConfig{
Enable0RTT: enable0RTT, Enable0RTT: enable0RTT,
Accept0RTT: accept0RTT, Accept0RTT: func(data []byte) bool {
return handleSessionTicket(data, true)
},
GetAppDataForSessionTicket: getDataForSessionTicket, GetAppDataForSessionTicket: getDataForSessionTicket,
} }
} }

View File

@ -41,7 +41,7 @@ const (
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) } func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) } func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) { func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
conf := qconf.TLSConfig conf := qconf.TLSConfig
// Workaround for https://github.com/golang/go/issues/60506. // Workaround for https://github.com/golang/go/issues/60506.
@ -55,11 +55,9 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce
// add callbacks to save transport parameters into the session ticket // add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession origWrapSession := conf.WrapSession
conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) {
// Add QUIC transport parameters if this is a 0-RTT packet. // Add QUIC session ticket
// TODO(#3853): also save the RTT for non-0-RTT tickets state.Extra = append(state.Extra, addExtraPrefix(getData()))
if state.EarlyData {
state.Extra = append(state.Extra, addExtraPrefix(getData()))
}
if origWrapSession != nil { if origWrapSession != nil {
return origWrapSession(cs, state) return origWrapSession(cs, state)
} }
@ -83,14 +81,14 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, acce
if err != nil || state == nil { if err != nil || state == nil {
return nil, err return nil, err
} }
if state.EarlyData {
extra := findExtraData(state.Extra) extra := findExtraData(state.Extra)
if unwrapCount == 1 && extra != nil { // first session ticket if extra != nil {
state.EarlyData = accept0RTT(extra) state.EarlyData = handleSessionTicket(extra, state.EarlyData && unwrapCount == 1)
} else { // subsequent session ticket, can't be used for 0-RTT } else {
state.EarlyData = false state.EarlyData = false
}
} }
return state, nil return state, nil
} }
} }

View File

@ -0,0 +1,255 @@
package logging
import (
"net"
"time"
)
// A ConnectionTracer records events.
type ConnectionTracer struct {
StartedConnection func(local, remote net.Addr, srcConnID, destConnID ConnectionID)
NegotiatedVersion func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber)
ClosedConnection func(error)
SentTransportParameters func(*TransportParameters)
ReceivedTransportParameters func(*TransportParameters)
RestoredTransportParameters func(parameters *TransportParameters) // for 0-RTT
SentLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, *AckFrame, []Frame)
SentShortHeaderPacket func(*ShortHeader, ByteCount, ECN, *AckFrame, []Frame)
ReceivedVersionNegotiationPacket func(dest, src ArbitraryLenConnectionID, _ []VersionNumber)
ReceivedRetry func(*Header)
ReceivedLongHeaderPacket func(*ExtendedHeader, ByteCount, ECN, []Frame)
ReceivedShortHeaderPacket func(*ShortHeader, ByteCount, ECN, []Frame)
BufferedPacket func(PacketType, ByteCount)
DroppedPacket func(PacketType, ByteCount, PacketDropReason)
UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int)
AcknowledgedPacket func(EncryptionLevel, PacketNumber)
LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason)
UpdatedCongestionState func(CongestionState)
UpdatedPTOCount func(value uint32)
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
UpdatedKey func(generation KeyPhase, remote bool)
DroppedEncryptionLevel func(EncryptionLevel)
DroppedKey func(generation KeyPhase)
SetLossTimer func(TimerType, EncryptionLevel, time.Time)
LossTimerExpired func(TimerType, EncryptionLevel)
LossTimerCanceled func()
ECNStateUpdated func(state ECNState, trigger ECNStateTrigger)
// Close is called when the connection is closed.
Close func()
Debug func(name, msg string)
}
// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers.
func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &ConnectionTracer{
StartedConnection: func(local, remote net.Addr, srcConnID, destConnID ConnectionID) {
for _, t := range tracers {
if t.StartedConnection != nil {
t.StartedConnection(local, remote, srcConnID, destConnID)
}
}
},
NegotiatedVersion: func(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) {
for _, t := range tracers {
if t.NegotiatedVersion != nil {
t.NegotiatedVersion(chosen, clientVersions, serverVersions)
}
}
},
ClosedConnection: func(e error) {
for _, t := range tracers {
if t.ClosedConnection != nil {
t.ClosedConnection(e)
}
}
},
SentTransportParameters: func(tp *TransportParameters) {
for _, t := range tracers {
if t.SentTransportParameters != nil {
t.SentTransportParameters(tp)
}
}
},
ReceivedTransportParameters: func(tp *TransportParameters) {
for _, t := range tracers {
if t.ReceivedTransportParameters != nil {
t.ReceivedTransportParameters(tp)
}
}
},
RestoredTransportParameters: func(tp *TransportParameters) {
for _, t := range tracers {
if t.RestoredTransportParameters != nil {
t.RestoredTransportParameters(tp)
}
}
},
SentLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range tracers {
if t.SentLongHeaderPacket != nil {
t.SentLongHeaderPacket(hdr, size, ecn, ack, frames)
}
}
},
SentShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, ack *AckFrame, frames []Frame) {
for _, t := range tracers {
if t.SentShortHeaderPacket != nil {
t.SentShortHeaderPacket(hdr, size, ecn, ack, frames)
}
}
},
ReceivedVersionNegotiationPacket: func(dest, src ArbitraryLenConnectionID, versions []VersionNumber) {
for _, t := range tracers {
if t.ReceivedVersionNegotiationPacket != nil {
t.ReceivedVersionNegotiationPacket(dest, src, versions)
}
}
},
ReceivedRetry: func(hdr *Header) {
for _, t := range tracers {
if t.ReceivedRetry != nil {
t.ReceivedRetry(hdr)
}
}
},
ReceivedLongHeaderPacket: func(hdr *ExtendedHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range tracers {
if t.ReceivedLongHeaderPacket != nil {
t.ReceivedLongHeaderPacket(hdr, size, ecn, frames)
}
}
},
ReceivedShortHeaderPacket: func(hdr *ShortHeader, size ByteCount, ecn ECN, frames []Frame) {
for _, t := range tracers {
if t.ReceivedShortHeaderPacket != nil {
t.ReceivedShortHeaderPacket(hdr, size, ecn, frames)
}
}
},
BufferedPacket: func(typ PacketType, size ByteCount) {
for _, t := range tracers {
if t.BufferedPacket != nil {
t.BufferedPacket(typ, size)
}
}
},
DroppedPacket: func(typ PacketType, size ByteCount, reason PacketDropReason) {
for _, t := range tracers {
if t.DroppedPacket != nil {
t.DroppedPacket(typ, size, reason)
}
}
},
UpdatedMetrics: func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) {
for _, t := range tracers {
if t.UpdatedMetrics != nil {
t.UpdatedMetrics(rttStats, cwnd, bytesInFlight, packetsInFlight)
}
}
},
AcknowledgedPacket: func(encLevel EncryptionLevel, pn PacketNumber) {
for _, t := range tracers {
if t.AcknowledgedPacket != nil {
t.AcknowledgedPacket(encLevel, pn)
}
}
},
LostPacket: func(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) {
for _, t := range tracers {
if t.LostPacket != nil {
t.LostPacket(encLevel, pn, reason)
}
}
},
UpdatedCongestionState: func(state CongestionState) {
for _, t := range tracers {
if t.UpdatedCongestionState != nil {
t.UpdatedCongestionState(state)
}
}
},
UpdatedPTOCount: func(value uint32) {
for _, t := range tracers {
if t.UpdatedPTOCount != nil {
t.UpdatedPTOCount(value)
}
}
},
UpdatedKeyFromTLS: func(encLevel EncryptionLevel, perspective Perspective) {
for _, t := range tracers {
if t.UpdatedKeyFromTLS != nil {
t.UpdatedKeyFromTLS(encLevel, perspective)
}
}
},
UpdatedKey: func(generation KeyPhase, remote bool) {
for _, t := range tracers {
if t.UpdatedKey != nil {
t.UpdatedKey(generation, remote)
}
}
},
DroppedEncryptionLevel: func(encLevel EncryptionLevel) {
for _, t := range tracers {
if t.DroppedEncryptionLevel != nil {
t.DroppedEncryptionLevel(encLevel)
}
}
},
DroppedKey: func(generation KeyPhase) {
for _, t := range tracers {
if t.DroppedKey != nil {
t.DroppedKey(generation)
}
}
},
SetLossTimer: func(typ TimerType, encLevel EncryptionLevel, exp time.Time) {
for _, t := range tracers {
if t.SetLossTimer != nil {
t.SetLossTimer(typ, encLevel, exp)
}
}
},
LossTimerExpired: func(typ TimerType, encLevel EncryptionLevel) {
for _, t := range tracers {
if t.LossTimerExpired != nil {
t.LossTimerExpired(typ, encLevel)
}
}
},
LossTimerCanceled: func() {
for _, t := range tracers {
if t.LossTimerCanceled != nil {
t.LossTimerCanceled()
}
}
},
ECNStateUpdated: func(state ECNState, trigger ECNStateTrigger) {
for _, t := range tracers {
if t.ECNStateUpdated != nil {
t.ECNStateUpdated(state, trigger)
}
}
},
Close: func() {
for _, t := range tracers {
if t.Close != nil {
t.Close()
}
}
},
Debug: func(name, msg string) {
for _, t := range tracers {
if t.Debug != nil {
t.Debug(name, msg)
}
}
},
}
}

View File

@ -3,9 +3,6 @@
package logging package logging
import ( import (
"net"
"time"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
@ -15,6 +12,8 @@ import (
type ( type (
// A ByteCount is used to count bytes. // A ByteCount is used to count bytes.
ByteCount = protocol.ByteCount ByteCount = protocol.ByteCount
// ECN is the ECN value
ECN = protocol.ECN
// A ConnectionID is a QUIC Connection ID. // A ConnectionID is a QUIC Connection ID.
ConnectionID = protocol.ConnectionID ConnectionID = protocol.ConnectionID
// An ArbitraryLenConnectionID is a QUIC Connection ID that can be up to 255 bytes long. // An ArbitraryLenConnectionID is a QUIC Connection ID that can be up to 255 bytes long.
@ -58,6 +57,19 @@ type (
RTTStats = utils.RTTStats RTTStats = utils.RTTStats
) )
const (
// ECNUnsupported means that no ECN value was set / received
ECNUnsupported = protocol.ECNUnsupported
// ECTNot is Not-ECT
ECTNot = protocol.ECNNon
// ECT0 is ECT(0)
ECT0 = protocol.ECT0
// ECT1 is ECT(1)
ECT1 = protocol.ECT1
// ECNCE is CE
ECNCE = protocol.ECNCE
)
const ( const (
// KeyPhaseZero is key phase bit 0 // KeyPhaseZero is key phase bit 0
KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero KeyPhaseZero KeyPhaseBit = protocol.KeyPhaseZero
@ -97,43 +109,3 @@ type ShortHeader struct {
PacketNumberLen protocol.PacketNumberLen PacketNumberLen protocol.PacketNumberLen
KeyPhase KeyPhaseBit KeyPhase KeyPhaseBit
} }
// A Tracer traces events.
type Tracer interface {
SentPacket(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason)
}
// A ConnectionTracer records events.
type ConnectionTracer interface {
StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID)
NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber)
ClosedConnection(error)
SentTransportParameters(*TransportParameters)
ReceivedTransportParameters(*TransportParameters)
RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT
SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame)
SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame)
ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber)
ReceivedRetry(*Header)
ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame)
ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame)
BufferedPacket(PacketType, ByteCount)
DroppedPacket(PacketType, ByteCount, PacketDropReason)
UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int)
AcknowledgedPacket(EncryptionLevel, PacketNumber)
LostPacket(EncryptionLevel, PacketNumber, PacketLossReason)
UpdatedCongestionState(CongestionState)
UpdatedPTOCount(value uint32)
UpdatedKeyFromTLS(EncryptionLevel, Perspective)
UpdatedKey(generation KeyPhase, remote bool)
DroppedEncryptionLevel(EncryptionLevel)
DroppedKey(generation KeyPhase)
SetLossTimer(TimerType, EncryptionLevel, time.Time)
LossTimerExpired(TimerType, EncryptionLevel)
LossTimerCanceled()
// Close is called when the connection is closed.
Close()
Debug(name, msg string)
}

View File

@ -1,4 +0,0 @@
package logging
//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/quic-go/quic-go/logging -destination mock_connection_tracer_test.go github.com/quic-go/quic-go/logging ConnectionTracer"
//go:generate sh -c "go run github.com/golang/mock/mockgen -package logging -self_package github.com/quic-go/quic-go/logging -destination mock_tracer_test.go github.com/quic-go/quic-go/logging Tracer"

View File

@ -1,226 +0,0 @@
package logging
import (
"net"
"time"
)
type tracerMultiplexer struct {
tracers []Tracer
}
var _ Tracer = &tracerMultiplexer{}
// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers.
func NewMultiplexedTracer(tracers ...Tracer) Tracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &tracerMultiplexer{tracers}
}
func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) {
for _, t := range m.tracers {
t.SentPacket(remote, hdr, size, frames)
}
}
func (m *tracerMultiplexer) SentVersionNegotiationPacket(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) {
for _, t := range m.tracers {
t.SentVersionNegotiationPacket(remote, dest, src, versions)
}
}
func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) {
for _, t := range m.tracers {
t.DroppedPacket(remote, typ, size, reason)
}
}
type connTracerMultiplexer struct {
tracers []ConnectionTracer
}
var _ ConnectionTracer = &connTracerMultiplexer{}
// NewMultiplexedConnectionTracer creates a new connection tracer that multiplexes events to multiple tracers.
func NewMultiplexedConnectionTracer(tracers ...ConnectionTracer) ConnectionTracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &connTracerMultiplexer{tracers: tracers}
}
func (m *connTracerMultiplexer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) {
for _, t := range m.tracers {
t.StartedConnection(local, remote, srcConnID, destConnID)
}
}
func (m *connTracerMultiplexer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) {
for _, t := range m.tracers {
t.NegotiatedVersion(chosen, clientVersions, serverVersions)
}
}
func (m *connTracerMultiplexer) ClosedConnection(e error) {
for _, t := range m.tracers {
t.ClosedConnection(e)
}
}
func (m *connTracerMultiplexer) SentTransportParameters(tp *TransportParameters) {
for _, t := range m.tracers {
t.SentTransportParameters(tp)
}
}
func (m *connTracerMultiplexer) ReceivedTransportParameters(tp *TransportParameters) {
for _, t := range m.tracers {
t.ReceivedTransportParameters(tp)
}
}
func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParameters) {
for _, t := range m.tracers {
t.RestoredTransportParameters(tp)
}
}
func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) {
for _, t := range m.tracers {
t.SentLongHeaderPacket(hdr, size, ack, frames)
}
}
func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) {
for _, t := range m.tracers {
t.SentShortHeaderPacket(hdr, size, ack, frames)
}
}
func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, versions []VersionNumber) {
for _, t := range m.tracers {
t.ReceivedVersionNegotiationPacket(dest, src, versions)
}
}
func (m *connTracerMultiplexer) ReceivedRetry(hdr *Header) {
for _, t := range m.tracers {
t.ReceivedRetry(hdr)
}
}
func (m *connTracerMultiplexer) ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) {
for _, t := range m.tracers {
t.ReceivedLongHeaderPacket(hdr, size, frames)
}
}
func (m *connTracerMultiplexer) ReceivedShortHeaderPacket(hdr *ShortHeader, size ByteCount, frames []Frame) {
for _, t := range m.tracers {
t.ReceivedShortHeaderPacket(hdr, size, frames)
}
}
func (m *connTracerMultiplexer) BufferedPacket(typ PacketType, size ByteCount) {
for _, t := range m.tracers {
t.BufferedPacket(typ, size)
}
}
func (m *connTracerMultiplexer) DroppedPacket(typ PacketType, size ByteCount, reason PacketDropReason) {
for _, t := range m.tracers {
t.DroppedPacket(typ, size, reason)
}
}
func (m *connTracerMultiplexer) UpdatedCongestionState(state CongestionState) {
for _, t := range m.tracers {
t.UpdatedCongestionState(state)
}
}
func (m *connTracerMultiplexer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFLight ByteCount, packetsInFlight int) {
for _, t := range m.tracers {
t.UpdatedMetrics(rttStats, cwnd, bytesInFLight, packetsInFlight)
}
}
func (m *connTracerMultiplexer) AcknowledgedPacket(encLevel EncryptionLevel, pn PacketNumber) {
for _, t := range m.tracers {
t.AcknowledgedPacket(encLevel, pn)
}
}
func (m *connTracerMultiplexer) LostPacket(encLevel EncryptionLevel, pn PacketNumber, reason PacketLossReason) {
for _, t := range m.tracers {
t.LostPacket(encLevel, pn, reason)
}
}
func (m *connTracerMultiplexer) UpdatedPTOCount(value uint32) {
for _, t := range m.tracers {
t.UpdatedPTOCount(value)
}
}
func (m *connTracerMultiplexer) UpdatedKeyFromTLS(encLevel EncryptionLevel, perspective Perspective) {
for _, t := range m.tracers {
t.UpdatedKeyFromTLS(encLevel, perspective)
}
}
func (m *connTracerMultiplexer) UpdatedKey(generation KeyPhase, remote bool) {
for _, t := range m.tracers {
t.UpdatedKey(generation, remote)
}
}
func (m *connTracerMultiplexer) DroppedEncryptionLevel(encLevel EncryptionLevel) {
for _, t := range m.tracers {
t.DroppedEncryptionLevel(encLevel)
}
}
func (m *connTracerMultiplexer) DroppedKey(generation KeyPhase) {
for _, t := range m.tracers {
t.DroppedKey(generation)
}
}
func (m *connTracerMultiplexer) SetLossTimer(typ TimerType, encLevel EncryptionLevel, exp time.Time) {
for _, t := range m.tracers {
t.SetLossTimer(typ, encLevel, exp)
}
}
func (m *connTracerMultiplexer) LossTimerExpired(typ TimerType, encLevel EncryptionLevel) {
for _, t := range m.tracers {
t.LossTimerExpired(typ, encLevel)
}
}
func (m *connTracerMultiplexer) LossTimerCanceled() {
for _, t := range m.tracers {
t.LossTimerCanceled()
}
}
func (m *connTracerMultiplexer) Debug(name, msg string) {
for _, t := range m.tracers {
t.Debug(name, msg)
}
}
func (m *connTracerMultiplexer) Close() {
for _, t := range m.tracers {
t.Close()
}
}

View File

@ -1,58 +0,0 @@
package logging
import (
"net"
"time"
)
// The NullTracer is a Tracer that does nothing.
// It is useful for embedding.
type NullTracer struct{}
var _ Tracer = &NullTracer{}
func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {}
func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
}
func (n NullTracer) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) {}
// The NullConnectionTracer is a ConnectionTracer that does nothing.
// It is useful for embedding.
type NullConnectionTracer struct{}
var _ ConnectionTracer = &NullConnectionTracer{}
func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) {
}
func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) {
}
func (n NullConnectionTracer) ClosedConnection(err error) {}
func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {}
func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, *AckFrame, []Frame) {}
func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
}
func (n NullConnectionTracer) ReceivedRetry(*Header) {}
func (n NullConnectionTracer) ReceivedLongHeaderPacket(*ExtendedHeader, ByteCount, []Frame) {}
func (n NullConnectionTracer) ReceivedShortHeaderPacket(*ShortHeader, ByteCount, []Frame) {}
func (n NullConnectionTracer) BufferedPacket(PacketType, ByteCount) {}
func (n NullConnectionTracer) DroppedPacket(PacketType, ByteCount, PacketDropReason) {}
func (n NullConnectionTracer) UpdatedMetrics(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int) {
}
func (n NullConnectionTracer) AcknowledgedPacket(EncryptionLevel, PacketNumber) {}
func (n NullConnectionTracer) LostPacket(EncryptionLevel, PacketNumber, PacketLossReason) {}
func (n NullConnectionTracer) UpdatedCongestionState(CongestionState) {}
func (n NullConnectionTracer) UpdatedPTOCount(uint32) {}
func (n NullConnectionTracer) UpdatedKeyFromTLS(EncryptionLevel, Perspective) {}
func (n NullConnectionTracer) UpdatedKey(keyPhase KeyPhase, remote bool) {}
func (n NullConnectionTracer) DroppedEncryptionLevel(EncryptionLevel) {}
func (n NullConnectionTracer) DroppedKey(KeyPhase) {}
func (n NullConnectionTracer) SetLossTimer(TimerType, EncryptionLevel, time.Time) {}
func (n NullConnectionTracer) LossTimerExpired(timerType TimerType, level EncryptionLevel) {}
func (n NullConnectionTracer) LossTimerCanceled() {}
func (n NullConnectionTracer) Close() {}
func (n NullConnectionTracer) Debug(name, msg string) {}

43
vendor/github.com/quic-go/quic-go/logging/tracer.go generated vendored Normal file
View File

@ -0,0 +1,43 @@
package logging
import "net"
// A Tracer traces events.
type Tracer struct {
SentPacket func(net.Addr, *Header, ByteCount, []Frame)
SentVersionNegotiationPacket func(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber)
DroppedPacket func(net.Addr, PacketType, ByteCount, PacketDropReason)
}
// NewMultiplexedTracer creates a new tracer that multiplexes events to multiple tracers.
func NewMultiplexedTracer(tracers ...*Tracer) *Tracer {
if len(tracers) == 0 {
return nil
}
if len(tracers) == 1 {
return tracers[0]
}
return &Tracer{
SentPacket: func(remote net.Addr, hdr *Header, size ByteCount, frames []Frame) {
for _, t := range tracers {
if t.SentPacket != nil {
t.SentPacket(remote, hdr, size, frames)
}
}
},
SentVersionNegotiationPacket: func(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) {
for _, t := range tracers {
if t.SentVersionNegotiationPacket != nil {
t.SentVersionNegotiationPacket(remote, dest, src, versions)
}
}
},
DroppedPacket: func(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) {
for _, t := range tracers {
if t.DroppedPacket != nil {
t.DroppedPacket(remote, typ, size, reason)
}
}
},
}
}

View File

@ -92,3 +92,37 @@ const (
// CongestionStateApplicationLimited means that the congestion controller is application limited // CongestionStateApplicationLimited means that the congestion controller is application limited
CongestionStateApplicationLimited CongestionStateApplicationLimited
) )
// ECNState is the state of the ECN state machine (see Appendix A.4 of RFC 9000)
type ECNState uint8
const (
// ECNStateTesting is the testing state
ECNStateTesting ECNState = 1 + iota
// ECNStateUnknown is the unknown state
ECNStateUnknown
// ECNStateFailed is the failed state
ECNStateFailed
// ECNStateCapable is the capable state
ECNStateCapable
)
// ECNStateTrigger is a trigger for an ECN state transition.
type ECNStateTrigger uint8
const (
ECNTriggerNoTrigger ECNStateTrigger = iota
// ECNFailedNoECNCounts is emitted when an ACK acknowledges ECN-marked packets,
// but doesn't contain any ECN counts
ECNFailedNoECNCounts
// ECNFailedDecreasedECNCounts is emitted when an ACK frame decreases ECN counts
ECNFailedDecreasedECNCounts
// ECNFailedLostAllTestingPackets is emitted when all ECN testing packets are declared lost
ECNFailedLostAllTestingPackets
// ECNFailedMoreECNCountsThanSent is emitted when an ACK contains more ECN counts than ECN-marked packets were sent
ECNFailedMoreECNCountsThanSent
// ECNFailedTooFewECNCounts is emitted when an ACK contains fewer ECN counts than it acknowledges packets
ECNFailedTooFewECNCounts
// ECNFailedManglingDetected is emitted when the path marks all ECN-marked packets as CE
ECNFailedManglingDetected
)

View File

@ -2,76 +2,73 @@
package quic package quic
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_conn_test.go github.com/quic-go/quic-go SendConn"
type SendConn = sendConn type SendConn = sendConn
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_raw_conn_test.go github.com/quic-go/quic-go RawConn"
type RawConn = rawConn type RawConn = rawConn
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sender_test.go github.com/quic-go/quic-go Sender"
type Sender = sender type Sender = sender
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_internal_test.go github.com/quic-go/quic-go StreamI"
type StreamI = streamI type StreamI = streamI
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_stream_test.go github.com/quic-go/quic-go CryptoStream"
type CryptoStream = cryptoStream type CryptoStream = cryptoStream
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_receive_stream_internal_test.go github.com/quic-go/quic-go ReceiveStreamI"
type ReceiveStreamI = receiveStreamI type ReceiveStreamI = receiveStreamI
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_send_stream_internal_test.go github.com/quic-go/quic-go SendStreamI"
type SendStreamI = sendStreamI type SendStreamI = sendStreamI
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_getter_test.go github.com/quic-go/quic-go StreamGetter"
type StreamGetter = streamGetter type StreamGetter = streamGetter
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_sender_test.go github.com/quic-go/quic-go StreamSender"
type StreamSender = streamSender type StreamSender = streamSender
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_crypto_data_handler_test.go github.com/quic-go/quic-go CryptoDataHandler"
type CryptoDataHandler = cryptoDataHandler type CryptoDataHandler = cryptoDataHandler
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_frame_source_test.go github.com/quic-go/quic-go FrameSource"
type FrameSource = frameSource type FrameSource = frameSource
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_ack_frame_source_test.go github.com/quic-go/quic-go AckFrameSource"
type AckFrameSource = ackFrameSource type AckFrameSource = ackFrameSource
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_stream_manager_test.go github.com/quic-go/quic-go StreamManager"
type StreamManager = streamManager type StreamManager = streamManager
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_sealing_manager_test.go github.com/quic-go/quic-go SealingManager"
type SealingManager = sealingManager type SealingManager = sealingManager
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unpacker_test.go github.com/quic-go/quic-go Unpacker"
type Unpacker = unpacker type Unpacker = unpacker
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packer_test.go github.com/quic-go/quic-go Packer"
type Packer = packer type Packer = packer
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_mtu_discoverer_test.go github.com/quic-go/quic-go MTUDiscoverer"
type MTUDiscoverer = mtuDiscoverer type MTUDiscoverer = mtuDiscoverer
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_conn_runner_test.go github.com/quic-go/quic-go ConnRunner"
type ConnRunner = connRunner type ConnRunner = connRunner
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_quic_conn_test.go github.com/quic-go/quic-go QUICConn" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_quic_conn_test.go github.com/quic-go/quic-go QUICConn"
type QUICConn = quicConn type QUICConn = quicConn
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_test.go github.com/quic-go/quic-go PacketHandler"
type PacketHandler = packetHandler type PacketHandler = packetHandler
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_unknown_packet_handler_test.go github.com/quic-go/quic-go UnknownPacketHandler" //go:generate sh -c "go run go.uber.org/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
type UnknownPacketHandler = unknownPacketHandler
//go:generate sh -c "go run github.com/golang/mock/mockgen -build_flags=\"-tags=gomock\" -package quic -self_package github.com/quic-go/quic-go -destination mock_packet_handler_manager_test.go github.com/quic-go/quic-go PacketHandlerManager"
type PacketHandlerManager = packetHandlerManager type PacketHandlerManager = packetHandlerManager
// Need to use source mode for the batchConn, since reflect mode follows type aliases. // Need to use source mode for the batchConn, since reflect mode follows type aliases.
// See https://github.com/golang/mock/issues/244 for details. // See https://github.com/golang/mock/issues/244 for details.
// //
//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn" //go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -source sys_conn_oob.go -destination mock_batch_conn_test.go -mock_names batchConn=MockBatchConn"
//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore" //go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_token_store_test.go github.com/quic-go/quic-go TokenStore"
//go:generate sh -c "go run github.com/golang/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn" //go:generate sh -c "go run go.uber.org/mock/mockgen -package quic -self_package github.com/quic-go/quic-go -self_package github.com/quic-go/quic-go -destination mock_packetconn_test.go net PacketConn"

View File

@ -4,7 +4,6 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"errors"
"hash" "hash"
"io" "io"
"net" "net"
@ -21,14 +20,17 @@ type connCapabilities struct {
DF bool DF bool
// GSO (Generic Segmentation Offload) supported // GSO (Generic Segmentation Offload) supported
GSO bool GSO bool
// ECN (Explicit Congestion Notifications) supported
ECN bool
} }
// rawConn is a connection that allow reading of a receivedPackeh. // rawConn is a connection that allow reading of a receivedPackeh.
type rawConn interface { type rawConn interface {
ReadPacket() (receivedPacket, error) ReadPacket() (receivedPacket, error)
// WritePacket writes a packet on the wire. // WritePacket writes a packet on the wire.
// If GSO is enabled, it's the caller's responsibility to set the correct control message. // gsoSize is the size of a single packet, or 0 to disable GSO.
WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) // It is invalid to set gsoSize if capabilities.GSO is not set.
WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error)
LocalAddr() net.Addr LocalAddr() net.Addr
SetReadDeadline(time.Time) error SetReadDeadline(time.Time) error
io.Closer io.Closer
@ -42,13 +44,6 @@ type closePacket struct {
info packetInfo info packetInfo
} }
type unknownPacketHandler interface {
handlePacket(receivedPacket)
setCloseError(error)
}
var errListenerAlreadySet = errors.New("listener already set")
type packetHandlerMap struct { type packetHandlerMap struct {
mutex sync.Mutex mutex sync.Mutex
handlers map[protocol.ConnectionID]packetHandler handlers map[protocol.ConnectionID]packetHandler

View File

@ -1,9 +1,13 @@
package quic package quic
import ( import (
crand "crypto/rand"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/exp/rand"
"github.com/quic-go/quic-go/internal/ackhandler" "github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -67,6 +71,11 @@ type coalescedPacket struct {
shortHdrPacket *shortHeaderPacket shortHdrPacket *shortHeaderPacket
} }
// IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets).
func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool {
return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil
}
func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel {
//nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data).
switch p.header.Type { switch p.header.Type {
@ -122,6 +131,7 @@ type packetPacker struct {
acks ackFrameSource acks ackFrameSource
datagramQueue *datagramQueue datagramQueue *datagramQueue
retransmissionQueue *retransmissionQueue retransmissionQueue *retransmissionQueue
rand rand.Rand
numNonAckElicitingAcks int numNonAckElicitingAcks int
} }
@ -140,6 +150,9 @@ func newPacketPacker(
datagramQueue *datagramQueue, datagramQueue *datagramQueue,
perspective protocol.Perspective, perspective protocol.Perspective,
) *packetPacker { ) *packetPacker {
var b [8]byte
_, _ = crand.Read(b[:])
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
getDestConnID: getDestConnID, getDestConnID: getDestConnID,
@ -151,6 +164,7 @@ func newPacketPacker(
perspective: perspective, perspective: perspective,
framer: framer, framer: framer,
acks: acks, acks: acks,
rand: *rand.New(rand.NewSource(binary.BigEndian.Uint64(b[:]))),
pnManager: packetNumberManager, pnManager: packetNumberManager,
} }
} }
@ -832,6 +846,8 @@ func (p *packetPacker) appendShortHeaderPacket(
}, nil }, nil
} }
// appendPacketPayload serializes the payload of a packet into the raw byte slice.
// It modifies the order of payload.frames.
func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) {
payloadOffset := len(raw) payloadOffset := len(raw)
if pl.ack != nil { if pl.ack != nil {
@ -844,6 +860,11 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr
if paddingLen > 0 { if paddingLen > 0 {
raw = append(raw, make([]byte, paddingLen)...) raw = append(raw, make([]byte, paddingLen)...)
} }
// Randomize the order of the control frames.
// This makes sure that the receiver doesn't rely on the order in which frames are packed.
if len(pl.frames) > 1 {
p.rand.Shuffle(len(pl.frames), func(i, j int) { pl.frames[i], pl.frames[j] = pl.frames[j], pl.frames[i] })
}
for _, f := range pl.frames { for _, f := range pl.frames {
var err error var err error
raw, err = f.Frame.Append(raw, v) raw, err = f.Frame.Append(raw, v)

View File

@ -1,8 +1,6 @@
package quic package quic
import ( import (
"fmt"
"math"
"net" "net"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
@ -11,7 +9,7 @@ import (
// A sendConn allows sending using a simple Write() on a non-connected packet conn. // A sendConn allows sending using a simple Write() on a non-connected packet conn.
type sendConn interface { type sendConn interface {
Write(b []byte, size protocol.ByteCount) error Write(b []byte, gsoSize uint16, ecn protocol.ECN) error
Close() error Close() error
LocalAddr() net.Addr LocalAddr() net.Addr
RemoteAddr() net.Addr RemoteAddr() net.Addr
@ -27,8 +25,7 @@ type sconn struct {
logger utils.Logger logger utils.Logger
info packetInfo packetInfoOOB []byte
oob []byte
// If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled. // If GSO enabled, and we receive a GSO error for this remote address, GSO is disabled.
gotGSOError bool gotGSOError bool
} }
@ -46,33 +43,20 @@ func newSendConn(c rawConn, remote net.Addr, info packetInfo, logger utils.Logge
} }
oob := info.OOB() oob := info.OOB()
// add 32 bytes, so we can add the UDP_SEGMENT msg // increase oob slice capacity, so we can add the UDP_SEGMENT and ECN control messages without allocating
l := len(oob) l := len(oob)
oob = append(oob, make([]byte, 32)...) oob = append(oob, make([]byte, 64)...)[:l]
oob = oob[:l]
return &sconn{ return &sconn{
rawConn: c, rawConn: c,
localAddr: localAddr, localAddr: localAddr,
remoteAddr: remote, remoteAddr: remote,
info: info, packetInfoOOB: oob,
oob: oob, logger: logger,
logger: logger,
} }
} }
func (c *sconn) Write(p []byte, size protocol.ByteCount) error { func (c *sconn) Write(p []byte, gsoSize uint16, ecn protocol.ECN) error {
if !c.capabilities().GSO { _, err := c.WritePacket(p, c.remoteAddr, c.packetInfoOOB, gsoSize, ecn)
if protocol.ByteCount(len(p)) != size {
panic(fmt.Sprintf("inconsistent packet size (%d vs %d)", len(p), size))
}
_, err := c.WritePacket(p, c.remoteAddr, c.oob)
return err
}
// GSO is supported. Append the control message and send.
if size > math.MaxUint16 {
panic("size overflow")
}
_, err := c.WritePacket(p, c.remoteAddr, appendUDPSegmentSizeMsg(c.oob, uint16(size)))
if err != nil && isGSOError(err) { if err != nil && isGSOError(err) {
// disable GSO for future calls // disable GSO for future calls
c.gotGSOError = true c.gotGSOError = true
@ -82,10 +66,10 @@ func (c *sconn) Write(p []byte, size protocol.ByteCount) error {
// send out the packets one by one // send out the packets one by one
for len(p) > 0 { for len(p) > 0 {
l := len(p) l := len(p)
if l > int(size) { if l > int(gsoSize) {
l = int(size) l = int(gsoSize)
} }
if _, err := c.WritePacket(p[:l], c.remoteAddr, c.oob); err != nil { if _, err := c.WritePacket(p[:l], c.remoteAddr, c.packetInfoOOB, 0, ecn); err != nil {
return err return err
} }
p = p[l:] p = p[l:]

View File

@ -3,7 +3,7 @@ package quic
import "github.com/quic-go/quic-go/internal/protocol" import "github.com/quic-go/quic-go/internal/protocol"
type sender interface { type sender interface {
Send(p *packetBuffer, packetSize protocol.ByteCount) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN)
Run() error Run() error
WouldBlock() bool WouldBlock() bool
Available() <-chan struct{} Available() <-chan struct{}
@ -11,8 +11,9 @@ type sender interface {
} }
type queueEntry struct { type queueEntry struct {
buf *packetBuffer buf *packetBuffer
size protocol.ByteCount gsoSize uint16
ecn protocol.ECN
} }
type sendQueue struct { type sendQueue struct {
@ -40,9 +41,9 @@ func newSendQueue(conn sendConn) sender {
// Send sends out a packet. It's guaranteed to not block. // Send sends out a packet. It's guaranteed to not block.
// Callers need to make sure that there's actually space in the send queue by calling WouldBlock. // Callers need to make sure that there's actually space in the send queue by calling WouldBlock.
// Otherwise Send will panic. // Otherwise Send will panic.
func (h *sendQueue) Send(p *packetBuffer, size protocol.ByteCount) { func (h *sendQueue) Send(p *packetBuffer, gsoSize uint16, ecn protocol.ECN) {
select { select {
case h.queue <- queueEntry{buf: p, size: size}: case h.queue <- queueEntry{buf: p, gsoSize: gsoSize, ecn: ecn}:
// clear available channel if we've reached capacity // clear available channel if we've reached capacity
if len(h.queue) == sendQueueCapacity { if len(h.queue) == sendQueueCapacity {
select { select {
@ -77,7 +78,7 @@ func (h *sendQueue) Run() error {
// make sure that all queued packets are actually sent out // make sure that all queued packets are actually sent out
shouldClose = true shouldClose = true
case e := <-h.queue: case e := <-h.queue:
if err := h.conn.Write(e.buf.Data, e.size); err != nil { if err := h.conn.Write(e.buf.Data, e.gsoSize, e.ecn); err != nil {
// This additional check enables: // This additional check enables:
// 1. Checking for "datagram too large" message from the kernel, as such, // 1. Checking for "datagram too large" message from the kernel, as such,
// 2. Path MTU discovery,and // 2. Path MTU discovery,and

View File

@ -2,7 +2,6 @@ package quic
import ( import (
"context" "context"
"crypto/rand"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -59,7 +58,8 @@ type zeroRTTQueue struct {
type baseServer struct { type baseServer struct {
mutex sync.Mutex mutex sync.Mutex
acceptEarlyConns bool disableVersionNegotiation bool
acceptEarlyConns bool
tlsConf *tls.Config tlsConf *tls.Config
config *Config config *Config
@ -67,6 +67,7 @@ type baseServer struct {
conn rawConn conn rawConn
tokenGenerator *handshake.TokenGenerator tokenGenerator *handshake.TokenGenerator
maxTokenAge time.Duration
connIDGenerator ConnectionIDGenerator connIDGenerator ConnectionIDGenerator
connHandler packetHandlerManager connHandler packetHandlerManager
@ -92,7 +93,7 @@ type baseServer struct {
*tls.Config, *tls.Config,
*handshake.TokenGenerator, *handshake.TokenGenerator,
bool, /* client address validated by an address validation token */ bool, /* client address validated by an address validation token */
logging.ConnectionTracer, *logging.ConnectionTracer,
uint64, uint64,
utils.Logger, utils.Logger,
protocol.VersionNumber, protocol.VersionNumber,
@ -108,7 +109,7 @@ type baseServer struct {
connQueue chan quicConn connQueue chan quicConn
connQueueLen int32 // to be used as an atomic connQueueLen int32 // to be used as an atomic
tracer logging.Tracer tracer *logging.Tracer
logger utils.Logger logger utils.Logger
} }
@ -224,32 +225,33 @@ func newServer(
connIDGenerator ConnectionIDGenerator, connIDGenerator ConnectionIDGenerator,
tlsConf *tls.Config, tlsConf *tls.Config,
config *Config, config *Config,
tracer logging.Tracer, tracer *logging.Tracer,
onClose func(), onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
disableVersionNegotiation bool,
acceptEarly bool, acceptEarly bool,
) (*baseServer, error) { ) *baseServer {
tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader)
if err != nil {
return nil, err
}
s := &baseServer{ s := &baseServer{
conn: conn, conn: conn,
tlsConf: tlsConf, tlsConf: tlsConf,
config: config, config: config,
tokenGenerator: tokenGenerator, tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
connIDGenerator: connIDGenerator, maxTokenAge: maxTokenAge,
connHandler: connHandler, connIDGenerator: connIDGenerator,
connQueue: make(chan quicConn), connHandler: connHandler,
errorChan: make(chan struct{}), connQueue: make(chan quicConn),
running: make(chan struct{}), errorChan: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets), running: make(chan struct{}),
versionNegotiationQueue: make(chan receivedPacket, 4), receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
invalidTokenQueue: make(chan receivedPacket, 4), versionNegotiationQueue: make(chan receivedPacket, 4),
newConn: newConnection, invalidTokenQueue: make(chan receivedPacket, 4),
tracer: tracer, newConn: newConnection,
logger: utils.DefaultLogger.WithPrefix("server"), tracer: tracer,
acceptEarlyConns: acceptEarly, logger: utils.DefaultLogger.WithPrefix("server"),
onClose: onClose, acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
} }
if acceptEarly { if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
@ -257,7 +259,7 @@ func newServer(
go s.run() go s.run()
go s.runSendQueue() go s.runSendQueue()
s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
return s, nil return s
} }
func (s *baseServer) run() { func (s *baseServer) run() {
@ -350,7 +352,7 @@ func (s *baseServer) handlePacket(p receivedPacket) {
case s.receivedPackets <- p: case s.receivedPackets <- p:
default: default:
s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size()) s.logger.Debugf("Dropping packet from %s (%d bytes). Server receive queue full.", p.remoteAddr, p.Size())
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
} }
} }
@ -363,7 +365,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
if wire.IsVersionNegotiationPacket(p.data) { if wire.IsVersionNegotiationPacket(p.data) {
s.logger.Debugf("Dropping Version Negotiation packet.") s.logger.Debugf("Dropping Version Negotiation packet.")
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
@ -376,20 +378,20 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
// drop the packet if we failed to parse the protocol version // drop the packet if we failed to parse the protocol version
if err != nil { if err != nil {
s.logger.Debugf("Dropping a packet with an unknown version") s.logger.Debugf("Dropping a packet with an unknown version")
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
} }
// send a Version Negotiation Packet if the client is speaking a different protocol version // send a Version Negotiation Packet if the client is speaking a different protocol version
if !protocol.IsSupportedVersion(s.config.Versions, v) { if !protocol.IsSupportedVersion(s.config.Versions, v) {
if s.config.DisableVersionNegotiationPackets { if s.disableVersionNegotiation {
return false return false
} }
if p.Size() < protocol.MinUnknownVersionPacketSize { if p.Size() < protocol.MinUnknownVersionPacketSize {
s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size()) s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size())
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
@ -399,7 +401,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
if wire.Is0RTTPacket(p.data) { if wire.Is0RTTPacket(p.data) {
if !s.acceptEarlyConns { if !s.acceptEarlyConns {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
@ -411,7 +413,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
// The header will then be parsed again. // The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data) hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
} }
s.logger.Debugf("Error parsing packet: %s", err) s.logger.Debugf("Error parsing packet: %s", err)
@ -419,7 +421,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
} }
if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize {
s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size())
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
@ -430,7 +432,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
// There's little point in sending a Stateless Reset, since the client // There's little point in sending a Stateless Reset, since the client
// might not have received the token yet. // might not have received the token yet.
s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data))
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket)
} }
return false return false
@ -449,7 +451,7 @@ func (s *baseServer) handlePacketImpl(p receivedPacket) bool /* is the buffer st
func (s *baseServer) handle0RTTPacket(p receivedPacket) bool { func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
connID, err := wire.ParseConnectionID(p.data, 0) connID, err := wire.ParseConnectionID(p.data, 0)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError)
} }
return false return false
@ -463,7 +465,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
if q, ok := s.zeroRTTQueues[connID]; ok { if q, ok := s.zeroRTTQueues[connID]; ok {
if len(q.packets) >= protocol.Max0RTTQueueLen { if len(q.packets) >= protocol.Max0RTTQueueLen {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
} }
return false return false
@ -473,7 +475,7 @@ func (s *baseServer) handle0RTTPacket(p receivedPacket) bool {
} }
if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
} }
return false return false
@ -501,7 +503,7 @@ func (s *baseServer) cleanupZeroRTTQueues(now time.Time) {
continue continue
} }
for _, p := range q.packets { for _, p := range q.packets {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention)
} }
p.buffer.Release() p.buffer.Release()
@ -525,10 +527,10 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
if !token.ValidateRemoteAddr(addr) { if !token.ValidateRemoteAddr(addr) {
return false return false
} }
if !token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxTokenAge { if !token.IsRetryToken && time.Since(token.SentTime) > s.maxTokenAge {
return false return false
} }
if token.IsRetryToken && time.Since(token.SentTime) > s.config.MaxRetryTokenAge { if token.IsRetryToken && time.Since(token.SentTime) > s.config.maxRetryTokenAge() {
return false return false
} }
return true return true
@ -537,7 +539,7 @@ func (s *baseServer) validateToken(token *handshake.Token, addr net.Addr) bool {
func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error { func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error {
if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
p.buffer.Release() p.buffer.Release()
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return errors.New("too short connection ID") return errors.New("too short connection ID")
@ -622,7 +624,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
} }
config = populateConfig(conf) config = populateConfig(conf)
} }
var tracer logging.ConnectionTracer var tracer *logging.ConnectionTracer
if config.Tracer != nil { if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback. // Use the same connection ID that is passed to the client's GetLogWriter callback.
connID := hdr.DestConnectionID connID := hdr.DestConnectionID
@ -739,10 +741,10 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info packe
// append the Retry integrity tag // append the Retry integrity tag
tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version) tag := handshake.GetRetryIntegrityTag(buf.Data, hdr.DestConnectionID, hdr.Version)
buf.Data = append(buf.Data, tag[:]...) buf.Data = append(buf.Data, tag[:]...)
if s.tracer != nil { if s.tracer != nil && s.tracer.SentPacket != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil) s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(buf.Data)), nil)
} }
_, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB()) _, err = s.conn.WritePacket(buf.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err return err
} }
@ -760,7 +762,7 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
hdr, _, _, err := wire.ParsePacket(p.data) hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
} }
s.logger.Debugf("Error parsing packet: %s", err) s.logger.Debugf("Error parsing packet: %s", err)
@ -775,14 +777,14 @@ func (s *baseServer) maybeSendInvalidToken(p receivedPacket) {
// Only send INVALID_TOKEN if we can unprotect the packet. // Only send INVALID_TOKEN if we can unprotect the packet.
// This makes sure that we won't send it for packets that were corrupted. // This makes sure that we won't send it for packets that were corrupted.
if err != nil { if err != nil {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError)
} }
return return
} }
hdrLen := extHdr.ParsedLen() hdrLen := extHdr.ParsedLen()
if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil {
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError)
} }
return return
@ -838,10 +840,10 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
replyHdr.Log(s.logger) replyHdr.Log(s.logger)
wire.LogFrame(s.logger, ccf, true) wire.LogFrame(s.logger, ccf, true)
if s.tracer != nil { if s.tracer != nil && s.tracer.SentPacket != nil {
s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf}) s.tracer.SentPacket(remoteAddr, &replyHdr.Header, protocol.ByteCount(len(b.Data)), []logging.Frame{ccf})
} }
_, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB()) _, err = s.conn.WritePacket(b.Data, remoteAddr, info.OOB(), 0, protocol.ECNUnsupported)
return err return err
} }
@ -867,7 +869,7 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
_, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data)
if err != nil { // should never happen if err != nil { // should never happen
s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs")
if s.tracer != nil { if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket)
} }
return return
@ -876,10 +878,10 @@ func (s *baseServer) maybeSendVersionNegotiationPacket(p receivedPacket) {
s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) s.logger.Debugf("Client offered version %s, sending Version Negotiation", v)
data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions)
if s.tracer != nil { if s.tracer != nil && s.tracer.SentVersionNegotiationPacket != nil {
s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions)
} }
if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
s.logger.Debugf("Error sending Version Negotiation: %s", err) s.logger.Debugf("Error sending Version Negotiation: %s", err)
} }
} }

View File

@ -104,7 +104,13 @@ func (c *basicConn) ReadPacket() (receivedPacket, error) {
}, nil }, nil
} }
func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte) (n int, err error) { func (c *basicConn) WritePacket(b []byte, addr net.Addr, _ []byte, gsoSize uint16, ecn protocol.ECN) (n int, err error) {
if gsoSize != 0 {
panic("cannot use GSO with a basicConn")
}
if ecn != protocol.ECNUnsupported {
panic("cannot use ECN with a basicConn")
}
return c.PacketConn.WriteTo(b, addr) return c.PacketConn.WriteTo(b, addr)
} }

View File

@ -15,6 +15,8 @@ const (
ipv4PKTINFO = unix.IP_RECVPKTINFO ipv4PKTINFO = unix.IP_RECVPKTINFO
) )
const ecnIPv4DataLen = 4
// ReadBatch only returns a single packet on OSX, // ReadBatch only returns a single packet on OSX,
// see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch. // see https://godoc.org/golang.org/x/net/ipv4#PacketConn.ReadBatch.
const batchSize = 1 const batchSize = 1

View File

@ -14,6 +14,8 @@ const (
ipv4PKTINFO = 0x7 ipv4PKTINFO = 0x7
) )
const ecnIPv4DataLen = 4
const batchSize = 8 const batchSize = 8
func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) { func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) {

View File

@ -19,6 +19,8 @@ const (
ipv4PKTINFO = unix.IP_PKTINFO ipv4PKTINFO = unix.IP_PKTINFO
) )
const ecnIPv4DataLen = 4
const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed) const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed)
func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error { func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error {

View File

@ -8,9 +8,12 @@ import (
"log" "log"
"net" "net"
"net/netip" "net/netip"
"os"
"strconv"
"sync" "sync"
"syscall" "syscall"
"time" "time"
"unsafe"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@ -56,6 +59,11 @@ func inspectWriteBuffer(c syscall.RawConn) (int, error) {
return size, serr return size, serr
} }
func isECNDisabled() bool {
disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN"))
return err == nil && disabled
}
type oobConn struct { type oobConn struct {
OOBCapablePacketConn OOBCapablePacketConn
batchConn batchConn batchConn batchConn
@ -140,6 +148,7 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
cap: connCapabilities{ cap: connCapabilities{
DF: supportsDF, DF: supportsDF,
GSO: isGSOSupported(rawConn), GSO: isGSOSupported(rawConn),
ECN: !isECNDisabled(),
}, },
} }
for i := 0; i < batchSize; i++ { for i := 0; i < batchSize; i++ {
@ -188,7 +197,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
if hdr.Level == unix.IPPROTO_IP { if hdr.Level == unix.IPPROTO_IP {
switch hdr.Type { switch hdr.Type {
case msgTypeIPTOS: case msgTypeIPTOS:
p.ecn = protocol.ECN(body[0] & ecnMask) p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
case ipv4PKTINFO: case ipv4PKTINFO:
ip, ifIndex, ok := parseIPv4PktInfo(body) ip, ifIndex, ok := parseIPv4PktInfo(body)
if ok { if ok {
@ -205,7 +214,7 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
if hdr.Level == unix.IPPROTO_IPV6 { if hdr.Level == unix.IPPROTO_IPV6 {
switch hdr.Type { switch hdr.Type {
case unix.IPV6_TCLASS: case unix.IPV6_TCLASS:
p.ecn = protocol.ECN(body[0] & ecnMask) p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
case unix.IPV6_PKTINFO: case unix.IPV6_PKTINFO:
// struct in6_pktinfo { // struct in6_pktinfo {
// struct in6_addr ipi6_addr; /* src/dst IPv6 address */ // struct in6_addr ipi6_addr; /* src/dst IPv6 address */
@ -228,8 +237,26 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
} }
// WritePacket writes a new packet. // WritePacket writes a new packet.
// If the connection supports GSO, it's the caller's responsibility to append the right control mesage. func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) {
func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (int, error) { oob := packetInfoOOB
if gsoSize > 0 {
if !c.capabilities().GSO {
panic("GSO disabled")
}
oob = appendUDPSegmentSizeMsg(oob, gsoSize)
}
if ecn != protocol.ECNUnsupported {
if !c.capabilities().ECN {
panic("tried to send a ECN-marked packet although ECN is disabled")
}
if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok {
if remoteUDPAddr.IP.To4() != nil {
oob = appendIPv4ECNMsg(oob, ecn)
} else {
oob = appendIPv6ECNMsg(oob, ecn)
}
}
}
n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr)) n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
return n, err return n, err
} }
@ -273,3 +300,32 @@ func (info *packetInfo) OOB() []byte {
} }
return nil return nil
} }
func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte {
startLen := len(b)
b = append(b, make([]byte, unix.CmsgSpace(ecnIPv4DataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_IP
h.Type = unix.IP_TOS
h.SetLen(unix.CmsgLen(ecnIPv4DataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
b[offset] = val.ToHeaderBits()
return b
}
func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte {
startLen := len(b)
const dataLen = 4
b = append(b, make([]byte, unix.CmsgSpace(dataLen))...)
h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
h.Level = syscall.IPPROTO_IPV6
h.Type = unix.IPV6_TCLASS
h.SetLen(unix.CmsgLen(dataLen))
// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
offset := startLen + unix.CmsgSpace(0)
b[offset] = val.ToHeaderBits()
return b
}

View File

@ -3,6 +3,6 @@
package quic package quic
import ( import (
_ "github.com/golang/mock/mockgen"
_ "github.com/onsi/ginkgo/v2/ginkgo" _ "github.com/onsi/ginkgo/v2/ginkgo"
_ "go.uber.org/mock/mockgen"
) )

View File

@ -16,6 +16,8 @@ import (
"github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/logging"
) )
var errListenerAlreadySet = errors.New("listener already set")
// The Transport is the central point to manage incoming and outgoing QUIC connections. // The Transport is the central point to manage incoming and outgoing QUIC connections.
// QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple. // QUIC demultiplexes connections based on their QUIC Connection IDs, not based on the 4-tuple.
// This means that a single UDP socket can be used for listening for incoming connections, as well as // This means that a single UDP socket can be used for listening for incoming connections, as well as
@ -57,8 +59,26 @@ type Transport struct {
// See section 10.3 of RFC 9000 for details. // See section 10.3 of RFC 9000 for details.
StatelessResetKey *StatelessResetKey StatelessResetKey *StatelessResetKey
// The TokenGeneratorKey is used to encrypt session resumption tokens.
// If no key is configured, a random key will be generated.
// If multiple servers are authoritative for the same domain, they should use the same key,
// see section 8.1.3 of RFC 9000 for details.
TokenGeneratorKey *TokenGeneratorKey
// MaxTokenAge is the maximum age of the resumption token presented during the handshake.
// These tokens allow skipping address resumption when resuming a QUIC connection,
// and are especially useful when using 0-RTT.
// If not set, it defaults to 24 hours.
// See section 8.1.3 of RFC 9000 for details.
MaxTokenAge time.Duration
// DisableVersionNegotiationPackets disables the sending of Version Negotiation packets.
// This can be useful if version information is exchanged out-of-band.
// It has no effect for clients.
DisableVersionNegotiationPackets bool
// A Tracer traces events that don't belong to a single QUIC connection. // A Tracer traces events that don't belong to a single QUIC connection.
Tracer logging.Tracer Tracer *logging.Tracer
handlerMap packetHandlerManager handlerMap packetHandlerManager
@ -73,7 +93,7 @@ type Transport struct {
// If no ConnectionIDGenerator is set, this is set to a default. // If no ConnectionIDGenerator is set, this is set to a default.
connIDGenerator ConnectionIDGenerator connIDGenerator ConnectionIDGenerator
server unknownPacketHandler server *baseServer
conn rawConn conn rawConn
@ -95,28 +115,10 @@ type Transport struct {
// There can only be a single listener on any net.PacketConn. // There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed. // Listen may only be called again after the current Listener was closed.
func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) { func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error) {
if tlsConf == nil { s, err := t.createServer(tlsConf, conf, false)
return nil, errors.New("quic: tls.Config not set")
}
if err := validateConfig(conf); err != nil {
return nil, err
}
t.mutex.Lock()
defer t.mutex.Unlock()
if t.server != nil {
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(false); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.server = s
return &Listener{baseServer: s}, nil return &Listener{baseServer: s}, nil
} }
@ -124,6 +126,14 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
// There can only be a single listener on any net.PacketConn. // There can only be a single listener on any net.PacketConn.
// Listen may only be called again after the current Listener was closed. // Listen may only be called again after the current Listener was closed.
func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) { func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListener, error) {
s, err := t.createServer(tlsConf, conf, true)
if err != nil {
return nil, err
}
return &EarlyListener{baseServer: s}, nil
}
func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bool) (*baseServer, error) {
if tlsConf == nil { if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set") return nil, errors.New("quic: tls.Config not set")
} }
@ -141,12 +151,21 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
if err := t.init(false); err != nil { if err := t.init(false); err != nil {
return nil, err return nil, err
} }
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true) s := newServer(
if err != nil { t.conn,
return nil, err t.handlerMap,
} t.connIDGenerator,
tlsConf,
conf,
t.Tracer,
t.closeServer,
*t.TokenGeneratorKey,
t.MaxTokenAge,
t.DisableVersionNegotiationPackets,
allow0RTT,
)
t.server = s t.server = s
return &EarlyListener{baseServer: s}, nil return s, nil
} }
// Dial dials a new connection to a remote host (not using 0-RTT). // Dial dials a new connection to a remote host (not using 0-RTT).
@ -198,6 +217,14 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.closeQueue = make(chan closePacket, 4) t.closeQueue = make(chan closePacket, 4)
t.statelessResetQueue = make(chan receivedPacket, 4) t.statelessResetQueue = make(chan receivedPacket, 4)
if t.TokenGeneratorKey == nil {
var key TokenGeneratorKey
if _, err := rand.Read(key[:]); err != nil {
t.initErr = err
return
}
t.TokenGeneratorKey = &key
}
if t.ConnectionIDGenerator != nil { if t.ConnectionIDGenerator != nil {
t.connIDGenerator = t.ConnectionIDGenerator t.connIDGenerator = t.ConnectionIDGenerator
@ -223,7 +250,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil { if err := t.init(false); err != nil {
return 0, err return 0, err
} }
return t.conn.WritePacket(b, addr, nil) return t.conn.WritePacket(b, addr, nil, 0, protocol.ECNUnsupported)
} }
func (t *Transport) enqueueClosePacket(p closePacket) { func (t *Transport) enqueueClosePacket(p closePacket) {
@ -241,7 +268,7 @@ func (t *Transport) runSendQueue() {
case <-t.listening: case <-t.listening:
return return
case p := <-t.closeQueue: case p := <-t.closeQueue:
t.conn.WritePacket(p.payload, p.addr, p.info.OOB()) t.conn.WritePacket(p.payload, p.addr, p.info.OOB(), 0, protocol.ECNUnsupported)
case p := <-t.statelessResetQueue: case p := <-t.statelessResetQueue:
t.sendStatelessReset(p) t.sendStatelessReset(p)
} }
@ -346,7 +373,7 @@ func (t *Transport) handlePacket(p receivedPacket) {
connID, err := wire.ParseConnectionID(p.data, t.connIDLen) connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil { if err != nil {
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
if t.Tracer != nil { if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
} }
p.buffer.MaybeRelease() p.buffer.MaybeRelease()
@ -409,7 +436,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) {
rand.Read(data) rand.Read(data)
data[0] = (data[0] & 0x7f) | 0x40 data[0] = (data[0] & 0x7f) | 0x40
data = append(data, token[:]...) data = append(data, token[:]...)
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB(), 0, protocol.ECNUnsupported); err != nil {
t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err) t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
} }
} }
@ -441,7 +468,7 @@ func (t *Transport) handleNonQUICPacket(p receivedPacket) {
select { select {
case t.nonQUICPackets <- p: case t.nonQUICPackets <- p:
default: default:
if t.Tracer != nil { if t.Tracer != nil && t.Tracer.DroppedPacket != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
} }
} }

116
vendor/go.uber.org/mock/mockgen/generic_go118.go generated vendored Normal file
View File

@ -0,0 +1,116 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build go1.18
// +build go1.18
package main
import (
"fmt"
"go/ast"
"strings"
"go.uber.org/mock/mockgen/model"
)
func getTypeSpecTypeParams(ts *ast.TypeSpec) []*ast.Field {
if ts == nil || ts.TypeParams == nil {
return nil
}
return ts.TypeParams.List
}
func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]model.Type) (model.Type, error) {
switch v := typ.(type) {
case *ast.IndexExpr:
m, err := p.parseType(pkg, v.X, tps)
if err != nil {
return nil, err
}
nm, ok := m.(*model.NamedType)
if !ok {
return m, nil
}
t, err := p.parseType(pkg, v.Index, tps)
if err != nil {
return nil, err
}
nm.TypeParams = &model.TypeParametersType{TypeParameters: []model.Type{t}}
return m, nil
case *ast.IndexListExpr:
m, err := p.parseType(pkg, v.X, tps)
if err != nil {
return nil, err
}
nm, ok := m.(*model.NamedType)
if !ok {
return m, nil
}
var ts []model.Type
for _, expr := range v.Indices {
t, err := p.parseType(pkg, expr, tps)
if err != nil {
return nil, err
}
ts = append(ts, t)
}
nm.TypeParams = &model.TypeParametersType{TypeParameters: ts}
return m, nil
}
return nil, nil
}
func getIdentTypeParams(decl any) string {
if decl == nil {
return ""
}
ts, ok := decl.(*ast.TypeSpec)
if !ok {
return ""
}
if ts.TypeParams == nil || len(ts.TypeParams.List) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("[")
for i, v := range ts.TypeParams.List {
if i != 0 {
sb.WriteString(", ")
}
sb.WriteString(v.Names[0].Name)
}
sb.WriteString("]")
return sb.String()
}
func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, iface *model.Interface, pkg string, tps map[string]model.Type) ([]*model.Method, error) {
var indices []ast.Expr
var typ ast.Expr
switch v := field.Type.(type) {
case *ast.IndexExpr:
indices = []ast.Expr{v.Index}
typ = v.X
case *ast.IndexListExpr:
indices = v.Indices
typ = v.X
default:
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
}
nf := &ast.Field{
Doc: field.Comment,
Names: field.Names,
Type: typ,
Tag: field.Tag,
Comment: field.Comment,
}
it.embeddedInstTypeParams = indices
return p.parseMethod(nf, it, iface, pkg, tps)
}

41
vendor/go.uber.org/mock/mockgen/generic_notgo118.go generated vendored Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !go1.18
// +build !go1.18
package main
import (
"fmt"
"go/ast"
"go.uber.org/mock/mockgen/model"
)
func getTypeSpecTypeParams(ts *ast.TypeSpec) []*ast.Field {
return nil
}
func (p *fileParser) parseGenericType(pkg string, typ ast.Expr, tps map[string]model.Type) (model.Type, error) {
return nil, nil
}
func getIdentTypeParams(decl any) string {
return ""
}
func (p *fileParser) parseGenericMethod(field *ast.Field, it *namedInterface, iface *model.Interface, pkg string, tps map[string]model.Type) ([]*model.Method, error) {
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
}

View File

@ -21,11 +21,11 @@ package main
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"flag" "flag"
"fmt" "fmt"
"go/token" "go/token"
"io" "io"
"io/ioutil"
"log" "log"
"os" "os"
"os/exec" "os/exec"
@ -36,14 +36,14 @@ import (
"strings" "strings"
"unicode" "unicode"
"github.com/golang/mock/mockgen/model"
"golang.org/x/mod/modfile" "golang.org/x/mod/modfile"
toolsimports "golang.org/x/tools/imports" toolsimports "golang.org/x/tools/imports"
"go.uber.org/mock/mockgen/model"
) )
const ( const (
gomockImportPath = "github.com/golang/mock/gomock" gomockImportPath = "go.uber.org/mock/gomock"
) )
var ( var (
@ -53,13 +53,19 @@ var (
) )
var ( var (
source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.")
destination = flag.String("destination", "", "Output file; defaults to stdout.") destination = flag.String("destination", "", "Output file; defaults to stdout.")
mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.")
packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.") packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.")
selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.") selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.")
writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.") writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.")
copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header") writeSourceComment = flag.Bool("write_source_comment", true, "Writes original file (source mode) or interface names (reflect mode) comment if true.")
writeGenerateDirective = flag.Bool("write_generate_directive", false, "Add //go:generate directive to regenerate the mock")
copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header")
typed = flag.Bool("typed", false, "Generate Type-safe 'Return', 'Do', 'DoAndReturn' function")
imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
excludeInterfaces = flag.String("exclude_interfaces", "", "Comma-separated names of interfaces to be excluded")
debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
showVersion = flag.Bool("version", false, "Print version.") showVersion = flag.Bool("version", false, "Print version.")
@ -107,19 +113,6 @@ func main() {
return return
} }
dst := os.Stdout
if len(*destination) > 0 {
if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil {
log.Fatalf("Unable to create directory: %v", err)
}
f, err := os.Create(*destination)
if err != nil {
log.Fatalf("Failed opening destination file: %v", err)
}
defer f.Close()
dst = f
}
outputPackageName := *packageOut outputPackageName := *packageOut
if outputPackageName == "" { if outputPackageName == "" {
// pkg.Name in reflect mode is the base name of the import path, // pkg.Name in reflect mode is the base name of the import path,
@ -161,7 +154,7 @@ func main() {
g.mockNames = parseMockNames(*mockNames) g.mockNames = parseMockNames(*mockNames)
} }
if *copyrightFile != "" { if *copyrightFile != "" {
header, err := ioutil.ReadFile(*copyrightFile) header, err := os.ReadFile(*copyrightFile)
if err != nil { if err != nil {
log.Fatalf("Failed reading copyright file: %v", err) log.Fatalf("Failed reading copyright file: %v", err)
} }
@ -171,7 +164,27 @@ func main() {
if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil { if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil {
log.Fatalf("Failed generating mock: %v", err) log.Fatalf("Failed generating mock: %v", err)
} }
if _, err := dst.Write(g.Output()); err != nil { output := g.Output()
dst := os.Stdout
if len(*destination) > 0 {
if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil {
log.Fatalf("Unable to create directory: %v", err)
}
existing, err := os.ReadFile(*destination)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Fatalf("Failed reading pre-exiting destination file: %v", err)
}
if len(existing) == len(output) && bytes.Equal(existing, output) {
return
}
f, err := os.Create(*destination)
if err != nil {
log.Fatalf("Failed opening destination file: %v", err)
}
defer f.Close()
dst = f
}
if _, err := dst.Write(output); err != nil {
log.Fatalf("Failed writing to destination: %v", err) log.Fatalf("Failed writing to destination: %v", err)
} }
} }
@ -188,6 +201,24 @@ func parseMockNames(names string) map[string]string {
return mocksMap return mocksMap
} }
func parseExcludeInterfaces(names string) map[string]struct{} {
splitNames := strings.Split(names, ",")
namesSet := make(map[string]struct{}, len(splitNames))
for _, name := range splitNames {
if name == "" {
continue
}
namesSet[name] = struct{}{}
}
if len(namesSet) == 0 {
return nil
}
return namesSet
}
func usage() { func usage() {
_, _ = io.WriteString(os.Stderr, usageText) _, _ = io.WriteString(os.Stderr, usageText)
flag.PrintDefaults() flag.PrintDefaults()
@ -222,7 +253,7 @@ type generator struct {
packageMap map[string]string // map from import path to package name packageMap map[string]string // map from import path to package name
} }
func (g *generator) p(format string, args ...interface{}) { func (g *generator) p(format string, args ...any) {
fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
} }
@ -274,12 +305,17 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
} }
g.p("// Code generated by MockGen. DO NOT EDIT.") g.p("// Code generated by MockGen. DO NOT EDIT.")
if g.filename != "" { if *writeSourceComment {
g.p("// Source: %v", g.filename) if g.filename != "" {
} else { g.p("// Source: %v", g.filename)
g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) } else {
g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces)
}
} }
g.p("") g.p("//")
g.p("// Generated by this command:")
// only log the name of the executable, not the full path
g.p("// %v", strings.Join(append([]string{filepath.Base(os.Args[0])}, os.Args[1:]...), " "))
// Get all required imports, and generate unique names for them all. // Get all required imports, and generate unique names for them all.
im := pkg.Imports() im := pkg.Imports()
@ -305,6 +341,16 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
packagesName := createPackageMap(sortedPaths) packagesName := createPackageMap(sortedPaths)
definedImports := make(map[string]string, len(im))
if *imports != "" {
for _, kv := range strings.Split(*imports, ",") {
eq := strings.Index(kv, "=")
if k, v := kv[:eq], kv[eq+1:]; k != "." {
definedImports[v] = k
}
}
}
g.packageMap = make(map[string]string, len(im)) g.packageMap = make(map[string]string, len(im))
localNames := make(map[string]bool, len(im)) localNames := make(map[string]bool, len(im))
for _, pth := range sortedPaths { for _, pth := range sortedPaths {
@ -316,9 +362,14 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
// Local names for an imported package can usually be the basename of the import path. // Local names for an imported package can usually be the basename of the import path.
// A couple of situations don't permit that, such as duplicate local names // A couple of situations don't permit that, such as duplicate local names
// (e.g. importing "html/template" and "text/template"), or where the basename is // (e.g. importing "html/template" and "text/template"), or where the basename is
// a keyword (e.g. "foo/case"). // a keyword (e.g. "foo/case") or when defining a name for that by using the -imports flag.
// try base0, base1, ... // try base0, base1, ...
pkgName := base pkgName := base
if _, ok := definedImports[base]; ok {
pkgName = definedImports[base]
}
i := 0 i := 0
for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() { for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
pkgName = base + strconv.Itoa(i) pkgName = base + strconv.Itoa(i)
@ -353,6 +404,10 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac
g.out() g.out()
g.p(")") g.p(")")
if *writeGenerateDirective {
g.p("//go:generate %v", strings.Join(os.Args, " "))
}
for _, intf := range pkg.Interfaces { for _, intf := range pkg.Interfaces {
if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
return err return err
@ -371,32 +426,58 @@ func (g *generator) mockName(typeName string) string {
return "Mock" + typeName return "Mock" + typeName
} }
// formattedTypeParams returns a long and short form of type param info used for
// printing. If analyzing a interface with type param [I any, O any] the result
// will be:
// "[I any, O any]", "[I, O]"
func (g *generator) formattedTypeParams(it *model.Interface, pkgOverride string) (string, string) {
if len(it.TypeParams) == 0 {
return "", ""
}
var long, short strings.Builder
long.WriteString("[")
short.WriteString("[")
for i, v := range it.TypeParams {
if i != 0 {
long.WriteString(", ")
short.WriteString(", ")
}
long.WriteString(v.Name)
short.WriteString(v.Name)
long.WriteString(fmt.Sprintf(" %s", v.Type.String(g.packageMap, pkgOverride)))
}
long.WriteString("]")
short.WriteString("]")
return long.String(), short.String()
}
func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
mockType := g.mockName(intf.Name) mockType := g.mockName(intf.Name)
longTp, shortTp := g.formattedTypeParams(intf, outputPackagePath)
g.p("") g.p("")
g.p("// %v is a mock of %v interface.", mockType, intf.Name) g.p("// %v is a mock of %v interface.", mockType, intf.Name)
g.p("type %v struct {", mockType) g.p("type %v%v struct {", mockType, longTp)
g.in() g.in()
g.p("ctrl *gomock.Controller") g.p("ctrl *gomock.Controller")
g.p("recorder *%vMockRecorder", mockType) g.p("recorder *%vMockRecorder%v", mockType, shortTp)
g.out() g.out()
g.p("}") g.p("}")
g.p("") g.p("")
g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType) g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType)
g.p("type %vMockRecorder struct {", mockType) g.p("type %vMockRecorder%v struct {", mockType, longTp)
g.in() g.in()
g.p("mock *%v", mockType) g.p("mock *%v%v", mockType, shortTp)
g.out() g.out()
g.p("}") g.p("}")
g.p("") g.p("")
g.p("// New%v creates a new mock instance.", mockType) g.p("// New%v creates a new mock instance.", mockType)
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType) g.p("func New%v%v(ctrl *gomock.Controller) *%v%v {", mockType, longTp, mockType, shortTp)
g.in() g.in()
g.p("mock := &%v{ctrl: ctrl}", mockType) g.p("mock := &%v%v{ctrl: ctrl}", mockType, shortTp)
g.p("mock.recorder = &%vMockRecorder{mock}", mockType) g.p("mock.recorder = &%vMockRecorder%v{mock}", mockType, shortTp)
g.p("return mock") g.p("return mock")
g.out() g.out()
g.p("}") g.p("}")
@ -404,13 +485,13 @@ func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePa
// XXX: possible name collision here if someone has EXPECT in their interface. // XXX: possible name collision here if someone has EXPECT in their interface.
g.p("// EXPECT returns an object that allows the caller to indicate expected use.") g.p("// EXPECT returns an object that allows the caller to indicate expected use.")
g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType) g.p("func (m *%v%v) EXPECT() *%vMockRecorder%v {", mockType, shortTp, mockType, shortTp)
g.in() g.in()
g.p("return m.recorder") g.p("return m.recorder")
g.out() g.out()
g.p("}") g.p("}")
g.GenerateMockMethods(mockType, intf, outputPackagePath) g.GenerateMockMethods(mockType, intf, outputPackagePath, longTp, shortTp, *typed)
return nil return nil
} }
@ -421,13 +502,17 @@ func (b byMethodName) Len() int { return len(b) }
func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name } func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name }
func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) { func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride, longTp, shortTp string, typed bool) {
sort.Sort(byMethodName(intf.Methods)) sort.Sort(byMethodName(intf.Methods))
for _, m := range intf.Methods { for _, m := range intf.Methods {
g.p("") g.p("")
_ = g.GenerateMockMethod(mockType, m, pkgOverride) _ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp)
g.p("") g.p("")
_ = g.GenerateMockRecorderMethod(mockType, m) _ = g.GenerateMockRecorderMethod(intf, mockType, m, shortTp, typed)
if typed {
g.p("")
_ = g.GenerateMockReturnCallMethod(intf, m, pkgOverride, longTp, shortTp)
}
} }
} }
@ -446,9 +531,9 @@ func makeArgString(argNames, argTypes []string) string {
// GenerateMockMethod generates a mock method implementation. // GenerateMockMethod generates a mock method implementation.
// If non-empty, pkgOverride is the package in which unqualified types reside. // If non-empty, pkgOverride is the package in which unqualified types reside.
func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error { func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride, shortTp string) error {
argNames := g.getArgNames(m) argNames := g.getArgNames(m, true /* in */)
argTypes := g.getArgTypes(m, pkgOverride) argTypes := g.getArgTypes(m, pkgOverride, true /* in */)
argString := makeArgString(argNames, argTypes) argString := makeArgString(argNames, argTypes)
rets := make([]string, len(m.Out)) rets := make([]string, len(m.Out))
@ -467,7 +552,7 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
idRecv := ia.allocateIdentifier("m") idRecv := ia.allocateIdentifier("m")
g.p("// %v mocks base method.", m.Name) g.p("// %v mocks base method.", m.Name)
g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString) g.p("func (%v *%v%v) %v(%v)%v {", idRecv, mockType, shortTp, m.Name, argString, retString)
g.in() g.in()
g.p("%s.ctrl.T.Helper()", idRecv) g.p("%s.ctrl.T.Helper()", idRecv)
@ -477,11 +562,11 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
callArgs = ", " + strings.Join(argNames, ", ") callArgs = ", " + strings.Join(argNames, ", ")
} }
} else { } else {
// Non-trivial. The generated code must build a []interface{}, // Non-trivial. The generated code must build a []any,
// but the variadic argument may be any type. // but the variadic argument may be any type.
idVarArgs := ia.allocateIdentifier("varargs") idVarArgs := ia.allocateIdentifier("varargs")
idVArg := ia.allocateIdentifier("a") idVArg := ia.allocateIdentifier("a")
g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) g.p("%s := []any{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", "))
g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1]) g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1])
g.in() g.in()
g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg) g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg)
@ -511,8 +596,8 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
return nil return nil
} }
func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error { func (g *generator) GenerateMockRecorderMethod(intf *model.Interface, mockType string, m *model.Method, shortTp string, typed bool) error {
argNames := g.getArgNames(m) argNames := g.getArgNames(m, true)
var argString string var argString string
if m.Variadic == nil { if m.Variadic == nil {
@ -521,21 +606,26 @@ func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method)
argString = strings.Join(argNames[:len(argNames)-1], ", ") argString = strings.Join(argNames[:len(argNames)-1], ", ")
} }
if argString != "" { if argString != "" {
argString += " interface{}" argString += " any"
} }
if m.Variadic != nil { if m.Variadic != nil {
if argString != "" { if argString != "" {
argString += ", " argString += ", "
} }
argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1]) argString += fmt.Sprintf("%s ...any", argNames[len(argNames)-1])
} }
ia := newIdentifierAllocator(argNames) ia := newIdentifierAllocator(argNames)
idRecv := ia.allocateIdentifier("mr") idRecv := ia.allocateIdentifier("mr")
g.p("// %v indicates an expected call of %v.", m.Name, m.Name) g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString) if typed {
g.p("func (%s *%vMockRecorder%v) %v(%v) *%s%sCall%s {", idRecv, mockType, shortTp, m.Name, argString, intf.Name, m.Name, shortTp)
} else {
g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString)
}
g.in() g.in()
g.p("%s.mock.ctrl.T.Helper()", idRecv) g.p("%s.mock.ctrl.T.Helper()", idRecv)
@ -551,42 +641,121 @@ func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method)
} else { } else {
// Hard: create a temporary slice. // Hard: create a temporary slice.
idVarArgs := ia.allocateIdentifier("varargs") idVarArgs := ia.allocateIdentifier("varargs")
g.p("%s := append([]interface{}{%s}, %s...)", g.p("%s := append([]any{%s}, %s...)",
idVarArgs, idVarArgs,
strings.Join(argNames[:len(argNames)-1], ", "), strings.Join(argNames[:len(argNames)-1], ", "),
argNames[len(argNames)-1]) argNames[len(argNames)-1])
callArgs = ", " + idVarArgs + "..." callArgs = ", " + idVarArgs + "..."
} }
} }
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs) if typed {
g.p(`call := %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
g.p(`return &%s%sCall%s{Call: call}`, intf.Name, m.Name, shortTp)
} else {
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
}
g.out() g.out()
g.p("}") g.p("}")
return nil return nil
} }
func (g *generator) getArgNames(m *model.Method) []string { func (g *generator) GenerateMockReturnCallMethod(intf *model.Interface, m *model.Method, pkgOverride, longTp, shortTp string) error {
argNames := make([]string, len(m.In)) argNames := g.getArgNames(m, true /* in */)
for i, p := range m.In { retNames := g.getArgNames(m, false /* out */)
argTypes := g.getArgTypes(m, pkgOverride, true /* in */)
retTypes := g.getArgTypes(m, pkgOverride, false /* out */)
argString := strings.Join(argTypes, ", ")
rets := make([]string, len(m.Out))
for i, p := range m.Out {
rets[i] = p.Type.String(g.packageMap, pkgOverride)
}
var retString string
switch {
case len(rets) == 1:
retString = " " + rets[0]
case len(rets) > 1:
retString = " (" + strings.Join(rets, ", ") + ")"
}
ia := newIdentifierAllocator(argNames)
idRecv := ia.allocateIdentifier("c")
recvStructName := intf.Name + m.Name
g.p("// %s%sCall wrap *gomock.Call", intf.Name, m.Name)
g.p("type %s%sCall%s struct{", intf.Name, m.Name, longTp)
g.in()
g.p("*gomock.Call")
g.out()
g.p("}")
g.p("// Return rewrite *gomock.Call.Return")
g.p("func (%s *%sCall%s) Return(%v) *%sCall%s {", idRecv, recvStructName, shortTp, makeArgString(retNames, retTypes), recvStructName, shortTp)
g.in()
var retArgs string
if len(retNames) > 0 {
retArgs = strings.Join(retNames, ", ")
}
g.p(`%s.Call = %v.Call.Return(%v)`, idRecv, idRecv, retArgs)
g.p("return %s", idRecv)
g.out()
g.p("}")
g.p("// Do rewrite *gomock.Call.Do")
g.p("func (%s *%sCall%s) Do(f func(%v)%v) *%sCall%s {", idRecv, recvStructName, shortTp, argString, retString, recvStructName, shortTp)
g.in()
g.p(`%s.Call = %v.Call.Do(f)`, idRecv, idRecv)
g.p("return %s", idRecv)
g.out()
g.p("}")
g.p("// DoAndReturn rewrite *gomock.Call.DoAndReturn")
g.p("func (%s *%sCall%s) DoAndReturn(f func(%v)%v) *%sCall%s {", idRecv, recvStructName, shortTp, argString, retString, recvStructName, shortTp)
g.in()
g.p(`%s.Call = %v.Call.DoAndReturn(f)`, idRecv, idRecv)
g.p("return %s", idRecv)
g.out()
g.p("}")
return nil
}
func (g *generator) getArgNames(m *model.Method, in bool) []string {
var params []*model.Parameter
if in {
params = m.In
} else {
params = m.Out
}
argNames := make([]string, len(params))
for i, p := range params {
name := p.Name name := p.Name
if name == "" || name == "_" { if name == "" || name == "_" {
name = fmt.Sprintf("arg%d", i) name = fmt.Sprintf("arg%d", i)
} }
argNames[i] = name argNames[i] = name
} }
if m.Variadic != nil { if m.Variadic != nil && in {
name := m.Variadic.Name name := m.Variadic.Name
if name == "" { if name == "" {
name = fmt.Sprintf("arg%d", len(m.In)) name = fmt.Sprintf("arg%d", len(params))
} }
argNames = append(argNames, name) argNames = append(argNames, name)
} }
return argNames return argNames
} }
func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string { func (g *generator) getArgTypes(m *model.Method, pkgOverride string, in bool) []string {
argTypes := make([]string, len(m.In)) var params []*model.Parameter
for i, p := range m.In { if in {
params = m.In
} else {
params = m.Out
}
argTypes := make([]string, len(params))
for i, p := range params {
argTypes[i] = p.Type.String(g.packageMap, pkgOverride) argTypes[i] = p.Type.String(g.packageMap, pkgOverride)
} }
if m.Variadic != nil { if m.Variadic != nil {
@ -670,7 +839,7 @@ func parsePackageImport(srcDir string) (string, error) {
if moduleMode != "off" { if moduleMode != "off" {
currentDir := srcDir currentDir := srcDir
for { for {
dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod")) dat, err := os.ReadFile(filepath.Join(currentDir, "go.mod"))
if os.IsNotExist(err) { if os.IsNotExist(err) {
if currentDir == filepath.Dir(currentDir) { if currentDir == filepath.Dir(currentDir) {
// at the root // at the root

View File

@ -24,7 +24,7 @@ import (
) )
// pkgPath is the importable path for package model // pkgPath is the importable path for package model
const pkgPath = "github.com/golang/mock/mockgen/model" const pkgPath = "go.uber.org/mock/mockgen/model"
// Package is a Go package. It may be a subset. // Package is a Go package. It may be a subset.
type Package struct { type Package struct {
@ -47,14 +47,18 @@ func (pkg *Package) Imports() map[string]bool {
im := make(map[string]bool) im := make(map[string]bool)
for _, intf := range pkg.Interfaces { for _, intf := range pkg.Interfaces {
intf.addImports(im) intf.addImports(im)
for _, tp := range intf.TypeParams {
tp.Type.addImports(im)
}
} }
return im return im
} }
// Interface is a Go interface. // Interface is a Go interface.
type Interface struct { type Interface struct {
Name string Name string
Methods []*Method Methods []*Method
TypeParams []*Parameter
} }
// Print writes the interface name and its methods. // Print writes the interface name and its methods.
@ -143,12 +147,14 @@ type Type interface {
} }
func init() { func init() {
gob.Register(&ArrayType{}) // Call gob.RegisterName with pkgPath as prefix to avoid conflicting with
gob.Register(&ChanType{}) // github.com/golang/mock/mockgen/model 's registration.
gob.Register(&FuncType{}) gob.RegisterName(pkgPath+".ArrayType", &ArrayType{})
gob.Register(&MapType{}) gob.RegisterName(pkgPath+".ChanType", &ChanType{})
gob.Register(&NamedType{}) gob.RegisterName(pkgPath+".FuncType", &FuncType{})
gob.Register(&PointerType{}) gob.RegisterName(pkgPath+".MapType", &MapType{})
gob.RegisterName(pkgPath+".NamedType", &NamedType{})
gob.RegisterName(pkgPath+".PointerType", &PointerType{})
// Call gob.RegisterName to make sure it has the consistent name registered // Call gob.RegisterName to make sure it has the consistent name registered
// for both gob decoder and encoder. // for both gob decoder and encoder.
@ -156,7 +162,7 @@ func init() {
// For a non-pointer type, gob.Register will try to get package full path by // For a non-pointer type, gob.Register will try to get package full path by
// calling rt.PkgPath() for a name to register. If your project has vendor // calling rt.PkgPath() for a name to register. If your project has vendor
// directory, it is possible that PkgPath will get a path like this: // directory, it is possible that PkgPath will get a path like this:
// ../../../vendor/github.com/golang/mock/mockgen/model // ../../../vendor/go.uber.org/mock/mockgen/model
gob.RegisterName(pkgPath+".PredeclaredType", PredeclaredType("")) gob.RegisterName(pkgPath+".PredeclaredType", PredeclaredType(""))
} }
@ -259,26 +265,28 @@ func (mt *MapType) addImports(im map[string]bool) {
// NamedType is an exported type in a package. // NamedType is an exported type in a package.
type NamedType struct { type NamedType struct {
Package string // may be empty Package string // may be empty
Type string Type string
TypeParams *TypeParametersType
} }
func (nt *NamedType) String(pm map[string]string, pkgOverride string) string { func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
if pkgOverride == nt.Package { if pkgOverride == nt.Package {
return nt.Type return nt.Type + nt.TypeParams.String(pm, pkgOverride)
} }
prefix := pm[nt.Package] prefix := pm[nt.Package]
if prefix != "" { if prefix != "" {
return prefix + "." + nt.Type return prefix + "." + nt.Type + nt.TypeParams.String(pm, pkgOverride)
} }
return nt.Type return nt.Type + nt.TypeParams.String(pm, pkgOverride)
} }
func (nt *NamedType) addImports(im map[string]bool) { func (nt *NamedType) addImports(im map[string]bool) {
if nt.Package != "" { if nt.Package != "" {
im[nt.Package] = true im[nt.Package] = true
} }
nt.TypeParams.addImports(im)
} }
// PointerType is a pointer to another type. // PointerType is a pointer to another type.
@ -297,6 +305,36 @@ type PredeclaredType string
func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) } func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
func (pt PredeclaredType) addImports(map[string]bool) {} func (pt PredeclaredType) addImports(map[string]bool) {}
// TypeParametersType contains type paramters for a NamedType.
type TypeParametersType struct {
TypeParameters []Type
}
func (tp *TypeParametersType) String(pm map[string]string, pkgOverride string) string {
if tp == nil || len(tp.TypeParameters) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("[")
for i, v := range tp.TypeParameters {
if i != 0 {
sb.WriteString(", ")
}
sb.WriteString(v.String(pm, pkgOverride))
}
sb.WriteString("]")
return sb.String()
}
func (tp *TypeParametersType) addImports(im map[string]bool) {
if tp == nil {
return
}
for _, v := range tp.TypeParameters {
v.addImports(im)
}
}
// The following code is intended to be called by the program generated by ../reflect.go. // The following code is intended to be called by the program generated by ../reflect.go.
// InterfaceFromInterfaceType returns a pointer to an interface for the // InterfaceFromInterfaceType returns a pointer to an interface for the
@ -431,7 +469,7 @@ func typeFromType(t reflect.Type) (Type, error) {
case reflect.Interface: case reflect.Interface:
// Two special interfaces. // Two special interfaces.
if t.NumMethod() == 0 { if t.NumMethod() == 0 {
return PredeclaredType("interface{}"), nil return PredeclaredType("any"), nil
} }
if t == errorType { if t == errorType {
return PredeclaredType("error"), nil return PredeclaredType("error"), nil

View File

@ -18,7 +18,6 @@ package main
import ( import (
"errors" "errors"
"flag"
"fmt" "fmt"
"go/ast" "go/ast"
"go/build" "go/build"
@ -26,19 +25,14 @@ import (
"go/parser" "go/parser"
"go/token" "go/token"
"go/types" "go/types"
"io/ioutil"
"log" "log"
"os"
"path" "path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"github.com/golang/mock/mockgen/model" "go.uber.org/mock/mockgen/model"
)
var (
imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
) )
// sourceMode generates mocks via source file. // sourceMode generates mocks via source file.
@ -62,8 +56,8 @@ func sourceMode(source string) (*model.Package, error) {
p := &fileParser{ p := &fileParser{
fileSet: fs, fileSet: fs,
imports: make(map[string]importedPackage), imports: make(map[string]importedPackage),
importedInterfaces: make(map[string]map[string]*ast.InterfaceType), importedInterfaces: newInterfaceCache(),
auxInterfaces: make(map[string]map[string]*ast.InterfaceType), auxInterfaces: newInterfaceCache(),
srcDir: srcDir, srcDir: srcDir,
} }
@ -81,6 +75,10 @@ func sourceMode(source string) (*model.Package, error) {
} }
} }
if *excludeInterfaces != "" {
p.excludeNamesSet = parseExcludeInterfaces(*excludeInterfaces)
}
// Handle -aux_files. // Handle -aux_files.
if err := p.parseAuxFiles(*auxFiles); err != nil { if err := p.parseAuxFiles(*auxFiles); err != nil {
return nil, err return nil, err
@ -127,21 +125,55 @@ func (d duplicateImport) Error() string {
func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" } func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" }
func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil } func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil }
type fileParser struct { type interfaceCache struct {
fileSet *token.FileSet m map[string]map[string]*namedInterface
imports map[string]importedPackage // package name => imported package
importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
auxFiles []*ast.File
auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
srcDir string
} }
func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error { func newInterfaceCache() *interfaceCache {
return &interfaceCache{
m: make(map[string]map[string]*namedInterface),
}
}
func (i *interfaceCache) Set(pkg, name string, it *namedInterface) {
if _, ok := i.m[pkg]; !ok {
i.m[pkg] = make(map[string]*namedInterface)
}
i.m[pkg][name] = it
}
func (i *interfaceCache) Get(pkg, name string) *namedInterface {
if _, ok := i.m[pkg]; !ok {
return nil
}
return i.m[pkg][name]
}
func (i *interfaceCache) GetASTIface(pkg, name string) *ast.InterfaceType {
if _, ok := i.m[pkg]; !ok {
return nil
}
it, ok := i.m[pkg][name]
if !ok {
return nil
}
return it.it
}
type fileParser struct {
fileSet *token.FileSet
imports map[string]importedPackage // package name => imported package
importedInterfaces *interfaceCache
auxFiles []*ast.File
auxInterfaces *interfaceCache
srcDir string
excludeNamesSet map[string]struct{}
}
func (p *fileParser) errorf(pos token.Pos, format string, args ...any) error {
ps := p.fileSet.Position(pos) ps := p.fileSet.Position(pos)
format = "%s:%d:%d: " + format format = "%s:%d:%d: " + format
args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...) args = append([]any{ps.Filename, ps.Line, ps.Column}, args...)
return fmt.Errorf(format, args...) return fmt.Errorf(format, args...)
} }
@ -168,11 +200,8 @@ func (p *fileParser) parseAuxFiles(auxFiles string) error {
} }
func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) { func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
if _, ok := p.auxInterfaces[pkg]; !ok {
p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
}
for ni := range iterInterfaces(file) { for ni := range iterInterfaces(file) {
p.auxInterfaces[pkg][ni.name.Name] = ni.it p.auxInterfaces.Set(pkg, ni.name.Name, ni)
} }
} }
@ -199,7 +228,10 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag
var is []*model.Interface var is []*model.Interface
for ni := range iterInterfaces(file) { for ni := range iterInterfaces(file) {
i, err := p.parseInterface(ni.name.String(), importPath, ni.it) if _, ok := p.excludeNamesSet[ni.name.String()]; ok {
continue
}
i, err := p.parseInterface(ni.name.String(), importPath, ni)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -219,8 +251,8 @@ func (p *fileParser) parsePackage(path string) (*fileParser, error) {
newP := &fileParser{ newP := &fileParser{
fileSet: token.NewFileSet(), fileSet: token.NewFileSet(),
imports: make(map[string]importedPackage), imports: make(map[string]importedPackage),
importedInterfaces: make(map[string]map[string]*ast.InterfaceType), importedInterfaces: newInterfaceCache(),
auxInterfaces: make(map[string]map[string]*ast.InterfaceType), auxInterfaces: newInterfaceCache(),
srcDir: p.srcDir, srcDir: p.srcDir,
} }
@ -233,11 +265,8 @@ func (p *fileParser) parsePackage(path string) (*fileParser, error) {
for _, pkg := range pkgs { for _, pkg := range pkgs {
file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates) file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
if _, ok := newP.importedInterfaces[path]; !ok {
newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
}
for ni := range iterInterfaces(file) { for ni := range iterInterfaces(file) {
newP.importedInterfaces[path][ni.name.Name] = ni.it newP.importedInterfaces.Set(path, ni.name.Name, ni)
} }
imports, _ := importsOfFile(file) imports, _ := importsOfFile(file)
for pkgName, pkgI := range imports { for pkgName, pkgI := range imports {
@ -247,9 +276,77 @@ func (p *fileParser) parsePackage(path string) (*fileParser, error) {
return newP, nil return newP, nil
} }
func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) { func (p *fileParser) constructInstParams(pkg string, params []*ast.Field, instParams []model.Type, embeddedInstParams []ast.Expr, tps map[string]model.Type) ([]model.Type, error) {
pm := make(map[string]int)
var i int
for _, v := range params {
for _, n := range v.Names {
pm[n.Name] = i
instParams = append(instParams, model.PredeclaredType(n.Name))
i++
}
}
var runtimeInstParams []model.Type
for _, instParam := range embeddedInstParams {
switch t := instParam.(type) {
case *ast.Ident:
if idx, ok := pm[t.Name]; ok {
runtimeInstParams = append(runtimeInstParams, instParams[idx])
continue
}
}
modelType, err := p.parseType(pkg, instParam, tps)
if err != nil {
return nil, err
}
runtimeInstParams = append(runtimeInstParams, modelType)
}
return runtimeInstParams, nil
}
func (p *fileParser) constructTps(it *namedInterface) (tps map[string]model.Type) {
tps = make(map[string]model.Type)
n := 0
for _, tp := range it.typeParams {
for _, tm := range tp.Names {
tps[tm.Name] = nil
if len(it.instTypes) != 0 {
tps[tm.Name] = it.instTypes[n]
n++
}
}
}
return tps
}
// parseInterface loads interface specified by pkg and name, parses it and returns
// a new model with the parsed.
func (p *fileParser) parseInterface(name, pkg string, it *namedInterface) (*model.Interface, error) {
iface := &model.Interface{Name: name} iface := &model.Interface{Name: name}
for _, field := range it.Methods.List { tps := p.constructTps(it)
tp, err := p.parseFieldList(pkg, it.typeParams, tps)
if err != nil {
return nil, fmt.Errorf("unable to parse interface type parameters: %v", name)
}
iface.TypeParams = tp
for _, field := range it.it.Methods.List {
var methods []*model.Method
if methods, err = p.parseMethod(field, it, iface, pkg, tps); err != nil {
return nil, err
}
for _, m := range methods {
iface.AddMethod(m)
}
}
return iface, nil
}
func (p *fileParser) parseMethod(field *ast.Field, it *namedInterface, iface *model.Interface, pkg string, tps map[string]model.Type) ([]*model.Method, error) {
// {} for git diff
{
switch v := field.Type.(type) { switch v := field.Type.(type) {
case *ast.FuncType: case *ast.FuncType:
if nn := len(field.Names); nn != 1 { if nn := len(field.Names); nn != 1 {
@ -259,37 +356,55 @@ func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*m
Name: field.Names[0].String(), Name: field.Names[0].String(),
} }
var err error var err error
m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v) m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v, tps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iface.AddMethod(m) return []*model.Method{m}, nil
case *ast.Ident: case *ast.Ident:
// Embedded interface in this package. // Embedded interface in this package.
embeddedIfaceType := p.auxInterfaces[pkg][v.String()] embeddedIfaceType := p.auxInterfaces.Get(pkg, v.String())
if embeddedIfaceType == nil { if embeddedIfaceType == nil {
embeddedIfaceType = p.importedInterfaces[pkg][v.String()] embeddedIfaceType = p.importedInterfaces.Get(pkg, v.String())
} }
var embeddedIface *model.Interface var embeddedIface *model.Interface
if embeddedIfaceType != nil { if embeddedIfaceType != nil {
var err error var err error
embeddedIfaceType.instTypes, err = p.constructInstParams(pkg, it.typeParams, it.instTypes, it.embeddedInstTypeParams, tps)
if err != nil {
return nil, err
}
embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType) embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
// This is built-in error interface. // This is built-in error interface.
if v.String() == model.ErrorInterface.Name { if v.String() == model.ErrorInterface.Name {
embeddedIface = &model.ErrorInterface embeddedIface = &model.ErrorInterface
} else { } else {
return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String()) ip, err := p.parsePackage(pkg)
if err != nil {
return nil, p.errorf(v.Pos(), "could not parse package %s: %v", pkg, err)
}
if embeddedIfaceType = ip.importedInterfaces.Get(pkg, v.String()); embeddedIfaceType == nil {
return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", pkg, v.String())
}
embeddedIfaceType.instTypes, err = p.constructInstParams(pkg, it.typeParams, it.instTypes, it.embeddedInstTypeParams, tps)
if err != nil {
return nil, err
}
embeddedIface, err = ip.parseInterface(v.String(), pkg, embeddedIfaceType)
if err != nil {
return nil, err
}
} }
} }
// Copy the methods. return embeddedIface.Methods, nil
for _, m := range embeddedIface.Methods {
iface.AddMethod(m)
}
case *ast.SelectorExpr: case *ast.SelectorExpr:
// Embedded interface in another package. // Embedded interface in another package.
filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
@ -300,8 +415,12 @@ func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*m
var embeddedIface *model.Interface var embeddedIface *model.Interface
var err error var err error
embeddedIfaceType := p.auxInterfaces[filePkg][sel] embeddedIfaceType := p.auxInterfaces.Get(filePkg, sel)
if embeddedIfaceType != nil { if embeddedIfaceType != nil {
embeddedIfaceType.instTypes, err = p.constructInstParams(pkg, it.typeParams, it.instTypes, it.embeddedInstTypeParams, tps)
if err != nil {
return nil, err
}
embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType)
if err != nil { if err != nil {
return nil, err return nil, err
@ -320,46 +439,47 @@ func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*m
parser: parser, parser: parser,
} }
} }
if embeddedIfaceType = parser.importedInterfaces[path][sel]; embeddedIfaceType == nil { if embeddedIfaceType = parser.importedInterfaces.Get(path, sel); embeddedIfaceType == nil {
return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel)
} }
embeddedIfaceType.instTypes, err = p.constructInstParams(pkg, it.typeParams, it.instTypes, it.embeddedInstTypeParams, tps)
if err != nil {
return nil, err
}
embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// Copy the methods.
// TODO: apply shadowing rules. // TODO: apply shadowing rules.
for _, m := range embeddedIface.Methods { return embeddedIface.Methods, nil
iface.AddMethod(m)
}
default: default:
return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) return p.parseGenericMethod(field, it, iface, pkg, tps)
} }
} }
return iface, nil
} }
func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) { func (p *fileParser) parseFunc(pkg string, f *ast.FuncType, tps map[string]model.Type) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) {
if f.Params != nil { if f.Params != nil {
regParams := f.Params.List regParams := f.Params.List
if isVariadic(f) { if isVariadic(f) {
n := len(regParams) n := len(regParams)
varParams := regParams[n-1:] varParams := regParams[n-1:]
regParams = regParams[:n-1] regParams = regParams[:n-1]
vp, err := p.parseFieldList(pkg, varParams) vp, err := p.parseFieldList(pkg, varParams, tps)
if err != nil { if err != nil {
return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err) return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
} }
variadic = vp[0] variadic = vp[0]
} }
inParam, err = p.parseFieldList(pkg, regParams) inParam, err = p.parseFieldList(pkg, regParams, tps)
if err != nil { if err != nil {
return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err) return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
} }
} }
if f.Results != nil { if f.Results != nil {
outParam, err = p.parseFieldList(pkg, f.Results.List) outParam, err = p.parseFieldList(pkg, f.Results.List, tps)
if err != nil { if err != nil {
return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err) return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
} }
@ -367,7 +487,7 @@ func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Pa
return return
} }
func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) { func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field, tps map[string]model.Type) ([]*model.Parameter, error) {
nf := 0 nf := 0
for _, f := range fields { for _, f := range fields {
nn := len(f.Names) nn := len(f.Names)
@ -382,7 +502,7 @@ func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.P
ps := make([]*model.Parameter, nf) ps := make([]*model.Parameter, nf)
i := 0 // destination index i := 0 // destination index
for _, f := range fields { for _, f := range fields {
t, err := p.parseType(pkg, f.Type) t, err := p.parseType(pkg, f.Type, tps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -401,44 +521,27 @@ func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.P
return ps, nil return ps, nil
} }
func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { func (p *fileParser) parseType(pkg string, typ ast.Expr, tps map[string]model.Type) (model.Type, error) {
switch v := typ.(type) { switch v := typ.(type) {
case *ast.ArrayType: case *ast.ArrayType:
ln := -1 ln := -1
if v.Len != nil { if v.Len != nil {
var value string value, err := p.parseArrayLength(v.Len)
switch val := v.Len.(type) { if err != nil {
case (*ast.BasicLit): return nil, err
value = val.Value
case (*ast.Ident):
// when the length is a const defined locally
value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value
case (*ast.SelectorExpr):
// when the length is a const defined in an external package
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
if err != nil {
return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err)
}
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
if err != nil {
return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err)
}
value = ev.Value.String()
} }
ln, err = strconv.Atoi(value)
x, err := strconv.Atoi(value)
if err != nil { if err != nil {
return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err) return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
} }
ln = x
} }
t, err := p.parseType(pkg, v.Elt) t, err := p.parseType(pkg, v.Elt, tps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &model.ArrayType{Len: ln, Type: t}, nil return &model.ArrayType{Len: ln, Type: t}, nil
case *ast.ChanType: case *ast.ChanType:
t, err := p.parseType(pkg, v.Value) t, err := p.parseType(pkg, v.Value, tps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -452,15 +555,16 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
return &model.ChanType{Dir: dir, Type: t}, nil return &model.ChanType{Dir: dir, Type: t}, nil
case *ast.Ellipsis: case *ast.Ellipsis:
// assume we're parsing a variadic argument // assume we're parsing a variadic argument
return p.parseType(pkg, v.Elt) return p.parseType(pkg, v.Elt, tps)
case *ast.FuncType: case *ast.FuncType:
in, variadic, out, err := p.parseFunc(pkg, v) in, variadic, out, err := p.parseFunc(pkg, v, tps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
case *ast.Ident: case *ast.Ident:
if v.IsExported() { it, ok := tps[v.Name]
if v.IsExported() && !ok {
// `pkg` may be an aliased imported pkg // `pkg` may be an aliased imported pkg
// if so, patch the import w/ the fully qualified import // if so, patch the import w/ the fully qualified import
maybeImportedPkg, ok := p.imports[pkg] maybeImportedPkg, ok := p.imports[pkg]
@ -470,20 +574,22 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
// assume type in this package // assume type in this package
return &model.NamedType{Package: pkg, Type: v.Name}, nil return &model.NamedType{Package: pkg, Type: v.Name}, nil
} }
if ok && it != nil {
return it, nil
}
// assume predeclared type // assume predeclared type
return model.PredeclaredType(v.Name), nil return model.PredeclaredType(v.Name), nil
case *ast.InterfaceType: case *ast.InterfaceType:
if v.Methods != nil && len(v.Methods.List) > 0 { if v.Methods != nil && len(v.Methods.List) > 0 {
return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types") return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
} }
return model.PredeclaredType("interface{}"), nil return model.PredeclaredType("any"), nil
case *ast.MapType: case *ast.MapType:
key, err := p.parseType(pkg, v.Key) key, err := p.parseType(pkg, v.Key, tps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
value, err := p.parseType(pkg, v.Value) value, err := p.parseType(pkg, v.Value, tps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -496,7 +602,7 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
} }
return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil
case *ast.StarExpr: case *ast.StarExpr:
t, err := p.parseType(pkg, v.X) t, err := p.parseType(pkg, v.X, tps)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -507,12 +613,61 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
} }
return model.PredeclaredType("struct{}"), nil return model.PredeclaredType("struct{}"), nil
case *ast.ParenExpr: case *ast.ParenExpr:
return p.parseType(pkg, v.X) return p.parseType(pkg, v.X, tps)
default:
mt, err := p.parseGenericType(pkg, typ, tps)
if err != nil {
return nil, err
}
if mt == nil {
break
}
return mt, nil
} }
return nil, fmt.Errorf("don't know how to parse type %T", typ) return nil, fmt.Errorf("don't know how to parse type %T", typ)
} }
func (p *fileParser) parseArrayLength(expr ast.Expr) (string, error) {
switch val := expr.(type) {
case (*ast.BasicLit):
return val.Value, nil
case (*ast.Ident):
// when the length is a const defined locally
return val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value, nil
case (*ast.SelectorExpr):
// when the length is a const defined in an external package
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
if err != nil {
return "", p.errorf(expr.Pos(), "unknown package in array length: %v", err)
}
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
if err != nil {
return "", p.errorf(expr.Pos(), "unknown constant in array length: %v", err)
}
return ev.Value.String(), nil
case (*ast.ParenExpr):
return p.parseArrayLength(val.X)
case (*ast.BinaryExpr):
x, err := p.parseArrayLength(val.X)
if err != nil {
return "", err
}
y, err := p.parseArrayLength(val.Y)
if err != nil {
return "", err
}
biExpr := fmt.Sprintf("%s%v%s", x, val.Op, y)
tv, err := types.Eval(token.NewFileSet(), nil, token.NoPos, biExpr)
if err != nil {
return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", err)
}
return tv.Value.String(), nil
default:
return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", val)
}
}
// importsOfFile returns a map of package name to import path // importsOfFile returns a map of package name to import path
// of the imports in file. // of the imports in file.
func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) { func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) {
@ -575,13 +730,16 @@ func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, do
} }
type namedInterface struct { type namedInterface struct {
name *ast.Ident name *ast.Ident
it *ast.InterfaceType it *ast.InterfaceType
typeParams []*ast.Field
embeddedInstTypeParams []ast.Expr
instTypes []model.Type
} }
// Create an iterator over all interfaces in file. // Create an iterator over all interfaces in file.
func iterInterfaces(file *ast.File) <-chan namedInterface { func iterInterfaces(file *ast.File) <-chan *namedInterface {
ch := make(chan namedInterface) ch := make(chan *namedInterface)
go func() { go func() {
for _, decl := range file.Decls { for _, decl := range file.Decls {
gd, ok := decl.(*ast.GenDecl) gd, ok := decl.(*ast.GenDecl)
@ -598,7 +756,7 @@ func iterInterfaces(file *ast.File) <-chan namedInterface {
continue continue
} }
ch <- namedInterface{ts.Name, it} ch <- &namedInterface{name: ts.Name, it: it, typeParams: getTypeSpecTypeParams(ts)}
} }
} }
close(ch) close(ch)
@ -618,7 +776,7 @@ func isVariadic(f *ast.FuncType) bool {
// packageNameOfDir get package import path via dir // packageNameOfDir get package import path via dir
func packageNameOfDir(srcDir string) (string, error) { func packageNameOfDir(srcDir string) (string, error) {
files, err := ioutil.ReadDir(srcDir) files, err := os.ReadDir(srcDir)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -23,7 +23,6 @@ import (
"fmt" "fmt"
"go/build" "go/build"
"io" "io"
"io/ioutil"
"log" "log"
"os" "os"
"os/exec" "os/exec"
@ -32,7 +31,7 @@ import (
"strings" "strings"
"text/template" "text/template"
"github.com/golang/mock/mockgen/model" "go.uber.org/mock/mockgen/model"
) )
var ( var (
@ -92,7 +91,7 @@ func writeProgram(importPath string, symbols []string) ([]byte, error) {
// run the given program and parse the output as a model.Package. // run the given program and parse the output as a model.Package.
func run(program string) (*model.Package, error) { func run(program string) (*model.Package, error) {
f, err := ioutil.TempFile("", "") f, err := os.CreateTemp("", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -133,7 +132,7 @@ func run(program string) (*model.Package, error) {
// parses the output as a model.Package. // parses the output as a model.Package.
func runInDir(program []byte, dir string) (*model.Package, error) { func runInDir(program []byte, dir string) (*model.Package, error) {
// We use TempDir instead of TempFile so we can control the filename. // We use TempDir instead of TempFile so we can control the filename.
tmpDir, err := ioutil.TempDir(dir, "gomock_reflect_") tmpDir, err := os.MkdirTemp(dir, "gomock_reflect_")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -149,7 +148,7 @@ func runInDir(program []byte, dir string) (*model.Package, error) {
progBinary += ".exe" progBinary += ".exe"
} }
if err := ioutil.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil { if err := os.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil {
return nil, err return nil, err
} }
@ -169,8 +168,8 @@ func runInDir(program []byte, dir string) (*model.Package, error) {
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
sErr := buf.String() sErr := buf.String()
if strings.Contains(sErr, `cannot find package "."`) && if strings.Contains(sErr, `cannot find package "."`) &&
strings.Contains(sErr, "github.com/golang/mock/mockgen/model") { strings.Contains(sErr, "go.uber.org/mock/mockgen/model") {
fmt.Fprint(os.Stderr, "Please reference the steps in the README to fix this error:\n\thttps://github.com/golang/mock#reflect-vendoring-error.") fmt.Fprint(os.Stderr, "Please reference the steps in the README to fix this error:\n\thttps://go.uber.org/mock#reflect-vendoring-error.\n")
return nil, err return nil, err
} }
return nil, err return nil, err
@ -198,7 +197,7 @@ import (
"path" "path"
"reflect" "reflect"
"github.com/golang/mock/mockgen/model" "go.uber.org/mock/mockgen/model"
pkg_ {{printf "%q" .ImportPath}} pkg_ {{printf "%q" .ImportPath}}
) )

View File

@ -1,4 +1,4 @@
// Copyright 2019 Google LLC // Copyright 2022 Google LLC
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -11,9 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//
// +build go1.12
package main package main
@ -31,5 +28,4 @@ func printModuleVersion() {
"GO111MODULE=on when running 'go get' in order to use specific " + "GO111MODULE=on when running 'go get' in order to use specific " +
"version of the binary.") "version of the binary.")
} }
} }

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.11 && gc && !purego //go:build gc && !purego
// +build go1.11,gc,!purego // +build gc,!purego
package chacha20 package chacha20

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.11 && gc && !purego //go:build gc && !purego
// +build go1.11,gc,!purego // +build gc,!purego
#include "textflag.h" #include "textflag.h"

View File

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (!arm64 && !s390x && !ppc64le) || (arm64 && !go1.11) || !gc || purego //go:build (!arm64 && !s390x && !ppc64le) || !gc || purego
// +build !arm64,!s390x,!ppc64le arm64,!go1.11 !gc purego // +build !arm64,!s390x,!ppc64le !gc purego
package chacha20 package chacha20

View File

@ -95,6 +95,11 @@ func (b *Builder) AddUint32(v uint32) {
b.add(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) b.add(byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
} }
// AddUint48 appends a big-endian, 48-bit value to the byte string.
func (b *Builder) AddUint48(v uint64) {
b.add(byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
// AddUint64 appends a big-endian, 64-bit value to the byte string. // AddUint64 appends a big-endian, 64-bit value to the byte string.
func (b *Builder) AddUint64(v uint64) { func (b *Builder) AddUint64(v uint64) {
b.add(byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) b.add(byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v))

View File

@ -81,6 +81,17 @@ func (s *String) ReadUint32(out *uint32) bool {
return true return true
} }
// ReadUint48 decodes a big-endian, 48-bit value into out and advances over it.
// It reports whether the read was successful.
func (s *String) ReadUint48(out *uint64) bool {
v := s.read(6)
if v == nil {
return false
}
*out = uint64(v[0])<<40 | uint64(v[1])<<32 | uint64(v[2])<<24 | uint64(v[3])<<16 | uint64(v[4])<<8 | uint64(v[5])
return true
}
// ReadUint64 decodes a big-endian, 64-bit value into out and advances over it. // ReadUint64 decodes a big-endian, 64-bit value into out and advances over it.
// It reports whether the read was successful. // It reports whether the read was successful.
func (s *String) ReadUint64(out *uint64) bool { func (s *String) ReadUint64(out *uint64) bool {

221
vendor/golang.org/x/exp/rand/exp.go generated vendored Normal file
View File

@ -0,0 +1,221 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
import (
"math"
)
/*
* Exponential distribution
*
* See "The Ziggurat Method for Generating Random Variables"
* (Marsaglia & Tsang, 2000)
* http://www.jstatsoft.org/v05/i08/paper [pdf]
*/
const (
re = 7.69711747013104972
)
// ExpFloat64 returns an exponentially distributed float64 in the range
// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter
// (lambda) is 1 and whose mean is 1/lambda (1).
// To produce a distribution with a different rate parameter,
// callers can adjust the output using:
//
// sample = ExpFloat64() / desiredRateParameter
func (r *Rand) ExpFloat64() float64 {
for {
j := r.Uint32()
i := j & 0xFF
x := float64(j) * float64(we[i])
if j < ke[i] {
return x
}
if i == 0 {
return re - math.Log(r.Float64())
}
if fe[i]+float32(r.Float64())*(fe[i-1]-fe[i]) < float32(math.Exp(-x)) {
return x
}
}
}
var ke = [256]uint32{
0xe290a139, 0x0, 0x9beadebc, 0xc377ac71, 0xd4ddb990,
0xde893fb8, 0xe4a8e87c, 0xe8dff16a, 0xebf2deab, 0xee49a6e8,
0xf0204efd, 0xf19bdb8e, 0xf2d458bb, 0xf3da104b, 0xf4b86d78,
0xf577ad8a, 0xf61de83d, 0xf6afb784, 0xf730a573, 0xf7a37651,
0xf80a5bb6, 0xf867189d, 0xf8bb1b4f, 0xf9079062, 0xf94d70ca,
0xf98d8c7d, 0xf9c8928a, 0xf9ff175b, 0xfa319996, 0xfa6085f8,
0xfa8c3a62, 0xfab5084e, 0xfadb36c8, 0xfaff0410, 0xfb20a6ea,
0xfb404fb4, 0xfb5e2951, 0xfb7a59e9, 0xfb95038c, 0xfbae44ba,
0xfbc638d8, 0xfbdcf892, 0xfbf29a30, 0xfc0731df, 0xfc1ad1ed,
0xfc2d8b02, 0xfc3f6c4d, 0xfc5083ac, 0xfc60ddd1, 0xfc708662,
0xfc7f8810, 0xfc8decb4, 0xfc9bbd62, 0xfca9027c, 0xfcb5c3c3,
0xfcc20864, 0xfccdd70a, 0xfcd935e3, 0xfce42ab0, 0xfceebace,
0xfcf8eb3b, 0xfd02c0a0, 0xfd0c3f59, 0xfd156b7b, 0xfd1e48d6,
0xfd26daff, 0xfd2f2552, 0xfd372af7, 0xfd3eeee5, 0xfd4673e7,
0xfd4dbc9e, 0xfd54cb85, 0xfd5ba2f2, 0xfd62451b, 0xfd68b415,
0xfd6ef1da, 0xfd750047, 0xfd7ae120, 0xfd809612, 0xfd8620b4,
0xfd8b8285, 0xfd90bcf5, 0xfd95d15e, 0xfd9ac10b, 0xfd9f8d36,
0xfda43708, 0xfda8bf9e, 0xfdad2806, 0xfdb17141, 0xfdb59c46,
0xfdb9a9fd, 0xfdbd9b46, 0xfdc170f6, 0xfdc52bd8, 0xfdc8ccac,
0xfdcc542d, 0xfdcfc30b, 0xfdd319ef, 0xfdd6597a, 0xfdd98245,
0xfddc94e5, 0xfddf91e6, 0xfde279ce, 0xfde54d1f, 0xfde80c52,
0xfdeab7de, 0xfded5034, 0xfdefd5be, 0xfdf248e3, 0xfdf4aa06,
0xfdf6f984, 0xfdf937b6, 0xfdfb64f4, 0xfdfd818d, 0xfdff8dd0,
0xfe018a08, 0xfe03767a, 0xfe05536c, 0xfe07211c, 0xfe08dfc9,
0xfe0a8fab, 0xfe0c30fb, 0xfe0dc3ec, 0xfe0f48b1, 0xfe10bf76,
0xfe122869, 0xfe1383b4, 0xfe14d17c, 0xfe1611e7, 0xfe174516,
0xfe186b2a, 0xfe19843e, 0xfe1a9070, 0xfe1b8fd6, 0xfe1c8289,
0xfe1d689b, 0xfe1e4220, 0xfe1f0f26, 0xfe1fcfbc, 0xfe2083ed,
0xfe212bc3, 0xfe21c745, 0xfe225678, 0xfe22d95f, 0xfe234ffb,
0xfe23ba4a, 0xfe241849, 0xfe2469f2, 0xfe24af3c, 0xfe24e81e,
0xfe25148b, 0xfe253474, 0xfe2547c7, 0xfe254e70, 0xfe25485a,
0xfe25356a, 0xfe251586, 0xfe24e88f, 0xfe24ae64, 0xfe2466e1,
0xfe2411df, 0xfe23af34, 0xfe233eb4, 0xfe22c02c, 0xfe22336b,
0xfe219838, 0xfe20ee58, 0xfe20358c, 0xfe1f6d92, 0xfe1e9621,
0xfe1daef0, 0xfe1cb7ac, 0xfe1bb002, 0xfe1a9798, 0xfe196e0d,
0xfe1832fd, 0xfe16e5fe, 0xfe15869d, 0xfe141464, 0xfe128ed3,
0xfe10f565, 0xfe0f478c, 0xfe0d84b1, 0xfe0bac36, 0xfe09bd73,
0xfe07b7b5, 0xfe059a40, 0xfe03644c, 0xfe011504, 0xfdfeab88,
0xfdfc26e9, 0xfdf98629, 0xfdf6c83b, 0xfdf3ec01, 0xfdf0f04a,
0xfdedd3d1, 0xfdea953d, 0xfde7331e, 0xfde3abe9, 0xfddffdfb,
0xfddc2791, 0xfdd826cd, 0xfdd3f9a8, 0xfdcf9dfc, 0xfdcb1176,
0xfdc65198, 0xfdc15bb3, 0xfdbc2ce2, 0xfdb6c206, 0xfdb117be,
0xfdab2a63, 0xfda4f5fd, 0xfd9e7640, 0xfd97a67a, 0xfd908192,
0xfd8901f2, 0xfd812182, 0xfd78d98e, 0xfd7022bb, 0xfd66f4ed,
0xfd5d4732, 0xfd530f9c, 0xfd48432b, 0xfd3cd59a, 0xfd30b936,
0xfd23dea4, 0xfd16349e, 0xfd07a7a3, 0xfcf8219b, 0xfce7895b,
0xfcd5c220, 0xfcc2aadb, 0xfcae1d5e, 0xfc97ed4e, 0xfc7fe6d4,
0xfc65ccf3, 0xfc495762, 0xfc2a2fc8, 0xfc07ee19, 0xfbe213c1,
0xfbb8051a, 0xfb890078, 0xfb5411a5, 0xfb180005, 0xfad33482,
0xfa839276, 0xfa263b32, 0xf9b72d1c, 0xf930a1a2, 0xf889f023,
0xf7b577d2, 0xf69c650c, 0xf51530f0, 0xf2cb0e3c, 0xeeefb15d,
0xe6da6ecf,
}
var we = [256]float32{
2.0249555e-09, 1.486674e-11, 2.4409617e-11, 3.1968806e-11,
3.844677e-11, 4.4228204e-11, 4.9516443e-11, 5.443359e-11,
5.905944e-11, 6.344942e-11, 6.7643814e-11, 7.1672945e-11,
7.556032e-11, 7.932458e-11, 8.298079e-11, 8.654132e-11,
9.0016515e-11, 9.3415074e-11, 9.674443e-11, 1.0001099e-10,
1.03220314e-10, 1.06377254e-10, 1.09486115e-10, 1.1255068e-10,
1.1557435e-10, 1.1856015e-10, 1.2151083e-10, 1.2442886e-10,
1.2731648e-10, 1.3017575e-10, 1.3300853e-10, 1.3581657e-10,
1.3860142e-10, 1.4136457e-10, 1.4410738e-10, 1.4683108e-10,
1.4953687e-10, 1.5222583e-10, 1.54899e-10, 1.5755733e-10,
1.6020171e-10, 1.6283301e-10, 1.6545203e-10, 1.6805951e-10,
1.7065617e-10, 1.732427e-10, 1.7581973e-10, 1.7838787e-10,
1.8094774e-10, 1.8349985e-10, 1.8604476e-10, 1.8858298e-10,
1.9111498e-10, 1.9364126e-10, 1.9616223e-10, 1.9867835e-10,
2.0119004e-10, 2.0369768e-10, 2.0620168e-10, 2.087024e-10,
2.1120022e-10, 2.136955e-10, 2.1618855e-10, 2.1867974e-10,
2.2116936e-10, 2.2365775e-10, 2.261452e-10, 2.2863202e-10,
2.311185e-10, 2.3360494e-10, 2.360916e-10, 2.3857874e-10,
2.4106667e-10, 2.4355562e-10, 2.4604588e-10, 2.485377e-10,
2.5103128e-10, 2.5352695e-10, 2.560249e-10, 2.585254e-10,
2.6102867e-10, 2.6353494e-10, 2.6604446e-10, 2.6855745e-10,
2.7107416e-10, 2.7359479e-10, 2.761196e-10, 2.7864877e-10,
2.8118255e-10, 2.8372119e-10, 2.8626485e-10, 2.888138e-10,
2.9136826e-10, 2.939284e-10, 2.9649452e-10, 2.9906677e-10,
3.016454e-10, 3.0423064e-10, 3.0682268e-10, 3.0942177e-10,
3.1202813e-10, 3.1464195e-10, 3.1726352e-10, 3.19893e-10,
3.2253064e-10, 3.251767e-10, 3.2783135e-10, 3.3049485e-10,
3.3316744e-10, 3.3584938e-10, 3.3854083e-10, 3.4124212e-10,
3.4395342e-10, 3.46675e-10, 3.4940711e-10, 3.5215003e-10,
3.5490397e-10, 3.5766917e-10, 3.6044595e-10, 3.6323455e-10,
3.660352e-10, 3.6884823e-10, 3.7167386e-10, 3.745124e-10,
3.773641e-10, 3.802293e-10, 3.8310827e-10, 3.860013e-10,
3.8890866e-10, 3.918307e-10, 3.9476775e-10, 3.9772008e-10,
4.0068804e-10, 4.0367196e-10, 4.0667217e-10, 4.09689e-10,
4.1272286e-10, 4.1577405e-10, 4.1884296e-10, 4.2192994e-10,
4.250354e-10, 4.281597e-10, 4.313033e-10, 4.3446652e-10,
4.3764986e-10, 4.408537e-10, 4.4407847e-10, 4.4732465e-10,
4.5059267e-10, 4.5388301e-10, 4.571962e-10, 4.6053267e-10,
4.6389292e-10, 4.6727755e-10, 4.70687e-10, 4.741219e-10,
4.7758275e-10, 4.810702e-10, 4.845848e-10, 4.8812715e-10,
4.9169796e-10, 4.9529775e-10, 4.989273e-10, 5.0258725e-10,
5.0627835e-10, 5.100013e-10, 5.1375687e-10, 5.1754584e-10,
5.21369e-10, 5.2522725e-10, 5.2912136e-10, 5.330522e-10,
5.370208e-10, 5.4102806e-10, 5.45075e-10, 5.491625e-10,
5.532918e-10, 5.5746385e-10, 5.616799e-10, 5.6594107e-10,
5.7024857e-10, 5.746037e-10, 5.7900773e-10, 5.834621e-10,
5.8796823e-10, 5.925276e-10, 5.971417e-10, 6.018122e-10,
6.065408e-10, 6.113292e-10, 6.1617933e-10, 6.2109295e-10,
6.260722e-10, 6.3111916e-10, 6.3623595e-10, 6.4142497e-10,
6.4668854e-10, 6.5202926e-10, 6.5744976e-10, 6.6295286e-10,
6.6854156e-10, 6.742188e-10, 6.79988e-10, 6.858526e-10,
6.9181616e-10, 6.978826e-10, 7.04056e-10, 7.103407e-10,
7.167412e-10, 7.2326256e-10, 7.2990985e-10, 7.366886e-10,
7.4360473e-10, 7.5066453e-10, 7.5787476e-10, 7.6524265e-10,
7.7277595e-10, 7.80483e-10, 7.883728e-10, 7.9645507e-10,
8.047402e-10, 8.1323964e-10, 8.219657e-10, 8.309319e-10,
8.401528e-10, 8.496445e-10, 8.594247e-10, 8.6951274e-10,
8.799301e-10, 8.9070046e-10, 9.018503e-10, 9.134092e-10,
9.254101e-10, 9.378904e-10, 9.508923e-10, 9.644638e-10,
9.786603e-10, 9.935448e-10, 1.0091913e-09, 1.025686e-09,
1.0431306e-09, 1.0616465e-09, 1.08138e-09, 1.1025096e-09,
1.1252564e-09, 1.1498986e-09, 1.1767932e-09, 1.206409e-09,
1.2393786e-09, 1.276585e-09, 1.3193139e-09, 1.3695435e-09,
1.4305498e-09, 1.508365e-09, 1.6160854e-09, 1.7921248e-09,
}
var fe = [256]float32{
1, 0.9381437, 0.90046996, 0.87170434, 0.8477855, 0.8269933,
0.8084217, 0.7915276, 0.77595687, 0.7614634, 0.7478686,
0.7350381, 0.72286767, 0.71127474, 0.70019263, 0.6895665,
0.67935055, 0.6695063, 0.66000086, 0.65080583, 0.6418967,
0.63325197, 0.6248527, 0.6166822, 0.60872537, 0.60096896,
0.5934009, 0.58601034, 0.5787874, 0.57172304, 0.5648092,
0.5580383, 0.5514034, 0.5448982, 0.5385169, 0.53225386,
0.5261042, 0.52006316, 0.5141264, 0.50828975, 0.5025495,
0.496902, 0.49134386, 0.485872, 0.48048335, 0.4751752,
0.46994483, 0.46478975, 0.45970762, 0.45469615, 0.44975325,
0.44487688, 0.44006512, 0.43531612, 0.43062815, 0.42599955,
0.42142874, 0.4169142, 0.41245446, 0.40804818, 0.403694,
0.3993907, 0.39513698, 0.39093173, 0.38677382, 0.38266218,
0.37859577, 0.37457356, 0.37059465, 0.3666581, 0.362763,
0.35890847, 0.35509375, 0.351318, 0.3475805, 0.34388044,
0.34021714, 0.3365899, 0.33299807, 0.32944095, 0.32591796,
0.3224285, 0.3189719, 0.31554767, 0.31215525, 0.30879408,
0.3054636, 0.3021634, 0.29889292, 0.2956517, 0.29243928,
0.28925523, 0.28609908, 0.28297043, 0.27986884, 0.27679393,
0.2737453, 0.2707226, 0.2677254, 0.26475343, 0.26180625,
0.25888354, 0.25598502, 0.2531103, 0.25025907, 0.24743107,
0.24462597, 0.24184346, 0.23908329, 0.23634516, 0.23362878,
0.23093392, 0.2282603, 0.22560766, 0.22297576, 0.22036438,
0.21777324, 0.21520215, 0.21265087, 0.21011916, 0.20760682,
0.20511365, 0.20263945, 0.20018397, 0.19774707, 0.19532852,
0.19292815, 0.19054577, 0.1881812, 0.18583426, 0.18350479,
0.1811926, 0.17889754, 0.17661946, 0.17435817, 0.17211354,
0.1698854, 0.16767362, 0.16547804, 0.16329853, 0.16113494,
0.15898713, 0.15685499, 0.15473837, 0.15263714, 0.15055119,
0.14848037, 0.14642459, 0.14438373, 0.14235765, 0.14034624,
0.13834943, 0.13636707, 0.13439907, 0.13244532, 0.13050574,
0.1285802, 0.12666863, 0.12477092, 0.12288698, 0.12101672,
0.119160056, 0.1173169, 0.115487166, 0.11367077, 0.11186763,
0.11007768, 0.10830083, 0.10653701, 0.10478614, 0.10304816,
0.101323, 0.09961058, 0.09791085, 0.09622374, 0.09454919,
0.09288713, 0.091237515, 0.08960028, 0.087975375, 0.08636274,
0.08476233, 0.083174095, 0.081597984, 0.08003395, 0.07848195,
0.076941945, 0.07541389, 0.07389775, 0.072393484, 0.07090106,
0.069420435, 0.06795159, 0.066494495, 0.06504912, 0.063615434,
0.062193416, 0.060783047, 0.059384305, 0.057997175,
0.05662164, 0.05525769, 0.053905312, 0.052564494, 0.051235236,
0.049917534, 0.048611384, 0.047316793, 0.046033762, 0.0447623,
0.043502413, 0.042254124, 0.041017443, 0.039792392,
0.038578995, 0.037377283, 0.036187284, 0.035009038,
0.033842582, 0.032687962, 0.031545233, 0.030414443, 0.02929566,
0.02818895, 0.027094385, 0.026012046, 0.024942026, 0.023884421,
0.022839336, 0.021806888, 0.020787204, 0.019780423, 0.0187867,
0.0178062, 0.016839107, 0.015885621, 0.014945968, 0.014020392,
0.013109165, 0.012212592, 0.011331013, 0.01046481, 0.009614414,
0.008780315, 0.007963077, 0.0071633533, 0.006381906,
0.0056196423, 0.0048776558, 0.004157295, 0.0034602648,
0.0027887989, 0.0021459677, 0.0015362998, 0.0009672693,
0.00045413437,
}

156
vendor/golang.org/x/exp/rand/normal.go generated vendored Normal file
View File

@ -0,0 +1,156 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
import (
"math"
)
/*
* Normal distribution
*
* See "The Ziggurat Method for Generating Random Variables"
* (Marsaglia & Tsang, 2000)
* http://www.jstatsoft.org/v05/i08/paper [pdf]
*/
const (
rn = 3.442619855899
)
func absInt32(i int32) uint32 {
if i < 0 {
return uint32(-i)
}
return uint32(i)
}
// NormFloat64 returns a normally distributed float64 in the range
// [-math.MaxFloat64, +math.MaxFloat64] with
// standard normal distribution (mean = 0, stddev = 1).
// To produce a different normal distribution, callers can
// adjust the output using:
//
// sample = NormFloat64() * desiredStdDev + desiredMean
func (r *Rand) NormFloat64() float64 {
for {
j := int32(r.Uint32()) // Possibly negative
i := j & 0x7F
x := float64(j) * float64(wn[i])
if absInt32(j) < kn[i] {
// This case should be hit better than 99% of the time.
return x
}
if i == 0 {
// This extra work is only required for the base strip.
for {
x = -math.Log(r.Float64()) * (1.0 / rn)
y := -math.Log(r.Float64())
if y+y >= x*x {
break
}
}
if j > 0 {
return rn + x
}
return -rn - x
}
if fn[i]+float32(r.Float64())*(fn[i-1]-fn[i]) < float32(math.Exp(-.5*x*x)) {
return x
}
}
}
var kn = [128]uint32{
0x76ad2212, 0x0, 0x600f1b53, 0x6ce447a6, 0x725b46a2,
0x7560051d, 0x774921eb, 0x789a25bd, 0x799045c3, 0x7a4bce5d,
0x7adf629f, 0x7b5682a6, 0x7bb8a8c6, 0x7c0ae722, 0x7c50cce7,
0x7c8cec5b, 0x7cc12cd6, 0x7ceefed2, 0x7d177e0b, 0x7d3b8883,
0x7d5bce6c, 0x7d78dd64, 0x7d932886, 0x7dab0e57, 0x7dc0dd30,
0x7dd4d688, 0x7de73185, 0x7df81cea, 0x7e07c0a3, 0x7e163efa,
0x7e23b587, 0x7e303dfd, 0x7e3beec2, 0x7e46db77, 0x7e51155d,
0x7e5aabb3, 0x7e63abf7, 0x7e6c222c, 0x7e741906, 0x7e7b9a18,
0x7e82adfa, 0x7e895c63, 0x7e8fac4b, 0x7e95a3fb, 0x7e9b4924,
0x7ea0a0ef, 0x7ea5b00d, 0x7eaa7ac3, 0x7eaf04f3, 0x7eb3522a,
0x7eb765a5, 0x7ebb4259, 0x7ebeeafd, 0x7ec2620a, 0x7ec5a9c4,
0x7ec8c441, 0x7ecbb365, 0x7ece78ed, 0x7ed11671, 0x7ed38d62,
0x7ed5df12, 0x7ed80cb4, 0x7eda175c, 0x7edc0005, 0x7eddc78e,
0x7edf6ebf, 0x7ee0f647, 0x7ee25ebe, 0x7ee3a8a9, 0x7ee4d473,
0x7ee5e276, 0x7ee6d2f5, 0x7ee7a620, 0x7ee85c10, 0x7ee8f4cd,
0x7ee97047, 0x7ee9ce59, 0x7eea0eca, 0x7eea3147, 0x7eea3568,
0x7eea1aab, 0x7ee9e071, 0x7ee98602, 0x7ee90a88, 0x7ee86d08,
0x7ee7ac6a, 0x7ee6c769, 0x7ee5bc9c, 0x7ee48a67, 0x7ee32efc,
0x7ee1a857, 0x7edff42f, 0x7ede0ffa, 0x7edbf8d9, 0x7ed9ab94,
0x7ed7248d, 0x7ed45fae, 0x7ed1585c, 0x7ece095f, 0x7eca6ccb,
0x7ec67be2, 0x7ec22eee, 0x7ebd7d1a, 0x7eb85c35, 0x7eb2c075,
0x7eac9c20, 0x7ea5df27, 0x7e9e769f, 0x7e964c16, 0x7e8d44ba,
0x7e834033, 0x7e781728, 0x7e6b9933, 0x7e5d8a1a, 0x7e4d9ded,
0x7e3b737a, 0x7e268c2f, 0x7e0e3ff5, 0x7df1aa5d, 0x7dcf8c72,
0x7da61a1e, 0x7d72a0fb, 0x7d30e097, 0x7cd9b4ab, 0x7c600f1a,
0x7ba90bdc, 0x7a722176, 0x77d664e5,
}
var wn = [128]float32{
1.7290405e-09, 1.2680929e-10, 1.6897518e-10, 1.9862688e-10,
2.2232431e-10, 2.4244937e-10, 2.601613e-10, 2.7611988e-10,
2.9073963e-10, 3.042997e-10, 3.1699796e-10, 3.289802e-10,
3.4035738e-10, 3.5121603e-10, 3.616251e-10, 3.7164058e-10,
3.8130857e-10, 3.9066758e-10, 3.9975012e-10, 4.08584e-10,
4.1719309e-10, 4.2559822e-10, 4.338176e-10, 4.418672e-10,
4.497613e-10, 4.5751258e-10, 4.651324e-10, 4.7263105e-10,
4.8001775e-10, 4.87301e-10, 4.944885e-10, 5.015873e-10,
5.0860405e-10, 5.155446e-10, 5.2241467e-10, 5.2921934e-10,
5.359635e-10, 5.426517e-10, 5.4928817e-10, 5.5587696e-10,
5.624219e-10, 5.6892646e-10, 5.753941e-10, 5.818282e-10,
5.882317e-10, 5.946077e-10, 6.00959e-10, 6.072884e-10,
6.135985e-10, 6.19892e-10, 6.2617134e-10, 6.3243905e-10,
6.386974e-10, 6.449488e-10, 6.511956e-10, 6.5744005e-10,
6.6368433e-10, 6.699307e-10, 6.7618144e-10, 6.824387e-10,
6.8870465e-10, 6.949815e-10, 7.012715e-10, 7.075768e-10,
7.1389966e-10, 7.202424e-10, 7.266073e-10, 7.329966e-10,
7.394128e-10, 7.4585826e-10, 7.5233547e-10, 7.58847e-10,
7.653954e-10, 7.719835e-10, 7.7861395e-10, 7.852897e-10,
7.920138e-10, 7.987892e-10, 8.0561924e-10, 8.125073e-10,
8.194569e-10, 8.2647167e-10, 8.3355556e-10, 8.407127e-10,
8.479473e-10, 8.55264e-10, 8.6266755e-10, 8.7016316e-10,
8.777562e-10, 8.8545243e-10, 8.932582e-10, 9.0117996e-10,
9.09225e-10, 9.174008e-10, 9.2571584e-10, 9.341788e-10,
9.427997e-10, 9.515889e-10, 9.605579e-10, 9.697193e-10,
9.790869e-10, 9.88676e-10, 9.985036e-10, 1.0085882e-09,
1.0189509e-09, 1.0296151e-09, 1.0406069e-09, 1.0519566e-09,
1.063698e-09, 1.0758702e-09, 1.0885183e-09, 1.1016947e-09,
1.1154611e-09, 1.1298902e-09, 1.1450696e-09, 1.1611052e-09,
1.1781276e-09, 1.1962995e-09, 1.2158287e-09, 1.2369856e-09,
1.2601323e-09, 1.2857697e-09, 1.3146202e-09, 1.347784e-09,
1.3870636e-09, 1.4357403e-09, 1.5008659e-09, 1.6030948e-09,
}
var fn = [128]float32{
1, 0.9635997, 0.9362827, 0.9130436, 0.89228165, 0.87324303,
0.8555006, 0.8387836, 0.8229072, 0.8077383, 0.793177,
0.7791461, 0.7655842, 0.7524416, 0.73967725, 0.7272569,
0.7151515, 0.7033361, 0.69178915, 0.68049186, 0.6694277,
0.658582, 0.6479418, 0.63749546, 0.6272325, 0.6171434,
0.6072195, 0.5974532, 0.58783704, 0.5783647, 0.56903,
0.5598274, 0.5507518, 0.54179835, 0.5329627, 0.52424055,
0.5156282, 0.50712204, 0.49871865, 0.49041483, 0.48220766,
0.4740943, 0.46607214, 0.4581387, 0.45029163, 0.44252872,
0.43484783, 0.427247, 0.41972435, 0.41227803, 0.40490642,
0.39760786, 0.3903808, 0.3832238, 0.37613547, 0.36911446,
0.3621595, 0.35526937, 0.34844297, 0.34167916, 0.33497685,
0.3283351, 0.3217529, 0.3152294, 0.30876362, 0.30235484,
0.29600215, 0.28970486, 0.2834622, 0.2772735, 0.27113807,
0.2650553, 0.25902456, 0.2530453, 0.24711695, 0.241239,
0.23541094, 0.22963232, 0.2239027, 0.21822165, 0.21258877,
0.20700371, 0.20146611, 0.19597565, 0.19053204, 0.18513499,
0.17978427, 0.17447963, 0.1692209, 0.16400786, 0.15884037,
0.15371831, 0.14864157, 0.14361008, 0.13862377, 0.13368265,
0.12878671, 0.12393598, 0.119130544, 0.11437051, 0.10965602,
0.104987256, 0.10036444, 0.095787846, 0.0912578, 0.08677467,
0.0823389, 0.077950984, 0.073611505, 0.06932112, 0.06508058,
0.06089077, 0.056752663, 0.0526674, 0.048636295, 0.044660863,
0.040742867, 0.03688439, 0.033087887, 0.029356318,
0.025693292, 0.022103304, 0.018592102, 0.015167298,
0.011839478, 0.008624485, 0.005548995, 0.0026696292,
}

372
vendor/golang.org/x/exp/rand/rand.go generated vendored Normal file
View File

@ -0,0 +1,372 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package rand implements pseudo-random number generators.
//
// Random numbers are generated by a Source. Top-level functions, such as
// Float64 and Int, use a default shared Source that produces a deterministic
// sequence of values each time a program is run. Use the Seed function to
// initialize the default Source if different behavior is required for each run.
// The default Source, a LockedSource, is safe for concurrent use by multiple
// goroutines, but Sources created by NewSource are not. However, Sources are small
// and it is reasonable to have a separate Source for each goroutine, seeded
// differently, to avoid locking.
//
// For random numbers suitable for security-sensitive work, see the crypto/rand
// package.
package rand
import "sync"
// A Source represents a source of uniformly-distributed
// pseudo-random int64 values in the range [0, 1<<64).
type Source interface {
Uint64() uint64
Seed(seed uint64)
}
// NewSource returns a new pseudo-random Source seeded with the given value.
func NewSource(seed uint64) Source {
var rng PCGSource
rng.Seed(seed)
return &rng
}
// A Rand is a source of random numbers.
type Rand struct {
src Source
// readVal contains remainder of 64-bit integer used for bytes
// generation during most recent Read call.
// It is saved so next Read call can start where the previous
// one finished.
readVal uint64
// readPos indicates the number of low-order bytes of readVal
// that are still valid.
readPos int8
}
// New returns a new Rand that uses random values from src
// to generate other random values.
func New(src Source) *Rand {
return &Rand{src: src}
}
// Seed uses the provided seed value to initialize the generator to a deterministic state.
// Seed should not be called concurrently with any other Rand method.
func (r *Rand) Seed(seed uint64) {
if lk, ok := r.src.(*LockedSource); ok {
lk.seedPos(seed, &r.readPos)
return
}
r.src.Seed(seed)
r.readPos = 0
}
// Uint64 returns a pseudo-random 64-bit integer as a uint64.
func (r *Rand) Uint64() uint64 { return r.src.Uint64() }
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64.
func (r *Rand) Int63() int64 { return int64(r.src.Uint64() &^ (1 << 63)) }
// Uint32 returns a pseudo-random 32-bit value as a uint32.
func (r *Rand) Uint32() uint32 { return uint32(r.Uint64() >> 32) }
// Int31 returns a non-negative pseudo-random 31-bit integer as an int32.
func (r *Rand) Int31() int32 { return int32(r.Uint64() >> 33) }
// Int returns a non-negative pseudo-random int.
func (r *Rand) Int() int {
u := uint(r.Uint64())
return int(u << 1 >> 1) // clear sign bit.
}
const maxUint64 = (1 << 64) - 1
// Uint64n returns, as a uint64, a pseudo-random number in [0,n).
// It is guaranteed more uniform than taking a Source value mod n
// for any n that is not a power of 2.
func (r *Rand) Uint64n(n uint64) uint64 {
if n&(n-1) == 0 { // n is power of two, can mask
if n == 0 {
panic("invalid argument to Uint64n")
}
return r.Uint64() & (n - 1)
}
// If n does not divide v, to avoid bias we must not use
// a v that is within maxUint64%n of the top of the range.
v := r.Uint64()
if v > maxUint64-n { // Fast check.
ceiling := maxUint64 - maxUint64%n
for v >= ceiling {
v = r.Uint64()
}
}
return v % n
}
// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n).
// It panics if n <= 0.
func (r *Rand) Int63n(n int64) int64 {
if n <= 0 {
panic("invalid argument to Int63n")
}
return int64(r.Uint64n(uint64(n)))
}
// Int31n returns, as an int32, a non-negative pseudo-random number in [0,n).
// It panics if n <= 0.
func (r *Rand) Int31n(n int32) int32 {
if n <= 0 {
panic("invalid argument to Int31n")
}
// TODO: Avoid some 64-bit ops to make it more efficient on 32-bit machines.
return int32(r.Uint64n(uint64(n)))
}
// Intn returns, as an int, a non-negative pseudo-random number in [0,n).
// It panics if n <= 0.
func (r *Rand) Intn(n int) int {
if n <= 0 {
panic("invalid argument to Intn")
}
// TODO: Avoid some 64-bit ops to make it more efficient on 32-bit machines.
return int(r.Uint64n(uint64(n)))
}
// Float64 returns, as a float64, a pseudo-random number in [0.0,1.0).
func (r *Rand) Float64() float64 {
// There is one bug in the value stream: r.Int63() may be so close
// to 1<<63 that the division rounds up to 1.0, and we've guaranteed
// that the result is always less than 1.0.
//
// We tried to fix this by mapping 1.0 back to 0.0, but since float64
// values near 0 are much denser than near 1, mapping 1 to 0 caused
// a theoretically significant overshoot in the probability of returning 0.
// Instead of that, if we round up to 1, just try again.
// Getting 1 only happens 1/2⁵³ of the time, so most clients
// will not observe it anyway.
again:
f := float64(r.Uint64n(1<<53)) / (1 << 53)
if f == 1.0 {
goto again // resample; this branch is taken O(never)
}
return f
}
// Float32 returns, as a float32, a pseudo-random number in [0.0,1.0).
func (r *Rand) Float32() float32 {
// We do not want to return 1.0.
// This only happens 1/2²⁴ of the time (plus the 1/2⁵³ of the time in Float64).
again:
f := float32(r.Float64())
if f == 1 {
goto again // resample; this branch is taken O(very rarely)
}
return f
}
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n).
func (r *Rand) Perm(n int) []int {
m := make([]int, n)
// In the following loop, the iteration when i=0 always swaps m[0] with m[0].
// A change to remove this useless iteration is to assign 1 to i in the init
// statement. But Perm also effects r. Making this change will affect
// the final state of r. So this change can't be made for compatibility
// reasons for Go 1.
for i := 0; i < n; i++ {
j := r.Intn(i + 1)
m[i] = m[j]
m[j] = i
}
return m
}
// Shuffle pseudo-randomizes the order of elements.
// n is the number of elements. Shuffle panics if n < 0.
// swap swaps the elements with indexes i and j.
func (r *Rand) Shuffle(n int, swap func(i, j int)) {
if n < 0 {
panic("invalid argument to Shuffle")
}
// Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
// Shuffle really ought not be called with n that doesn't fit in 32 bits.
// Not only will it take a very long time, but with 2³¹! possible permutations,
// there's no way that any PRNG can have a big enough internal state to
// generate even a minuscule percentage of the possible permutations.
// Nevertheless, the right API signature accepts an int n, so handle it as best we can.
i := n - 1
for ; i > 1<<31-1-1; i-- {
j := int(r.Int63n(int64(i + 1)))
swap(i, j)
}
for ; i > 0; i-- {
j := int(r.Int31n(int32(i + 1)))
swap(i, j)
}
}
// Read generates len(p) random bytes and writes them into p. It
// always returns len(p) and a nil error.
// Read should not be called concurrently with any other Rand method unless
// the underlying source is a LockedSource.
func (r *Rand) Read(p []byte) (n int, err error) {
if lk, ok := r.src.(*LockedSource); ok {
return lk.Read(p, &r.readVal, &r.readPos)
}
return read(p, r.src, &r.readVal, &r.readPos)
}
func read(p []byte, src Source, readVal *uint64, readPos *int8) (n int, err error) {
pos := *readPos
val := *readVal
rng, _ := src.(*PCGSource)
for n = 0; n < len(p); n++ {
if pos == 0 {
if rng != nil {
val = rng.Uint64()
} else {
val = src.Uint64()
}
pos = 8
}
p[n] = byte(val)
val >>= 8
pos--
}
*readPos = pos
*readVal = val
return
}
/*
* Top-level convenience functions
*/
var globalRand = New(&LockedSource{src: *NewSource(1).(*PCGSource)})
// Type assert that globalRand's source is a LockedSource whose src is a PCGSource.
var _ PCGSource = globalRand.src.(*LockedSource).src
// Seed uses the provided seed value to initialize the default Source to a
// deterministic state. If Seed is not called, the generator behaves as
// if seeded by Seed(1).
// Seed, unlike the Rand.Seed method, is safe for concurrent use.
func Seed(seed uint64) { globalRand.Seed(seed) }
// Int63 returns a non-negative pseudo-random 63-bit integer as an int64
// from the default Source.
func Int63() int64 { return globalRand.Int63() }
// Uint32 returns a pseudo-random 32-bit value as a uint32
// from the default Source.
func Uint32() uint32 { return globalRand.Uint32() }
// Uint64 returns a pseudo-random 64-bit value as a uint64
// from the default Source.
func Uint64() uint64 { return globalRand.Uint64() }
// Int31 returns a non-negative pseudo-random 31-bit integer as an int32
// from the default Source.
func Int31() int32 { return globalRand.Int31() }
// Int returns a non-negative pseudo-random int from the default Source.
func Int() int { return globalRand.Int() }
// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n)
// from the default Source.
// It panics if n <= 0.
func Int63n(n int64) int64 { return globalRand.Int63n(n) }
// Int31n returns, as an int32, a non-negative pseudo-random number in [0,n)
// from the default Source.
// It panics if n <= 0.
func Int31n(n int32) int32 { return globalRand.Int31n(n) }
// Intn returns, as an int, a non-negative pseudo-random number in [0,n)
// from the default Source.
// It panics if n <= 0.
func Intn(n int) int { return globalRand.Intn(n) }
// Float64 returns, as a float64, a pseudo-random number in [0.0,1.0)
// from the default Source.
func Float64() float64 { return globalRand.Float64() }
// Float32 returns, as a float32, a pseudo-random number in [0.0,1.0)
// from the default Source.
func Float32() float32 { return globalRand.Float32() }
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n)
// from the default Source.
func Perm(n int) []int { return globalRand.Perm(n) }
// Shuffle pseudo-randomizes the order of elements using the default Source.
// n is the number of elements. Shuffle panics if n < 0.
// swap swaps the elements with indexes i and j.
func Shuffle(n int, swap func(i, j int)) { globalRand.Shuffle(n, swap) }
// Read generates len(p) random bytes from the default Source and
// writes them into p. It always returns len(p) and a nil error.
// Read, unlike the Rand.Read method, is safe for concurrent use.
func Read(p []byte) (n int, err error) { return globalRand.Read(p) }
// NormFloat64 returns a normally distributed float64 in the range
// [-math.MaxFloat64, +math.MaxFloat64] with
// standard normal distribution (mean = 0, stddev = 1)
// from the default Source.
// To produce a different normal distribution, callers can
// adjust the output using:
//
// sample = NormFloat64() * desiredStdDev + desiredMean
func NormFloat64() float64 { return globalRand.NormFloat64() }
// ExpFloat64 returns an exponentially distributed float64 in the range
// (0, +math.MaxFloat64] with an exponential distribution whose rate parameter
// (lambda) is 1 and whose mean is 1/lambda (1) from the default Source.
// To produce a distribution with a different rate parameter,
// callers can adjust the output using:
//
// sample = ExpFloat64() / desiredRateParameter
func ExpFloat64() float64 { return globalRand.ExpFloat64() }
// LockedSource is an implementation of Source that is concurrency-safe.
// A Rand using a LockedSource is safe for concurrent use.
//
// The zero value of LockedSource is valid, but should be seeded before use.
type LockedSource struct {
lk sync.Mutex
src PCGSource
}
func (s *LockedSource) Uint64() (n uint64) {
s.lk.Lock()
n = s.src.Uint64()
s.lk.Unlock()
return
}
func (s *LockedSource) Seed(seed uint64) {
s.lk.Lock()
s.src.Seed(seed)
s.lk.Unlock()
}
// seedPos implements Seed for a LockedSource without a race condiiton.
func (s *LockedSource) seedPos(seed uint64, readPos *int8) {
s.lk.Lock()
s.src.Seed(seed)
*readPos = 0
s.lk.Unlock()
}
// Read implements Read for a LockedSource.
func (s *LockedSource) Read(p []byte, readVal *uint64, readPos *int8) (n int, err error) {
s.lk.Lock()
n, err = read(p, &s.src, readVal, readPos)
s.lk.Unlock()
return
}

91
vendor/golang.org/x/exp/rand/rng.go generated vendored Normal file
View File

@ -0,0 +1,91 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rand
import (
"encoding/binary"
"io"
"math/bits"
)
// PCGSource is an implementation of a 64-bit permuted congruential
// generator as defined in
//
// PCG: A Family of Simple Fast Space-Efficient Statistically Good
// Algorithms for Random Number Generation
// Melissa E. ONeill, Harvey Mudd College
// http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf
//
// The generator here is the congruential generator PCG XSL RR 128/64 (LCG)
// as found in the software available at http://www.pcg-random.org/.
// It has period 2^128 with 128 bits of state, producing 64-bit values.
// Is state is represented by two uint64 words.
type PCGSource struct {
low uint64
high uint64
}
const (
maxUint32 = (1 << 32) - 1
multiplier = 47026247687942121848144207491837523525
mulHigh = multiplier >> 64
mulLow = multiplier & maxUint64
increment = 117397592171526113268558934119004209487
incHigh = increment >> 64
incLow = increment & maxUint64
// TODO: Use these?
initializer = 245720598905631564143578724636268694099
initHigh = initializer >> 64
initLow = initializer & maxUint64
)
// Seed uses the provided seed value to initialize the generator to a deterministic state.
func (pcg *PCGSource) Seed(seed uint64) {
pcg.low = seed
pcg.high = seed // TODO: What is right?
}
// Uint64 returns a pseudo-random 64-bit unsigned integer as a uint64.
func (pcg *PCGSource) Uint64() uint64 {
pcg.multiply()
pcg.add()
// XOR high and low 64 bits together and rotate right by high 6 bits of state.
return bits.RotateLeft64(pcg.high^pcg.low, -int(pcg.high>>58))
}
func (pcg *PCGSource) add() {
var carry uint64
pcg.low, carry = bits.Add64(pcg.low, incLow, 0)
pcg.high, _ = bits.Add64(pcg.high, incHigh, carry)
}
func (pcg *PCGSource) multiply() {
hi, lo := bits.Mul64(pcg.low, mulLow)
hi += pcg.high * mulLow
hi += pcg.low * mulHigh
pcg.low = lo
pcg.high = hi
}
// MarshalBinary returns the binary representation of the current state of the generator.
func (pcg *PCGSource) MarshalBinary() ([]byte, error) {
var buf [16]byte
binary.BigEndian.PutUint64(buf[:8], pcg.high)
binary.BigEndian.PutUint64(buf[8:], pcg.low)
return buf[:], nil
}
// UnmarshalBinary sets the state of the generator to the state represented in data.
func (pcg *PCGSource) UnmarshalBinary(data []byte) error {
if len(data) < 16 {
return io.ErrUnexpectedEOF
}
pcg.low = binary.BigEndian.Uint64(data[8:])
pcg.high = binary.BigEndian.Uint64(data[:8])
return nil
}

77
vendor/golang.org/x/exp/rand/zipf.go generated vendored Normal file
View File

@ -0,0 +1,77 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// W.Hormann, G.Derflinger:
// "Rejection-Inversion to Generate Variates
// from Monotone Discrete Distributions"
// http://eeyore.wu-wien.ac.at/papers/96-04-04.wh-der.ps.gz
package rand
import "math"
// A Zipf generates Zipf distributed variates.
type Zipf struct {
r *Rand
imax float64
v float64
q float64
s float64
oneminusQ float64
oneminusQinv float64
hxm float64
hx0minusHxm float64
}
func (z *Zipf) h(x float64) float64 {
return math.Exp(z.oneminusQ*math.Log(z.v+x)) * z.oneminusQinv
}
func (z *Zipf) hinv(x float64) float64 {
return math.Exp(z.oneminusQinv*math.Log(z.oneminusQ*x)) - z.v
}
// NewZipf returns a Zipf variate generator.
// The generator generates values k ∈ [0, imax]
// such that P(k) is proportional to (v + k) ** (-s).
// Requirements: s > 1 and v >= 1.
func NewZipf(r *Rand, s float64, v float64, imax uint64) *Zipf {
z := new(Zipf)
if s <= 1.0 || v < 1 {
return nil
}
z.r = r
z.imax = float64(imax)
z.v = v
z.q = s
z.oneminusQ = 1.0 - z.q
z.oneminusQinv = 1.0 / z.oneminusQ
z.hxm = z.h(z.imax + 0.5)
z.hx0minusHxm = z.h(0.5) - math.Exp(math.Log(z.v)*(-z.q)) - z.hxm
z.s = 1 - z.hinv(z.h(1.5)-math.Exp(-z.q*math.Log(z.v+1.0)))
return z
}
// Uint64 returns a value drawn from the Zipf distribution described
// by the Zipf object.
func (z *Zipf) Uint64() uint64 {
if z == nil {
panic("rand: nil Zipf")
}
k := 0.0
for {
r := z.r.Float64() // r on [0,1]
ur := z.hxm + r*z.hx0minusHxm
x := z.hinv(ur)
k = math.Floor(x + 0.5)
if k-x <= z.s {
break
}
if ur >= z.h(k+0.5)-math.Exp(-math.Log(k+z.v)*z.q) {
break
}
}
return uint64(k)
}

View File

@ -581,9 +581,11 @@ type serverConn struct {
advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
curClientStreams uint32 // number of open streams initiated by the client curClientStreams uint32 // number of open streams initiated by the client
curPushedStreams uint32 // number of open streams initiated by server push curPushedStreams uint32 // number of open streams initiated by server push
curHandlers uint32 // number of running handler goroutines
maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests
maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes
streams map[uint32]*stream streams map[uint32]*stream
unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32 initialStreamSendWindowSize int32
maxFrameSize int32 maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default) peerMaxHeaderListSize uint32 // zero means unknown (default)
@ -981,6 +983,8 @@ func (sc *serverConn) serve() {
return return
case gracefulShutdownMsg: case gracefulShutdownMsg:
sc.startGracefulShutdownInternal() sc.startGracefulShutdownInternal()
case handlerDoneMsg:
sc.handlerDone()
default: default:
panic("unknown timer") panic("unknown timer")
} }
@ -1020,6 +1024,7 @@ var (
idleTimerMsg = new(serverMessage) idleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage) shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage) gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage)
) )
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
@ -1892,9 +1897,11 @@ func (st *stream) copyTrailersToHandlerRequest() {
// onReadTimeout is run on its own goroutine (from time.AfterFunc) // onReadTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's ReadTimeout has fired. // when the stream's ReadTimeout has fired.
func (st *stream) onReadTimeout() { func (st *stream) onReadTimeout() {
// Wrap the ErrDeadlineExceeded to avoid callers depending on us if st.body != nil {
// returning the bare error. // Wrap the ErrDeadlineExceeded to avoid callers depending on us
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded)) // returning the bare error.
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
}
} }
// onWriteTimeout is run on its own goroutine (from time.AfterFunc) // onWriteTimeout is run on its own goroutine (from time.AfterFunc)
@ -2012,13 +2019,10 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway. // (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout != 0 { if sc.hs.ReadTimeout != 0 {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
if st.body != nil { st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
} }
go sc.runHandler(rw, req, handler) return sc.scheduleHandler(id, rw, req, handler)
return nil
} }
func (sc *serverConn) upgradeRequest(req *http.Request) { func (sc *serverConn) upgradeRequest(req *http.Request) {
@ -2038,6 +2042,10 @@ func (sc *serverConn) upgradeRequest(req *http.Request) {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
} }
// This is the first request on the connection,
// so start the handler directly rather than going
// through scheduleHandler.
sc.curHandlers++
go sc.runHandler(rw, req, sc.handler.ServeHTTP) go sc.runHandler(rw, req, sc.handler.ServeHTTP)
} }
@ -2278,8 +2286,62 @@ func (sc *serverConn) newResponseWriter(st *stream, req *http.Request) *response
return &responseWriter{rws: rws} return &responseWriter{rws: rws}
} }
type unstartedHandler struct {
streamID uint32
rw *responseWriter
req *http.Request
handler func(http.ResponseWriter, *http.Request)
}
// scheduleHandler starts a handler goroutine,
// or schedules one to start as soon as an existing handler finishes.
func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) error {
sc.serveG.check()
maxHandlers := sc.advMaxStreams
if sc.curHandlers < maxHandlers {
sc.curHandlers++
go sc.runHandler(rw, req, handler)
return nil
}
if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) {
return sc.countError("too_many_early_resets", ConnectionError(ErrCodeEnhanceYourCalm))
}
sc.unstartedHandlers = append(sc.unstartedHandlers, unstartedHandler{
streamID: streamID,
rw: rw,
req: req,
handler: handler,
})
return nil
}
func (sc *serverConn) handlerDone() {
sc.serveG.check()
sc.curHandlers--
i := 0
maxHandlers := sc.advMaxStreams
for ; i < len(sc.unstartedHandlers); i++ {
u := sc.unstartedHandlers[i]
if sc.streams[u.streamID] == nil {
// This stream was reset before its goroutine had a chance to start.
continue
}
if sc.curHandlers >= maxHandlers {
break
}
sc.curHandlers++
go sc.runHandler(u.rw, u.req, u.handler)
sc.unstartedHandlers[i] = unstartedHandler{} // don't retain references
}
sc.unstartedHandlers = sc.unstartedHandlers[i:]
if len(sc.unstartedHandlers) == 0 {
sc.unstartedHandlers = nil
}
}
// Run on its own goroutine. // Run on its own goroutine.
func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
defer sc.sendServeMsg(handlerDoneMsg)
didPanic := true didPanic := true
defer func() { defer func() {
rw.rws.stream.cancelCtx() rw.rws.stream.cancelCtx()

View File

@ -7,6 +7,6 @@
package cpu package cpu
const cacheLineSize = 32 const cacheLineSize = 64
func initOptions() {} func initOptions() {}

View File

@ -5,7 +5,7 @@
package cpu package cpu
import ( import (
"io/ioutil" "os"
) )
const ( const (
@ -39,7 +39,7 @@ func readHWCAP() error {
return nil return nil
} }
buf, err := ioutil.ReadFile(procAuxv) buf, err := os.ReadFile(procAuxv)
if err != nil { if err != nil {
// e.g. on android /proc/self/auxv is not accessible, so silently // e.g. on android /proc/self/auxv is not accessible, so silently
// ignore the error and leave Initialized = false. On some // ignore the error and leave Initialized = false. On some

View File

@ -1,30 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package unsafeheader contains header declarations for the Go runtime's
// slice and string implementations.
//
// This package allows x/sys to use types equivalent to
// reflect.SliceHeader and reflect.StringHeader without introducing
// a dependency on the (relatively heavy) "reflect" package.
package unsafeheader
import (
"unsafe"
)
// Slice is the runtime representation of a slice.
// It cannot be used safely or portably and its representation may change in a later release.
type Slice struct {
Data unsafe.Pointer
Len int
Cap int
}
// String is the runtime representation of a string.
// It cannot be used safely or portably and its representation may change in a later release.
type String struct {
Data unsafe.Pointer
Len int
}

View File

@ -7,12 +7,6 @@
package unix package unix
import "unsafe"
func ptrace(request int, pid int, addr uintptr, data uintptr) error { func ptrace(request int, pid int, addr uintptr, data uintptr) error {
return ptrace1(request, pid, addr, data) return ptrace1(request, pid, addr, data)
} }
func ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) error {
return ptrace1Ptr(request, pid, addr, data)
}

View File

@ -7,12 +7,6 @@
package unix package unix
import "unsafe"
func ptrace(request int, pid int, addr uintptr, data uintptr) (err error) { func ptrace(request int, pid int, addr uintptr, data uintptr) (err error) {
return ENOTSUP return ENOTSUP
} }
func ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) (err error) {
return ENOTSUP
}

View File

@ -487,8 +487,6 @@ func Fsync(fd int) error {
//sys Unlinkat(dirfd int, path string, flags int) (err error) //sys Unlinkat(dirfd int, path string, flags int) (err error)
//sys Ustat(dev int, ubuf *Ustat_t) (err error) //sys Ustat(dev int, ubuf *Ustat_t) (err error)
//sys write(fd int, p []byte) (n int, err error) //sys write(fd int, p []byte) (n int, err error)
//sys readlen(fd int, p *byte, np int) (n int, err error) = read
//sys writelen(fd int, p *byte, np int) (n int, err error) = write
//sys Dup2(oldfd int, newfd int) (err error) //sys Dup2(oldfd int, newfd int) (err error)
//sys Fadvise(fd int, offset int64, length int64, advice int) (err error) = posix_fadvise64 //sys Fadvise(fd int, offset int64, length int64, advice int) (err error) = posix_fadvise64

View File

@ -644,189 +644,3 @@ func SysctlKinfoProcSlice(name string, args ...int) ([]KinfoProc, error) {
//sys write(fd int, p []byte) (n int, err error) //sys write(fd int, p []byte) (n int, err error)
//sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error) //sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error)
//sys munmap(addr uintptr, length uintptr) (err error) //sys munmap(addr uintptr, length uintptr) (err error)
//sys readlen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_READ
//sys writelen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_WRITE
/*
* Unimplemented
*/
// Profil
// Sigaction
// Sigprocmask
// Getlogin
// Sigpending
// Sigaltstack
// Ioctl
// Reboot
// Execve
// Vfork
// Sbrk
// Sstk
// Ovadvise
// Mincore
// Setitimer
// Swapon
// Select
// Sigsuspend
// Readv
// Writev
// Nfssvc
// Getfh
// Quotactl
// Csops
// Waitid
// Add_profil
// Kdebug_trace
// Sigreturn
// Atsocket
// Kqueue_from_portset_np
// Kqueue_portset
// Getattrlist
// Getdirentriesattr
// Searchfs
// Delete
// Copyfile
// Watchevent
// Waitevent
// Modwatch
// Fsctl
// Initgroups
// Posix_spawn
// Nfsclnt
// Fhopen
// Minherit
// Semsys
// Msgsys
// Shmsys
// Semctl
// Semget
// Semop
// Msgctl
// Msgget
// Msgsnd
// Msgrcv
// Shm_open
// Shm_unlink
// Sem_open
// Sem_close
// Sem_unlink
// Sem_wait
// Sem_trywait
// Sem_post
// Sem_getvalue
// Sem_init
// Sem_destroy
// Open_extended
// Umask_extended
// Stat_extended
// Lstat_extended
// Fstat_extended
// Chmod_extended
// Fchmod_extended
// Access_extended
// Settid
// Gettid
// Setsgroups
// Getsgroups
// Setwgroups
// Getwgroups
// Mkfifo_extended
// Mkdir_extended
// Identitysvc
// Shared_region_check_np
// Shared_region_map_np
// __pthread_mutex_destroy
// __pthread_mutex_init
// __pthread_mutex_lock
// __pthread_mutex_trylock
// __pthread_mutex_unlock
// __pthread_cond_init
// __pthread_cond_destroy
// __pthread_cond_broadcast
// __pthread_cond_signal
// Setsid_with_pid
// __pthread_cond_timedwait
// Aio_fsync
// Aio_return
// Aio_suspend
// Aio_cancel
// Aio_error
// Aio_read
// Aio_write
// Lio_listio
// __pthread_cond_wait
// Iopolicysys
// __pthread_kill
// __pthread_sigmask
// __sigwait
// __disable_threadsignal
// __pthread_markcancel
// __pthread_canceled
// __semwait_signal
// Proc_info
// sendfile
// Stat64_extended
// Lstat64_extended
// Fstat64_extended
// __pthread_chdir
// __pthread_fchdir
// Audit
// Auditon
// Getauid
// Setauid
// Getaudit
// Setaudit
// Getaudit_addr
// Setaudit_addr
// Auditctl
// Bsdthread_create
// Bsdthread_terminate
// Stack_snapshot
// Bsdthread_register
// Workq_open
// Workq_ops
// __mac_execve
// __mac_syscall
// __mac_get_file
// __mac_set_file
// __mac_get_link
// __mac_set_link
// __mac_get_proc
// __mac_set_proc
// __mac_get_fd
// __mac_set_fd
// __mac_get_pid
// __mac_get_lcid
// __mac_get_lctx
// __mac_set_lctx
// Setlcid
// Read_nocancel
// Write_nocancel
// Open_nocancel
// Close_nocancel
// Wait4_nocancel
// Recvmsg_nocancel
// Sendmsg_nocancel
// Recvfrom_nocancel
// Accept_nocancel
// Fcntl_nocancel
// Select_nocancel
// Fsync_nocancel
// Connect_nocancel
// Sigsuspend_nocancel
// Readv_nocancel
// Writev_nocancel
// Sendto_nocancel
// Pread_nocancel
// Pwrite_nocancel
// Waitid_nocancel
// Poll_nocancel
// Msgsnd_nocancel
// Msgrcv_nocancel
// Sem_wait_nocancel
// Aio_suspend_nocancel
// __sigwait_nocancel
// __semwait_signal_nocancel
// __mac_mount
// __mac_get_mount
// __mac_getfsstat

View File

@ -47,6 +47,5 @@ func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr,
//sys getfsstat(buf unsafe.Pointer, size uintptr, flags int) (n int, err error) = SYS_GETFSSTAT64 //sys getfsstat(buf unsafe.Pointer, size uintptr, flags int) (n int, err error) = SYS_GETFSSTAT64
//sys Lstat(path string, stat *Stat_t) (err error) = SYS_LSTAT64 //sys Lstat(path string, stat *Stat_t) (err error) = SYS_LSTAT64
//sys ptrace1(request int, pid int, addr uintptr, data uintptr) (err error) = SYS_ptrace //sys ptrace1(request int, pid int, addr uintptr, data uintptr) (err error) = SYS_ptrace
//sys ptrace1Ptr(request int, pid int, addr unsafe.Pointer, data uintptr) (err error) = SYS_ptrace
//sys Stat(path string, stat *Stat_t) (err error) = SYS_STAT64 //sys Stat(path string, stat *Stat_t) (err error) = SYS_STAT64
//sys Statfs(path string, stat *Statfs_t) (err error) = SYS_STATFS64 //sys Statfs(path string, stat *Statfs_t) (err error) = SYS_STATFS64

View File

@ -47,6 +47,5 @@ func Syscall9(num, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr,
//sys getfsstat(buf unsafe.Pointer, size uintptr, flags int) (n int, err error) = SYS_GETFSSTAT //sys getfsstat(buf unsafe.Pointer, size uintptr, flags int) (n int, err error) = SYS_GETFSSTAT
//sys Lstat(path string, stat *Stat_t) (err error) //sys Lstat(path string, stat *Stat_t) (err error)
//sys ptrace1(request int, pid int, addr uintptr, data uintptr) (err error) = SYS_ptrace //sys ptrace1(request int, pid int, addr uintptr, data uintptr) (err error) = SYS_ptrace
//sys ptrace1Ptr(request int, pid int, addr unsafe.Pointer, data uintptr) (err error) = SYS_ptrace
//sys Stat(path string, stat *Stat_t) (err error) //sys Stat(path string, stat *Stat_t) (err error)
//sys Statfs(path string, stat *Statfs_t) (err error) //sys Statfs(path string, stat *Statfs_t) (err error)

View File

@ -343,203 +343,5 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
//sys write(fd int, p []byte) (n int, err error) //sys write(fd int, p []byte) (n int, err error)
//sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error) //sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error)
//sys munmap(addr uintptr, length uintptr) (err error) //sys munmap(addr uintptr, length uintptr) (err error)
//sys readlen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_READ
//sys writelen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_WRITE
//sys accept4(fd int, rsa *RawSockaddrAny, addrlen *_Socklen, flags int) (nfd int, err error) //sys accept4(fd int, rsa *RawSockaddrAny, addrlen *_Socklen, flags int) (nfd int, err error)
//sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error) //sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error)
/*
* Unimplemented
* TODO(jsing): Update this list for DragonFly.
*/
// Profil
// Sigaction
// Sigprocmask
// Getlogin
// Sigpending
// Sigaltstack
// Reboot
// Execve
// Vfork
// Sbrk
// Sstk
// Ovadvise
// Mincore
// Setitimer
// Swapon
// Select
// Sigsuspend
// Readv
// Writev
// Nfssvc
// Getfh
// Quotactl
// Mount
// Csops
// Waitid
// Add_profil
// Kdebug_trace
// Sigreturn
// Atsocket
// Kqueue_from_portset_np
// Kqueue_portset
// Getattrlist
// Setattrlist
// Getdirentriesattr
// Searchfs
// Delete
// Copyfile
// Watchevent
// Waitevent
// Modwatch
// Getxattr
// Fgetxattr
// Setxattr
// Fsetxattr
// Removexattr
// Fremovexattr
// Listxattr
// Flistxattr
// Fsctl
// Initgroups
// Posix_spawn
// Nfsclnt
// Fhopen
// Minherit
// Semsys
// Msgsys
// Shmsys
// Semctl
// Semget
// Semop
// Msgctl
// Msgget
// Msgsnd
// Msgrcv
// Shmat
// Shmctl
// Shmdt
// Shmget
// Shm_open
// Shm_unlink
// Sem_open
// Sem_close
// Sem_unlink
// Sem_wait
// Sem_trywait
// Sem_post
// Sem_getvalue
// Sem_init
// Sem_destroy
// Open_extended
// Umask_extended
// Stat_extended
// Lstat_extended
// Fstat_extended
// Chmod_extended
// Fchmod_extended
// Access_extended
// Settid
// Gettid
// Setsgroups
// Getsgroups
// Setwgroups
// Getwgroups
// Mkfifo_extended
// Mkdir_extended
// Identitysvc
// Shared_region_check_np
// Shared_region_map_np
// __pthread_mutex_destroy
// __pthread_mutex_init
// __pthread_mutex_lock
// __pthread_mutex_trylock
// __pthread_mutex_unlock
// __pthread_cond_init
// __pthread_cond_destroy
// __pthread_cond_broadcast
// __pthread_cond_signal
// Setsid_with_pid
// __pthread_cond_timedwait
// Aio_fsync
// Aio_return
// Aio_suspend
// Aio_cancel
// Aio_error
// Aio_read
// Aio_write
// Lio_listio
// __pthread_cond_wait
// Iopolicysys
// __pthread_kill
// __pthread_sigmask
// __sigwait
// __disable_threadsignal
// __pthread_markcancel
// __pthread_canceled
// __semwait_signal
// Proc_info
// Stat64_extended
// Lstat64_extended
// Fstat64_extended
// __pthread_chdir
// __pthread_fchdir
// Audit
// Auditon
// Getauid
// Setauid
// Getaudit
// Setaudit
// Getaudit_addr
// Setaudit_addr
// Auditctl
// Bsdthread_create
// Bsdthread_terminate
// Stack_snapshot
// Bsdthread_register
// Workq_open
// Workq_ops
// __mac_execve
// __mac_syscall
// __mac_get_file
// __mac_set_file
// __mac_get_link
// __mac_set_link
// __mac_get_proc
// __mac_set_proc
// __mac_get_fd
// __mac_set_fd
// __mac_get_pid
// __mac_get_lcid
// __mac_get_lctx
// __mac_set_lctx
// Setlcid
// Read_nocancel
// Write_nocancel
// Open_nocancel
// Close_nocancel
// Wait4_nocancel
// Recvmsg_nocancel
// Sendmsg_nocancel
// Recvfrom_nocancel
// Accept_nocancel
// Fcntl_nocancel
// Select_nocancel
// Fsync_nocancel
// Connect_nocancel
// Sigsuspend_nocancel
// Readv_nocancel
// Writev_nocancel
// Sendto_nocancel
// Pread_nocancel
// Pwrite_nocancel
// Waitid_nocancel
// Msgsnd_nocancel
// Msgrcv_nocancel
// Sem_wait_nocancel
// Aio_suspend_nocancel
// __sigwait_nocancel
// __semwait_signal_nocancel
// __mac_mount
// __mac_get_mount
// __mac_getfsstat

View File

@ -449,197 +449,5 @@ func Dup3(oldfd, newfd, flags int) error {
//sys write(fd int, p []byte) (n int, err error) //sys write(fd int, p []byte) (n int, err error)
//sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error) //sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error)
//sys munmap(addr uintptr, length uintptr) (err error) //sys munmap(addr uintptr, length uintptr) (err error)
//sys readlen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_READ
//sys writelen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_WRITE
//sys accept4(fd int, rsa *RawSockaddrAny, addrlen *_Socklen, flags int) (nfd int, err error) //sys accept4(fd int, rsa *RawSockaddrAny, addrlen *_Socklen, flags int) (nfd int, err error)
//sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error) //sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error)
/*
* Unimplemented
*/
// Profil
// Sigaction
// Sigprocmask
// Getlogin
// Sigpending
// Sigaltstack
// Ioctl
// Reboot
// Execve
// Vfork
// Sbrk
// Sstk
// Ovadvise
// Mincore
// Setitimer
// Swapon
// Select
// Sigsuspend
// Readv
// Writev
// Nfssvc
// Getfh
// Quotactl
// Mount
// Csops
// Waitid
// Add_profil
// Kdebug_trace
// Sigreturn
// Atsocket
// Kqueue_from_portset_np
// Kqueue_portset
// Getattrlist
// Setattrlist
// Getdents
// Getdirentriesattr
// Searchfs
// Delete
// Copyfile
// Watchevent
// Waitevent
// Modwatch
// Fsctl
// Initgroups
// Posix_spawn
// Nfsclnt
// Fhopen
// Minherit
// Semsys
// Msgsys
// Shmsys
// Semctl
// Semget
// Semop
// Msgctl
// Msgget
// Msgsnd
// Msgrcv
// Shmat
// Shmctl
// Shmdt
// Shmget
// Shm_open
// Shm_unlink
// Sem_open
// Sem_close
// Sem_unlink
// Sem_wait
// Sem_trywait
// Sem_post
// Sem_getvalue
// Sem_init
// Sem_destroy
// Open_extended
// Umask_extended
// Stat_extended
// Lstat_extended
// Fstat_extended
// Chmod_extended
// Fchmod_extended
// Access_extended
// Settid
// Gettid
// Setsgroups
// Getsgroups
// Setwgroups
// Getwgroups
// Mkfifo_extended
// Mkdir_extended
// Identitysvc
// Shared_region_check_np
// Shared_region_map_np
// __pthread_mutex_destroy
// __pthread_mutex_init
// __pthread_mutex_lock
// __pthread_mutex_trylock
// __pthread_mutex_unlock
// __pthread_cond_init
// __pthread_cond_destroy
// __pthread_cond_broadcast
// __pthread_cond_signal
// Setsid_with_pid
// __pthread_cond_timedwait
// Aio_fsync
// Aio_return
// Aio_suspend
// Aio_cancel
// Aio_error
// Aio_read
// Aio_write
// Lio_listio
// __pthread_cond_wait
// Iopolicysys
// __pthread_kill
// __pthread_sigmask
// __sigwait
// __disable_threadsignal
// __pthread_markcancel
// __pthread_canceled
// __semwait_signal
// Proc_info
// Stat64_extended
// Lstat64_extended
// Fstat64_extended
// __pthread_chdir
// __pthread_fchdir
// Audit
// Auditon
// Getauid
// Setauid
// Getaudit
// Setaudit
// Getaudit_addr
// Setaudit_addr
// Auditctl
// Bsdthread_create
// Bsdthread_terminate
// Stack_snapshot
// Bsdthread_register
// Workq_open
// Workq_ops
// __mac_execve
// __mac_syscall
// __mac_get_file
// __mac_set_file
// __mac_get_link
// __mac_set_link
// __mac_get_proc
// __mac_set_proc
// __mac_get_fd
// __mac_set_fd
// __mac_get_pid
// __mac_get_lcid
// __mac_get_lctx
// __mac_set_lctx
// Setlcid
// Read_nocancel
// Write_nocancel
// Open_nocancel
// Close_nocancel
// Wait4_nocancel
// Recvmsg_nocancel
// Sendmsg_nocancel
// Recvfrom_nocancel
// Accept_nocancel
// Fcntl_nocancel
// Select_nocancel
// Fsync_nocancel
// Connect_nocancel
// Sigsuspend_nocancel
// Readv_nocancel
// Writev_nocancel
// Sendto_nocancel
// Pread_nocancel
// Pwrite_nocancel
// Waitid_nocancel
// Poll_nocancel
// Msgsnd_nocancel
// Msgrcv_nocancel
// Sem_wait_nocancel
// Aio_suspend_nocancel
// __sigwait_nocancel
// __semwait_signal_nocancel
// __mac_mount
// __mac_get_mount
// __mac_getfsstat

View File

@ -693,10 +693,10 @@ type SockaddrALG struct {
func (sa *SockaddrALG) sockaddr() (unsafe.Pointer, _Socklen, error) { func (sa *SockaddrALG) sockaddr() (unsafe.Pointer, _Socklen, error) {
// Leave room for NUL byte terminator. // Leave room for NUL byte terminator.
if len(sa.Type) > 13 { if len(sa.Type) > len(sa.raw.Type)-1 {
return nil, 0, EINVAL return nil, 0, EINVAL
} }
if len(sa.Name) > 63 { if len(sa.Name) > len(sa.raw.Name)-1 {
return nil, 0, EINVAL return nil, 0, EINVAL
} }
@ -704,17 +704,8 @@ func (sa *SockaddrALG) sockaddr() (unsafe.Pointer, _Socklen, error) {
sa.raw.Feat = sa.Feature sa.raw.Feat = sa.Feature
sa.raw.Mask = sa.Mask sa.raw.Mask = sa.Mask
typ, err := ByteSliceFromString(sa.Type) copy(sa.raw.Type[:], sa.Type)
if err != nil { copy(sa.raw.Name[:], sa.Name)
return nil, 0, err
}
name, err := ByteSliceFromString(sa.Name)
if err != nil {
return nil, 0, err
}
copy(sa.raw.Type[:], typ)
copy(sa.raw.Name[:], name)
return unsafe.Pointer(&sa.raw), SizeofSockaddrALG, nil return unsafe.Pointer(&sa.raw), SizeofSockaddrALG, nil
} }
@ -1988,8 +1979,6 @@ func Signalfd(fd int, sigmask *Sigset_t, flags int) (newfd int, err error) {
//sys Unshare(flags int) (err error) //sys Unshare(flags int) (err error)
//sys write(fd int, p []byte) (n int, err error) //sys write(fd int, p []byte) (n int, err error)
//sys exitThread(code int) (err error) = SYS_EXIT //sys exitThread(code int) (err error) = SYS_EXIT
//sys readlen(fd int, p *byte, np int) (n int, err error) = SYS_READ
//sys writelen(fd int, p *byte, np int) (n int, err error) = SYS_WRITE
//sys readv(fd int, iovs []Iovec) (n int, err error) = SYS_READV //sys readv(fd int, iovs []Iovec) (n int, err error) = SYS_READV
//sys writev(fd int, iovs []Iovec) (n int, err error) = SYS_WRITEV //sys writev(fd int, iovs []Iovec) (n int, err error) = SYS_WRITEV
//sys preadv(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr) (n int, err error) = SYS_PREADV //sys preadv(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr) (n int, err error) = SYS_PREADV
@ -2493,99 +2482,3 @@ func SchedGetAttr(pid int, flags uint) (*SchedAttr, error) {
} }
return attr, nil return attr, nil
} }
/*
* Unimplemented
*/
// AfsSyscall
// ArchPrctl
// Brk
// ClockNanosleep
// ClockSettime
// Clone
// EpollCtlOld
// EpollPwait
// EpollWaitOld
// Execve
// Fork
// Futex
// GetKernelSyms
// GetMempolicy
// GetRobustList
// GetThreadArea
// Getpmsg
// IoCancel
// IoDestroy
// IoGetevents
// IoSetup
// IoSubmit
// IoprioGet
// IoprioSet
// KexecLoad
// LookupDcookie
// Mbind
// MigratePages
// Mincore
// ModifyLdt
// Mount
// MovePages
// MqGetsetattr
// MqNotify
// MqOpen
// MqTimedreceive
// MqTimedsend
// MqUnlink
// Msgctl
// Msgget
// Msgrcv
// Msgsnd
// Nfsservctl
// Personality
// Pselect6
// Ptrace
// Putpmsg
// Quotactl
// Readahead
// Readv
// RemapFilePages
// RestartSyscall
// RtSigaction
// RtSigpending
// RtSigqueueinfo
// RtSigreturn
// RtSigsuspend
// RtSigtimedwait
// SchedGetPriorityMax
// SchedGetPriorityMin
// SchedGetparam
// SchedGetscheduler
// SchedRrGetInterval
// SchedSetparam
// SchedYield
// Security
// Semctl
// Semget
// Semop
// Semtimedop
// SetMempolicy
// SetRobustList
// SetThreadArea
// SetTidAddress
// Sigaltstack
// Swapoff
// Swapon
// Sysfs
// TimerCreate
// TimerDelete
// TimerGetoverrun
// TimerGettime
// TimerSettime
// Tkill (obsolete)
// Tuxcall
// Umount2
// Uselib
// Utimensat
// Vfork
// Vhangup
// Vserver
// _Sysctl

View File

@ -356,8 +356,6 @@ func Statvfs(path string, buf *Statvfs_t) (err error) {
//sys write(fd int, p []byte) (n int, err error) //sys write(fd int, p []byte) (n int, err error)
//sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error) //sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error)
//sys munmap(addr uintptr, length uintptr) (err error) //sys munmap(addr uintptr, length uintptr) (err error)
//sys readlen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_READ
//sys writelen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_WRITE
//sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error) //sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error)
const ( const (
@ -371,262 +369,3 @@ const (
func mremap(oldaddr uintptr, oldlength uintptr, newlength uintptr, flags int, newaddr uintptr) (uintptr, error) { func mremap(oldaddr uintptr, oldlength uintptr, newlength uintptr, flags int, newaddr uintptr) (uintptr, error) {
return mremapNetBSD(oldaddr, oldlength, newaddr, newlength, flags) return mremapNetBSD(oldaddr, oldlength, newaddr, newlength, flags)
} }
/*
* Unimplemented
*/
// ____semctl13
// __clone
// __fhopen40
// __fhstat40
// __fhstatvfs140
// __fstat30
// __getcwd
// __getfh30
// __getlogin
// __lstat30
// __mount50
// __msgctl13
// __msync13
// __ntp_gettime30
// __posix_chown
// __posix_fchown
// __posix_lchown
// __posix_rename
// __setlogin
// __shmctl13
// __sigaction_sigtramp
// __sigaltstack14
// __sigpending14
// __sigprocmask14
// __sigsuspend14
// __sigtimedwait
// __stat30
// __syscall
// __vfork14
// _ksem_close
// _ksem_destroy
// _ksem_getvalue
// _ksem_init
// _ksem_open
// _ksem_post
// _ksem_trywait
// _ksem_unlink
// _ksem_wait
// _lwp_continue
// _lwp_create
// _lwp_ctl
// _lwp_detach
// _lwp_exit
// _lwp_getname
// _lwp_getprivate
// _lwp_kill
// _lwp_park
// _lwp_self
// _lwp_setname
// _lwp_setprivate
// _lwp_suspend
// _lwp_unpark
// _lwp_unpark_all
// _lwp_wait
// _lwp_wakeup
// _pset_bind
// _sched_getaffinity
// _sched_getparam
// _sched_setaffinity
// _sched_setparam
// acct
// aio_cancel
// aio_error
// aio_fsync
// aio_read
// aio_return
// aio_suspend
// aio_write
// break
// clock_getres
// clock_gettime
// clock_settime
// compat_09_ogetdomainname
// compat_09_osetdomainname
// compat_09_ouname
// compat_10_omsgsys
// compat_10_osemsys
// compat_10_oshmsys
// compat_12_fstat12
// compat_12_getdirentries
// compat_12_lstat12
// compat_12_msync
// compat_12_oreboot
// compat_12_oswapon
// compat_12_stat12
// compat_13_sigaction13
// compat_13_sigaltstack13
// compat_13_sigpending13
// compat_13_sigprocmask13
// compat_13_sigreturn13
// compat_13_sigsuspend13
// compat_14___semctl
// compat_14_msgctl
// compat_14_shmctl
// compat_16___sigaction14
// compat_16___sigreturn14
// compat_20_fhstatfs
// compat_20_fstatfs
// compat_20_getfsstat
// compat_20_statfs
// compat_30___fhstat30
// compat_30___fstat13
// compat_30___lstat13
// compat_30___stat13
// compat_30_fhopen
// compat_30_fhstat
// compat_30_fhstatvfs1
// compat_30_getdents
// compat_30_getfh
// compat_30_ntp_gettime
// compat_30_socket
// compat_40_mount
// compat_43_fstat43
// compat_43_lstat43
// compat_43_oaccept
// compat_43_ocreat
// compat_43_oftruncate
// compat_43_ogetdirentries
// compat_43_ogetdtablesize
// compat_43_ogethostid
// compat_43_ogethostname
// compat_43_ogetkerninfo
// compat_43_ogetpagesize
// compat_43_ogetpeername
// compat_43_ogetrlimit
// compat_43_ogetsockname
// compat_43_okillpg
// compat_43_olseek
// compat_43_ommap
// compat_43_oquota
// compat_43_orecv
// compat_43_orecvfrom
// compat_43_orecvmsg
// compat_43_osend
// compat_43_osendmsg
// compat_43_osethostid
// compat_43_osethostname
// compat_43_osigblock
// compat_43_osigsetmask
// compat_43_osigstack
// compat_43_osigvec
// compat_43_otruncate
// compat_43_owait
// compat_43_stat43
// execve
// extattr_delete_fd
// extattr_delete_file
// extattr_delete_link
// extattr_get_fd
// extattr_get_file
// extattr_get_link
// extattr_list_fd
// extattr_list_file
// extattr_list_link
// extattr_set_fd
// extattr_set_file
// extattr_set_link
// extattrctl
// fchroot
// fdatasync
// fgetxattr
// fktrace
// flistxattr
// fork
// fremovexattr
// fsetxattr
// fstatvfs1
// fsync_range
// getcontext
// getitimer
// getvfsstat
// getxattr
// ktrace
// lchflags
// lchmod
// lfs_bmapv
// lfs_markv
// lfs_segclean
// lfs_segwait
// lgetxattr
// lio_listio
// listxattr
// llistxattr
// lremovexattr
// lseek
// lsetxattr
// lutimes
// madvise
// mincore
// minherit
// modctl
// mq_close
// mq_getattr
// mq_notify
// mq_open
// mq_receive
// mq_send
// mq_setattr
// mq_timedreceive
// mq_timedsend
// mq_unlink
// msgget
// msgrcv
// msgsnd
// nfssvc
// ntp_adjtime
// pmc_control
// pmc_get_info
// pollts
// preadv
// profil
// pselect
// pset_assign
// pset_create
// pset_destroy
// ptrace
// pwritev
// quotactl
// rasctl
// readv
// reboot
// removexattr
// sa_enable
// sa_preempt
// sa_register
// sa_setconcurrency
// sa_stacks
// sa_yield
// sbrk
// sched_yield
// semconfig
// semget
// semop
// setcontext
// setitimer
// setxattr
// shmat
// shmdt
// shmget
// sstk
// statvfs1
// swapctl
// sysarch
// syscall
// timer_create
// timer_delete
// timer_getoverrun
// timer_gettime
// timer_settime
// undelete
// utrace
// uuidgen
// vadvise
// vfork
// writev

View File

@ -326,78 +326,4 @@ func Uname(uname *Utsname) error {
//sys write(fd int, p []byte) (n int, err error) //sys write(fd int, p []byte) (n int, err error)
//sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error) //sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error)
//sys munmap(addr uintptr, length uintptr) (err error) //sys munmap(addr uintptr, length uintptr) (err error)
//sys readlen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_READ
//sys writelen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_WRITE
//sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error) //sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error)
/*
* Unimplemented
*/
// __getcwd
// __semctl
// __syscall
// __sysctl
// adjfreq
// break
// clock_getres
// clock_gettime
// clock_settime
// closefrom
// execve
// fhopen
// fhstat
// fhstatfs
// fork
// futimens
// getfh
// getgid
// getitimer
// getlogin
// getthrid
// ktrace
// lfs_bmapv
// lfs_markv
// lfs_segclean
// lfs_segwait
// mincore
// minherit
// mount
// mquery
// msgctl
// msgget
// msgrcv
// msgsnd
// nfssvc
// nnpfspioctl
// preadv
// profil
// pwritev
// quotactl
// readv
// reboot
// renameat
// rfork
// sched_yield
// semget
// semop
// setgroups
// setitimer
// setsockopt
// shmat
// shmctl
// shmdt
// shmget
// sigaction
// sigaltstack
// sigpending
// sigprocmask
// sigreturn
// sigsuspend
// sysarch
// syscall
// threxit
// thrsigdivert
// thrsleep
// thrwakeup
// vfork
// writev

View File

@ -698,24 +698,6 @@ func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err e
//sys setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error) = libsocket.setsockopt //sys setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error) = libsocket.setsockopt
//sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error) = libsocket.recvfrom //sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error) = libsocket.recvfrom
func readlen(fd int, buf *byte, nbuf int) (n int, err error) {
r0, _, e1 := sysvicall6(uintptr(unsafe.Pointer(&procread)), 3, uintptr(fd), uintptr(unsafe.Pointer(buf)), uintptr(nbuf), 0, 0, 0)
n = int(r0)
if e1 != 0 {
err = e1
}
return
}
func writelen(fd int, buf *byte, nbuf int) (n int, err error) {
r0, _, e1 := sysvicall6(uintptr(unsafe.Pointer(&procwrite)), 3, uintptr(fd), uintptr(unsafe.Pointer(buf)), uintptr(nbuf), 0, 0, 0)
n = int(r0)
if e1 != 0 {
err = e1
}
return
}
// Event Ports // Event Ports
type fileObjCookie struct { type fileObjCookie struct {

View File

@ -192,7 +192,6 @@ func (cmsg *Cmsghdr) SetLen(length int) {
//sys fcntl(fd int, cmd int, arg int) (val int, err error) //sys fcntl(fd int, cmd int, arg int) (val int, err error)
//sys read(fd int, p []byte) (n int, err error) //sys read(fd int, p []byte) (n int, err error)
//sys readlen(fd int, buf *byte, nbuf int) (n int, err error) = SYS_READ
//sys write(fd int, p []byte) (n int, err error) //sys write(fd int, p []byte) (n int, err error)
//sys accept(s int, rsa *RawSockaddrAny, addrlen *_Socklen) (fd int, err error) = SYS___ACCEPT_A //sys accept(s int, rsa *RawSockaddrAny, addrlen *_Socklen) (fd int, err error) = SYS___ACCEPT_A

View File

@ -2421,6 +2421,15 @@ const (
PR_PAC_GET_ENABLED_KEYS = 0x3d PR_PAC_GET_ENABLED_KEYS = 0x3d
PR_PAC_RESET_KEYS = 0x36 PR_PAC_RESET_KEYS = 0x36
PR_PAC_SET_ENABLED_KEYS = 0x3c PR_PAC_SET_ENABLED_KEYS = 0x3c
PR_RISCV_V_GET_CONTROL = 0x46
PR_RISCV_V_SET_CONTROL = 0x45
PR_RISCV_V_VSTATE_CTRL_CUR_MASK = 0x3
PR_RISCV_V_VSTATE_CTRL_DEFAULT = 0x0
PR_RISCV_V_VSTATE_CTRL_INHERIT = 0x10
PR_RISCV_V_VSTATE_CTRL_MASK = 0x1f
PR_RISCV_V_VSTATE_CTRL_NEXT_MASK = 0xc
PR_RISCV_V_VSTATE_CTRL_OFF = 0x1
PR_RISCV_V_VSTATE_CTRL_ON = 0x2
PR_SCHED_CORE = 0x3e PR_SCHED_CORE = 0x3e
PR_SCHED_CORE_CREATE = 0x1 PR_SCHED_CORE_CREATE = 0x1
PR_SCHED_CORE_GET = 0x0 PR_SCHED_CORE_GET = 0x0

View File

@ -326,10 +326,12 @@ const (
SO_NOFCS = 0x2b SO_NOFCS = 0x2b
SO_OOBINLINE = 0xa SO_OOBINLINE = 0xa
SO_PASSCRED = 0x10 SO_PASSCRED = 0x10
SO_PASSPIDFD = 0x4c
SO_PASSSEC = 0x22 SO_PASSSEC = 0x22
SO_PEEK_OFF = 0x2a SO_PEEK_OFF = 0x2a
SO_PEERCRED = 0x11 SO_PEERCRED = 0x11
SO_PEERGROUPS = 0x3b SO_PEERGROUPS = 0x3b
SO_PEERPIDFD = 0x4d
SO_PEERSEC = 0x1f SO_PEERSEC = 0x1f
SO_PREFER_BUSY_POLL = 0x45 SO_PREFER_BUSY_POLL = 0x45
SO_PROTOCOL = 0x26 SO_PROTOCOL = 0x26

View File

@ -327,10 +327,12 @@ const (
SO_NOFCS = 0x2b SO_NOFCS = 0x2b
SO_OOBINLINE = 0xa SO_OOBINLINE = 0xa
SO_PASSCRED = 0x10 SO_PASSCRED = 0x10
SO_PASSPIDFD = 0x4c
SO_PASSSEC = 0x22 SO_PASSSEC = 0x22
SO_PEEK_OFF = 0x2a SO_PEEK_OFF = 0x2a
SO_PEERCRED = 0x11 SO_PEERCRED = 0x11
SO_PEERGROUPS = 0x3b SO_PEERGROUPS = 0x3b
SO_PEERPIDFD = 0x4d
SO_PEERSEC = 0x1f SO_PEERSEC = 0x1f
SO_PREFER_BUSY_POLL = 0x45 SO_PREFER_BUSY_POLL = 0x45
SO_PROTOCOL = 0x26 SO_PROTOCOL = 0x26

View File

@ -333,10 +333,12 @@ const (
SO_NOFCS = 0x2b SO_NOFCS = 0x2b
SO_OOBINLINE = 0xa SO_OOBINLINE = 0xa
SO_PASSCRED = 0x10 SO_PASSCRED = 0x10
SO_PASSPIDFD = 0x4c
SO_PASSSEC = 0x22 SO_PASSSEC = 0x22
SO_PEEK_OFF = 0x2a SO_PEEK_OFF = 0x2a
SO_PEERCRED = 0x11 SO_PEERCRED = 0x11
SO_PEERGROUPS = 0x3b SO_PEERGROUPS = 0x3b
SO_PEERPIDFD = 0x4d
SO_PEERSEC = 0x1f SO_PEERSEC = 0x1f
SO_PREFER_BUSY_POLL = 0x45 SO_PREFER_BUSY_POLL = 0x45
SO_PROTOCOL = 0x26 SO_PROTOCOL = 0x26

View File

@ -323,10 +323,12 @@ const (
SO_NOFCS = 0x2b SO_NOFCS = 0x2b
SO_OOBINLINE = 0xa SO_OOBINLINE = 0xa
SO_PASSCRED = 0x10 SO_PASSCRED = 0x10
SO_PASSPIDFD = 0x4c
SO_PASSSEC = 0x22 SO_PASSSEC = 0x22
SO_PEEK_OFF = 0x2a SO_PEEK_OFF = 0x2a
SO_PEERCRED = 0x11 SO_PEERCRED = 0x11
SO_PEERGROUPS = 0x3b SO_PEERGROUPS = 0x3b
SO_PEERPIDFD = 0x4d
SO_PEERSEC = 0x1f SO_PEERSEC = 0x1f
SO_PREFER_BUSY_POLL = 0x45 SO_PREFER_BUSY_POLL = 0x45
SO_PROTOCOL = 0x26 SO_PROTOCOL = 0x26

View File

@ -118,6 +118,8 @@ const (
IUCLC = 0x200 IUCLC = 0x200
IXOFF = 0x1000 IXOFF = 0x1000
IXON = 0x400 IXON = 0x400
LASX_CTX_MAGIC = 0x41535801
LSX_CTX_MAGIC = 0x53580001
MAP_ANON = 0x20 MAP_ANON = 0x20
MAP_ANONYMOUS = 0x20 MAP_ANONYMOUS = 0x20
MAP_DENYWRITE = 0x800 MAP_DENYWRITE = 0x800
@ -317,10 +319,12 @@ const (
SO_NOFCS = 0x2b SO_NOFCS = 0x2b
SO_OOBINLINE = 0xa SO_OOBINLINE = 0xa
SO_PASSCRED = 0x10 SO_PASSCRED = 0x10
SO_PASSPIDFD = 0x4c
SO_PASSSEC = 0x22 SO_PASSSEC = 0x22
SO_PEEK_OFF = 0x2a SO_PEEK_OFF = 0x2a
SO_PEERCRED = 0x11 SO_PEERCRED = 0x11
SO_PEERGROUPS = 0x3b SO_PEERGROUPS = 0x3b
SO_PEERPIDFD = 0x4d
SO_PEERSEC = 0x1f SO_PEERSEC = 0x1f
SO_PREFER_BUSY_POLL = 0x45 SO_PREFER_BUSY_POLL = 0x45
SO_PROTOCOL = 0x26 SO_PROTOCOL = 0x26

View File

@ -326,10 +326,12 @@ const (
SO_NOFCS = 0x2b SO_NOFCS = 0x2b
SO_OOBINLINE = 0x100 SO_OOBINLINE = 0x100
SO_PASSCRED = 0x11 SO_PASSCRED = 0x11
SO_PASSPIDFD = 0x4c
SO_PASSSEC = 0x22 SO_PASSSEC = 0x22
SO_PEEK_OFF = 0x2a SO_PEEK_OFF = 0x2a
SO_PEERCRED = 0x12 SO_PEERCRED = 0x12
SO_PEERGROUPS = 0x3b SO_PEERGROUPS = 0x3b
SO_PEERPIDFD = 0x4d
SO_PEERSEC = 0x1e SO_PEERSEC = 0x1e
SO_PREFER_BUSY_POLL = 0x45 SO_PREFER_BUSY_POLL = 0x45
SO_PROTOCOL = 0x1028 SO_PROTOCOL = 0x1028

View File

@ -326,10 +326,12 @@ const (
SO_NOFCS = 0x2b SO_NOFCS = 0x2b
SO_OOBINLINE = 0x100 SO_OOBINLINE = 0x100
SO_PASSCRED = 0x11 SO_PASSCRED = 0x11
SO_PASSPIDFD = 0x4c
SO_PASSSEC = 0x22 SO_PASSSEC = 0x22
SO_PEEK_OFF = 0x2a SO_PEEK_OFF = 0x2a
SO_PEERCRED = 0x12 SO_PEERCRED = 0x12
SO_PEERGROUPS = 0x3b SO_PEERGROUPS = 0x3b
SO_PEERPIDFD = 0x4d
SO_PEERSEC = 0x1e SO_PEERSEC = 0x1e
SO_PREFER_BUSY_POLL = 0x45 SO_PREFER_BUSY_POLL = 0x45
SO_PROTOCOL = 0x1028 SO_PROTOCOL = 0x1028

Some files were not shown because too many files have changed in this diff Show More