From 61bad01726f6b7439cbc73154158e6cbcadee21a Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Tue, 30 Jan 2018 15:51:07 +0100 Subject: [PATCH] Import xtransport --- dnscrypt-proxy/xtransport.go | 147 +++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 dnscrypt-proxy/xtransport.go diff --git a/dnscrypt-proxy/xtransport.go b/dnscrypt-proxy/xtransport.go new file mode 100644 index 00000000..b1a7f9be --- /dev/null +++ b/dnscrypt-proxy/xtransport.go @@ -0,0 +1,147 @@ +package main + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/jedisct1/dlog" + "github.com/miekg/dns" +) + +const DefaultFallbackResolver = "9.9.9.9:53" + +type CachedIPs struct { + sync.RWMutex + cache map[string]string +} + +type XTransport struct { + transport *http.Transport + timeout time.Duration + cachedIPs CachedIPs + fallbackResolver string +} + +func NewXTransport(timeout time.Duration) *XTransport { + xTransport := XTransport{ + cachedIPs: CachedIPs{cache: make(map[string]string)}, + timeout: timeout, + fallbackResolver: DefaultFallbackResolver, + } + dialer := &net.Dialer{Timeout: timeout, KeepAlive: timeout, DualStack: true} + transport := &http.Transport{ + DisableKeepAlives: false, + DisableCompression: true, + MaxIdleConns: 1, + IdleConnTimeout: timeout, + ResponseHeaderTimeout: timeout, + ExpectContinueTimeout: timeout, + MaxResponseHeaderBytes: 4096, + DialContext: func(ctx context.Context, network, addrStr string) (net.Conn, error) { + host := addrStr[:strings.LastIndex(addrStr, ":")] + ipOnly := host + xTransport.cachedIPs.RLock() + cachedIP := xTransport.cachedIPs.cache[host] + xTransport.cachedIPs.RUnlock() + if len(cachedIP) > 0 { + ipOnly = cachedIP + } else { + dlog.Debugf("[%s] IP address was not cached", host) + } + addrStr = ipOnly + addrStr[strings.LastIndex(addrStr, ":"):] + return dialer.DialContext(ctx, network, addrStr) + }, + } + xTransport.transport = transport + return &xTransport +} + +func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, contentType string, body *io.ReadCloser, timeout time.Duration) (*http.Response, time.Duration, error) { + if timeout <= 0 { + timeout = xTransport.timeout + } + client := http.Client{Transport: xTransport.transport, Timeout: timeout} + header := map[string][]string{"User-Agent": {"dnscrypt-proxy"}} + if len(accept) > 0 { + header["Accept"] = []string{accept} + } + if len(contentType) > 0 { + header["Content-Type"] = []string{contentType} + } + req := &http.Request{ + Method: method, + URL: url, + Header: header, + Close: false, + } + if body != nil { + req.Body = *body + } + start := time.Now() + resp, err := client.Do(req) + rtt := time.Since(start) + if err == nil { + if resp == nil { + err = errors.New("Webserver returned an error") + } else if resp.StatusCode < 200 || resp.StatusCode > 299 { + err = fmt.Errorf("Webserver returned code %d", resp.StatusCode) + } + return resp, rtt, err + } + host := url.Host + xTransport.cachedIPs.RLock() + cachedIP := xTransport.cachedIPs.cache[host] + xTransport.cachedIPs.RUnlock() + if len(cachedIP) > 0 { + dlog.Debugf("IP for [%s] was cached to [%s], but connection failed: [%s]", host, cachedIP, err) + return resp, rtt, err + } + dnsClient := new(dns.Client) + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(host), dns.TypeA) + msg.SetEdns0(4096, true) + dlog.Noticef("System DNS configuration not usable yet, exceptionally resolving [%s] using fallback resolver [%s]", host, xTransport.fallbackResolver) + in, _, err := dnsClient.Exchange(msg, xTransport.fallbackResolver) + if err != nil { + return resp, rtt, err + } + if len(in.Answer) <= 0 { + return resp, rtt, fmt.Errorf("No IP found for [%s]", host) + } + foundIP := in.Answer[0].(*dns.A).A.String() + xTransport.cachedIPs.Lock() + xTransport.cachedIPs.cache[host] = foundIP + xTransport.cachedIPs.Unlock() + dlog.Debugf("[%s] IP address [%s] added to the cache", host, foundIP) + + start = time.Now() + resp, err = client.Do(req) + rtt = time.Since(start) + if err == nil { + if resp == nil { + err = errors.New("Webserver returned an error") + } else if resp.StatusCode < 200 || resp.StatusCode > 299 { + err = fmt.Errorf("Webserver returned code %d", resp.StatusCode) + } + } + return resp, rtt, err +} + +func (xTransport *XTransport) Get(url *url.URL, timeout time.Duration) (*http.Response, time.Duration, error) { + return xTransport.Fetch("GET", url, "", "", nil, timeout) +} + +func (xTransport *XTransport) Post(url *url.URL, accept string, contentType string, body []byte, timeout time.Duration) (*http.Response, time.Duration, error) { + bc := ioutil.NopCloser(bytes.NewReader(body)) + return xTransport.Fetch("POST", url, accept, contentType, &bc, timeout) +}