Import xtransport
This commit is contained in:
parent
ecaf18f614
commit
61bad01726
|
@ -0,0 +1,147 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jedisct1/dlog"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const DefaultFallbackResolver = "9.9.9.9:53"
|
||||
|
||||
type CachedIPs struct {
|
||||
sync.RWMutex
|
||||
cache map[string]string
|
||||
}
|
||||
|
||||
type XTransport struct {
|
||||
transport *http.Transport
|
||||
timeout time.Duration
|
||||
cachedIPs CachedIPs
|
||||
fallbackResolver string
|
||||
}
|
||||
|
||||
func NewXTransport(timeout time.Duration) *XTransport {
|
||||
xTransport := XTransport{
|
||||
cachedIPs: CachedIPs{cache: make(map[string]string)},
|
||||
timeout: timeout,
|
||||
fallbackResolver: DefaultFallbackResolver,
|
||||
}
|
||||
dialer := &net.Dialer{Timeout: timeout, KeepAlive: timeout, DualStack: true}
|
||||
transport := &http.Transport{
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: true,
|
||||
MaxIdleConns: 1,
|
||||
IdleConnTimeout: timeout,
|
||||
ResponseHeaderTimeout: timeout,
|
||||
ExpectContinueTimeout: timeout,
|
||||
MaxResponseHeaderBytes: 4096,
|
||||
DialContext: func(ctx context.Context, network, addrStr string) (net.Conn, error) {
|
||||
host := addrStr[:strings.LastIndex(addrStr, ":")]
|
||||
ipOnly := host
|
||||
xTransport.cachedIPs.RLock()
|
||||
cachedIP := xTransport.cachedIPs.cache[host]
|
||||
xTransport.cachedIPs.RUnlock()
|
||||
if len(cachedIP) > 0 {
|
||||
ipOnly = cachedIP
|
||||
} else {
|
||||
dlog.Debugf("[%s] IP address was not cached", host)
|
||||
}
|
||||
addrStr = ipOnly + addrStr[strings.LastIndex(addrStr, ":"):]
|
||||
return dialer.DialContext(ctx, network, addrStr)
|
||||
},
|
||||
}
|
||||
xTransport.transport = transport
|
||||
return &xTransport
|
||||
}
|
||||
|
||||
func (xTransport *XTransport) Fetch(method string, url *url.URL, accept string, contentType string, body *io.ReadCloser, timeout time.Duration) (*http.Response, time.Duration, error) {
|
||||
if timeout <= 0 {
|
||||
timeout = xTransport.timeout
|
||||
}
|
||||
client := http.Client{Transport: xTransport.transport, Timeout: timeout}
|
||||
header := map[string][]string{"User-Agent": {"dnscrypt-proxy"}}
|
||||
if len(accept) > 0 {
|
||||
header["Accept"] = []string{accept}
|
||||
}
|
||||
if len(contentType) > 0 {
|
||||
header["Content-Type"] = []string{contentType}
|
||||
}
|
||||
req := &http.Request{
|
||||
Method: method,
|
||||
URL: url,
|
||||
Header: header,
|
||||
Close: false,
|
||||
}
|
||||
if body != nil {
|
||||
req.Body = *body
|
||||
}
|
||||
start := time.Now()
|
||||
resp, err := client.Do(req)
|
||||
rtt := time.Since(start)
|
||||
if err == nil {
|
||||
if resp == nil {
|
||||
err = errors.New("Webserver returned an error")
|
||||
} else if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
err = fmt.Errorf("Webserver returned code %d", resp.StatusCode)
|
||||
}
|
||||
return resp, rtt, err
|
||||
}
|
||||
host := url.Host
|
||||
xTransport.cachedIPs.RLock()
|
||||
cachedIP := xTransport.cachedIPs.cache[host]
|
||||
xTransport.cachedIPs.RUnlock()
|
||||
if len(cachedIP) > 0 {
|
||||
dlog.Debugf("IP for [%s] was cached to [%s], but connection failed: [%s]", host, cachedIP, err)
|
||||
return resp, rtt, err
|
||||
}
|
||||
dnsClient := new(dns.Client)
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||
msg.SetEdns0(4096, true)
|
||||
dlog.Noticef("System DNS configuration not usable yet, exceptionally resolving [%s] using fallback resolver [%s]", host, xTransport.fallbackResolver)
|
||||
in, _, err := dnsClient.Exchange(msg, xTransport.fallbackResolver)
|
||||
if err != nil {
|
||||
return resp, rtt, err
|
||||
}
|
||||
if len(in.Answer) <= 0 {
|
||||
return resp, rtt, fmt.Errorf("No IP found for [%s]", host)
|
||||
}
|
||||
foundIP := in.Answer[0].(*dns.A).A.String()
|
||||
xTransport.cachedIPs.Lock()
|
||||
xTransport.cachedIPs.cache[host] = foundIP
|
||||
xTransport.cachedIPs.Unlock()
|
||||
dlog.Debugf("[%s] IP address [%s] added to the cache", host, foundIP)
|
||||
|
||||
start = time.Now()
|
||||
resp, err = client.Do(req)
|
||||
rtt = time.Since(start)
|
||||
if err == nil {
|
||||
if resp == nil {
|
||||
err = errors.New("Webserver returned an error")
|
||||
} else if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
err = fmt.Errorf("Webserver returned code %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
return resp, rtt, err
|
||||
}
|
||||
|
||||
func (xTransport *XTransport) Get(url *url.URL, timeout time.Duration) (*http.Response, time.Duration, error) {
|
||||
return xTransport.Fetch("GET", url, "", "", nil, timeout)
|
||||
}
|
||||
|
||||
func (xTransport *XTransport) Post(url *url.URL, accept string, contentType string, body []byte, timeout time.Duration) (*http.Response, time.Duration, error) {
|
||||
bc := ioutil.NopCloser(bytes.NewReader(body))
|
||||
return xTransport.Fetch("POST", url, accept, contentType, &bc, timeout)
|
||||
}
|
Loading…
Reference in New Issue