Use a custom transport and a host->ip cache

maybe
Fixes #45
This commit is contained in:
Frank Denis 2018-01-29 03:58:39 +01:00
parent 16928b9954
commit cf12fb170a
3 changed files with 44 additions and 2 deletions

View File

@ -2,6 +2,7 @@ package main
import (
"bytes"
"context"
"crypto/rand"
"flag"
"fmt"
@ -10,6 +11,7 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
@ -21,6 +23,11 @@ import (
const AppVersion = "2.0.0beta11"
type CachedIPs struct {
sync.RWMutex
cache map[string]string
}
type Proxy struct {
proxyPublicKey [32]byte
proxySecretKey [32]byte
@ -57,6 +64,7 @@ type Proxy struct {
clientsCount uint32
maxClients uint32
httpTransport *http.Transport
cachedIPs CachedIPs
}
type App struct {
@ -116,6 +124,7 @@ func main() {
func (app *App) Start(service service.Service) error {
proxy := app.proxy
proxy.cachedIPs.cache = make(map[string]string)
if err := InitPluginsGlobals(&proxy.pluginsGlobals, &proxy); err != nil {
dlog.Fatal(err)
}
@ -155,6 +164,11 @@ func (proxy *Proxy) StartProxy() {
for _, registeredServer := range proxy.registeredServers {
proxy.serversInfo.registerServer(proxy, registeredServer.name, registeredServer.stamp)
}
dialer := &net.Dialer{
Timeout: proxy.timeout,
KeepAlive: proxy.timeout,
DualStack: true,
}
proxy.httpTransport = &http.Transport{
DisableKeepAlives: false,
DisableCompression: true,
@ -163,6 +177,21 @@ func (proxy *Proxy) StartProxy() {
ResponseHeaderTimeout: proxy.timeout,
ExpectContinueTimeout: proxy.timeout,
MaxResponseHeaderBytes: 4096,
DialContext: func(ctx context.Context, network, addrStr string) (net.Conn, error) {
host := addrStr[:strings.LastIndex(addrStr, ":")]
ipOnly := host
proxy.cachedIPs.RLock()
cachedIP := proxy.cachedIPs.cache[host]
proxy.cachedIPs.RUnlock()
if len(cachedIP) > 0 {
ipOnly = cachedIP
dlog.Infof("[%s] IP address was cached: [%s]", host, ipOnly)
} else {
dlog.Infof("[%s] IP address was not cached", host)
}
addrStr = ipOnly + addrStr[strings.LastIndex(addrStr, ":"):]
return dialer.DialContext(ctx, network, addrStr)
},
}
for _, listenAddrStr := range proxy.listenAddresses {
listenUDPAddr, err := net.ResolveUDPAddr("udp", listenAddrStr)

View File

@ -197,6 +197,13 @@ func (serversInfo *ServersInfo) fetchDNSCryptServerInfo(proxy *Proxy, name strin
}
func (serversInfo *ServersInfo) fetchDoHServerInfo(proxy *Proxy, name string, stamp ServerStamp) (ServerInfo, error) {
if len(stamp.serverAddrStr) > 0 {
addrStr := stamp.serverAddrStr
ipOnly := addrStr[:strings.LastIndex(addrStr, ":")]
proxy.cachedIPs.Lock()
proxy.cachedIPs.cache[stamp.providerName] = ipOnly
proxy.cachedIPs.Unlock()
}
url := &url.URL{
Scheme: "https",
Host: stamp.providerName,

View File

@ -70,7 +70,7 @@ func NewServerStampFromString(stampStr string) (ServerStamp, error) {
// id(u8)=0x02 props addrLen(1) serverAddr pkStrlen(1) pkStr providerNameLen(1) providerName
func newDNSCryptServerStamp(bin []byte) (ServerStamp, error) {
stamp := ServerStamp{proto:StampProtoTypeDNSCrypt}
stamp := ServerStamp{proto: StampProtoTypeDNSCrypt}
if len(bin) < 24 {
return stamp, errors.New("Stamp is too short")
}
@ -85,6 +85,9 @@ func newDNSCryptServerStamp(bin []byte) (ServerStamp, error) {
pos++
stamp.serverAddrStr = string(bin[pos : pos+len])
pos += len
if net.ParseIP(stamp.serverAddrStr) != nil {
stamp.serverAddrStr = fmt.Sprintf("%s:%d", stamp.serverAddrStr, DefaultPort)
}
len = int(bin[pos])
if len >= binLen-pos {
@ -111,7 +114,7 @@ func newDNSCryptServerStamp(bin []byte) (ServerStamp, error) {
// id(u8)=0x02 props addrLen(1) serverAddr hashLen(1) hash providerNameLen(1) providerName pathLen(1) path
func newDoHServerStamp(bin []byte) (ServerStamp, error) {
stamp := ServerStamp{proto:StampProtoTypeDoH}
stamp := ServerStamp{proto: StampProtoTypeDoH}
stamp.props = ServerInformalProperties(binary.LittleEndian.Uint64(bin[1:9]))
binLen := len(bin)
@ -124,6 +127,9 @@ func newDoHServerStamp(bin []byte) (ServerStamp, error) {
pos++
stamp.serverAddrStr = string(bin[pos : pos+len])
pos += len
if net.ParseIP(stamp.serverAddrStr) != nil {
stamp.serverAddrStr = fmt.Sprintf("%s:%d", stamp.serverAddrStr, DefaultPort)
}
len = int(bin[pos])
if len >= binLen-pos {