Make the xTransport functions return the HTTP body directly

This simplifies things, but also make RTT computation way more reliable
This commit is contained in:
Frank Denis 2020-02-21 22:33:34 +01:00
parent a6d946c41f
commit aa0e7f42d3
4 changed files with 24 additions and 27 deletions

View File

@ -3,8 +3,6 @@ package main
import (
crypto_rand "crypto/rand"
"encoding/binary"
"io"
"io/ioutil"
"net"
"os"
"sync/atomic"
@ -513,9 +511,9 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str
tid := TransactionID(query)
SetTransactionID(query, 0)
serverInfo.noticeBegin(proxy)
resp, _, err := proxy.xTransport.DoHQuery(serverInfo.useGet, serverInfo.URL, query, proxy.timeout)
serverResponse, tls, _, err := proxy.xTransport.DoHQuery(serverInfo.useGet, serverInfo.URL, query, proxy.timeout)
SetTransactionID(query, tid)
if err == nil {
if err == nil || tls == nil || !tls.HandshakeComplete {
response = nil
} else if stale, ok := pluginsState.sessionData["stale"]; ok {
dlog.Debug("Serving stale response")
@ -528,7 +526,7 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str
return
}
if response == nil {
response, err = ioutil.ReadAll(io.LimitReader(resp.Body, int64(MaxDNSPacketSize)))
response = serverResponse
}
if err != nil {
pluginsState.returnCode = PluginsReturnCodeNetworkError

View File

@ -6,8 +6,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"net/url"
@ -381,18 +379,17 @@ func fetchDoHServerInfo(proxy *Proxy, name string, stamp stamps.ServerStamp, isN
}
body := dohTestPacket(0xcafe)
useGet := false
if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil {
if _, _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil {
useGet = true
if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil {
if _, _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil {
return ServerInfo{}, err
}
dlog.Debugf("Server [%s] doesn't appear to support POST; falling back to GET requests", name)
}
resp, rtt, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout)
serverResponse, tls, rtt, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout)
if err != nil {
return ServerInfo{}, err
}
tls := resp.TLS
if tls == nil || !tls.HandshakeComplete {
return ServerInfo{}, errors.New("TLS handshake failed")
}
@ -428,7 +425,7 @@ func fetchDoHServerInfo(proxy *Proxy, name string, stamp stamps.ServerStamp, isN
if !found && len(stamp.Hashes) > 0 {
return ServerInfo{}, fmt.Errorf("Certificate hash [%x] not found for [%s]", wantedHash, name)
}
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength))
respBody := serverResponse
if err != nil {
return ServerInfo{}, err
}

View File

@ -3,9 +3,7 @@ package main
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
@ -132,12 +130,8 @@ func (source *Source) parseURLs(urls []string) {
}
func fetchFromURL(xTransport *XTransport, u *url.URL) (bin []byte, err error) {
var resp *http.Response
if resp, _, err = xTransport.Get(u, "", DefaultTimeout); err == nil {
bin, err = ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength))
resp.Body.Close()
}
return
bin, _, _, err = xTransport.Get(u, "", DefaultTimeout)
return bin, err
}
func (source *Source) fetchWithCache(xTransport *XTransport, now time.Time) (delay time.Duration, err error) {

View File

@ -8,6 +8,7 @@ import (
"encoding/base64"
"encoding/hex"
"errors"
"io"
"io/ioutil"
"math/rand"
"net"
@ -316,7 +317,7 @@ func (xTransport *XTransport) resolveAndUpdateCache(host string) error {
return nil
}
func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) (*http.Response, time.Duration, error) {
func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) {
if timeout <= 0 {
timeout = xTransport.timeout
}
@ -338,11 +339,11 @@ func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string,
}
host, _ := ExtractHostAndPort(url.Host, 0)
if xTransport.proxyDialer == nil && strings.HasSuffix(host, ".onion") {
return nil, 0, errors.New("Onion service is not reachable without Tor")
return nil, nil, 0, errors.New("Onion service is not reachable without Tor")
}
if err := xTransport.resolveAndUpdateCache(host); err != nil {
dlog.Errorf("Unable to resolve [%v] - Make sure that the system resolver works, or that `fallback_resolver` has been set to a resolver that can be reached", host)
return nil, 0, err
return nil, nil, 0, err
}
req := &http.Request{
Method: method,
@ -373,19 +374,26 @@ func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string,
xTransport.tlsCipherSuite = nil
xTransport.rebuildTransport()
}
return nil, nil, 0, err
}
return resp, rtt, err
tls := resp.TLS
bin, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength))
if err != nil {
return nil, tls, 0, err
}
resp.Body.Close()
return bin, tls, rtt, err
}
func (xTransport *XTransport) Get(url *url.URL, accept string, timeout time.Duration) (*http.Response, time.Duration, error) {
func (xTransport *XTransport) Get(url *url.URL, accept string, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) {
return xTransport.Fetch("GET", url, accept, "", nil, timeout)
}
func (xTransport *XTransport) Post(url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) (*http.Response, time.Duration, error) {
func (xTransport *XTransport) Post(url *url.URL, accept string, contentType string, body *[]byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) {
return xTransport.Fetch("POST", url, accept, contentType, body, timeout)
}
func (xTransport *XTransport) DoHQuery(useGet bool, url *url.URL, body []byte, timeout time.Duration) (*http.Response, time.Duration, error) {
func (xTransport *XTransport) DoHQuery(useGet bool, url *url.URL, body []byte, timeout time.Duration) ([]byte, *tls.ConnectionState, time.Duration, error) {
dataType := "application/dns-message"
if useGet {
qs := url.Query()