dnscrypt-proxy/dnscrypt-proxy/serversInfo.go

339 lines
9.7 KiB
Go

package main
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/VividCortex/ewma"
"github.com/cloudflare/p751sidh"
"github.com/jedisct1/dlog"
"golang.org/x/crypto/ed25519"
)
const (
RTTEwmaDecay = 10.0
DefaultPort = 443
)
type ServerInformalProperties uint64
const (
ServerInformalPropertyDNSSEC = ServerInformalProperties(1) << 0
ServerInformalPropertyNoLog = ServerInformalProperties(1) << 1
ServerInformalPropertyNoFilter = ServerInformalProperties(1) << 2
)
type RegisteredServer struct {
name string
stamp ServerStamp
description string
}
type ServerInfo struct {
sync.RWMutex
Proto StampProtoType
MagicQuery [8]byte
ServerPk [32]byte
ServerSIDHPk *p751sidh.SIDHPublicKeyBob
SharedKey [32]byte
CryptoConstruction CryptoConstruction
Name string
Timeout time.Duration
URL *url.URL
HostName string
UDPAddr *net.UDPAddr
TCPAddr *net.TCPAddr
lastActionTS time.Time
rtt ewma.MovingAverage
initialRtt int
useGet bool
}
type LBStrategy int
const (
LBStrategyNone = LBStrategy(iota)
LBStrategyP2
LBStrategyPH
LBStrategyFastest
LBStrategyRandom
)
const DefaultLBStrategy = LBStrategyP2
type ServersInfo struct {
sync.RWMutex
inner []ServerInfo
registeredServers []RegisteredServer
lbStrategy LBStrategy
}
func (serversInfo *ServersInfo) registerServer(proxy *Proxy, name string, stamp ServerStamp) error {
newRegisteredServer := RegisteredServer{name: name, stamp: stamp}
serversInfo.Lock()
defer serversInfo.Unlock()
for i, oldRegisteredServer := range serversInfo.registeredServers {
if oldRegisteredServer.name == name {
serversInfo.registeredServers[i] = newRegisteredServer
return nil
}
}
serversInfo.registeredServers = append(serversInfo.registeredServers, newRegisteredServer)
return nil
}
func (serversInfo *ServersInfo) refreshServer(proxy *Proxy, name string, stamp ServerStamp) error {
serversInfo.Lock()
defer serversInfo.Unlock()
previousIndex := -1
for i, oldServer := range serversInfo.inner {
if oldServer.Name == name {
previousIndex = i
break
}
}
newServer, err := serversInfo.fetchServerInfo(proxy, name, stamp, previousIndex < 0)
if err != nil {
return err
}
if name != newServer.Name {
dlog.Fatalf("[%s] != [%s]", name, newServer.Name)
}
newServer.rtt = ewma.NewMovingAverage(RTTEwmaDecay)
if previousIndex >= 0 {
serversInfo.inner[previousIndex] = newServer
return nil
}
serversInfo.inner = append(serversInfo.inner, newServer)
serversInfo.registeredServers = append(serversInfo.registeredServers, RegisteredServer{name: name, stamp: stamp})
return nil
}
func (serversInfo *ServersInfo) refresh(proxy *Proxy) (int, error) {
dlog.Debug("Refreshing certificates")
serversInfo.RLock()
registeredServers := serversInfo.registeredServers
serversInfo.RUnlock()
liveServers := 0
var err error
for _, registeredServer := range registeredServers {
if err = serversInfo.refreshServer(proxy, registeredServer.name, registeredServer.stamp); err == nil {
liveServers++
}
}
serversInfo.Lock()
inner := serversInfo.inner
innerLen := len(inner)
for i := 0; i < innerLen; i++ {
for j := i + 1; j < innerLen; j++ {
if inner[j].initialRtt < inner[i].initialRtt {
inner[j], inner[i] = inner[i], inner[j]
}
}
}
serversInfo.inner = inner
if innerLen > 0 {
dlog.Noticef("Server with the lowest initial latency: %s (rtt: %dms)", inner[0].Name, inner[0].initialRtt)
proxy.certIgnoreTimestamp = false
}
serversInfo.Unlock()
return liveServers, err
}
func (serversInfo *ServersInfo) liveServers() int {
serversInfo.RLock()
liveServers := len(serversInfo.inner)
serversInfo.RUnlock()
return liveServers
}
func (serversInfo *ServersInfo) getOne() *ServerInfo {
serversInfo.Lock()
defer serversInfo.Unlock()
serversCount := len(serversInfo.inner)
if serversCount <= 0 {
return nil
}
candidate := rand.Intn(serversCount)
if candidate == 0 {
return &serversInfo.inner[candidate]
}
if serversInfo.inner[candidate].rtt.Value() < serversInfo.inner[0].rtt.Value() {
serversInfo.inner[candidate], serversInfo.inner[0] = serversInfo.inner[0], serversInfo.inner[candidate]
}
switch serversInfo.lbStrategy {
case LBStrategyFastest:
candidate = 0
case LBStrategyPH:
candidate = rand.Intn(Min(Min(serversCount, 2), len(serversInfo.inner)/2))
case LBStrategyRandom:
candidate = rand.Intn(len(serversInfo.inner))
default:
candidate = rand.Intn(Min(Min(serversCount, 2), len(serversInfo.inner)))
}
serverInfo := &serversInfo.inner[candidate]
return serverInfo
}
func (serversInfo *ServersInfo) fetchServerInfo(proxy *Proxy, name string, stamp ServerStamp, isNew bool) (ServerInfo, error) {
if stamp.proto == StampProtoTypeDNSCrypt {
return serversInfo.fetchDNSCryptServerInfo(proxy, name, stamp, isNew)
} else if stamp.proto == StampProtoTypeDoH {
return serversInfo.fetchDoHServerInfo(proxy, name, stamp, isNew)
}
return ServerInfo{}, errors.New("Unsupported protocol")
}
func (serversInfo *ServersInfo) fetchDNSCryptServerInfo(proxy *Proxy, name string, stamp ServerStamp, isNew bool) (ServerInfo, error) {
if len(stamp.serverPk) != ed25519.PublicKeySize {
serverPk, err := hex.DecodeString(strings.Replace(string(stamp.serverPk), ":", "", -1))
if err != nil || len(serverPk) != ed25519.PublicKeySize {
dlog.Fatalf("Unsupported public key for [%s]: [%s]", name, stamp.serverPk)
}
dlog.Warnf("Public key [%s] shouldn't be hex-encoded any more", string(stamp.serverPk))
stamp.serverPk = serverPk
}
certInfo, rtt, err := FetchCurrentDNSCryptCert(proxy, &name, proxy.mainProto, stamp.serverPk, stamp.serverAddrStr, stamp.providerName, isNew)
if err != nil {
return ServerInfo{}, err
}
remoteUDPAddr, err := net.ResolveUDPAddr("udp", stamp.serverAddrStr)
if err != nil {
return ServerInfo{}, err
}
remoteTCPAddr, err := net.ResolveTCPAddr("tcp", stamp.serverAddrStr)
if err != nil {
return ServerInfo{}, err
}
serverInfo := ServerInfo{
Proto: StampProtoTypeDNSCrypt,
MagicQuery: certInfo.MagicQuery,
ServerPk: certInfo.ServerPk,
SharedKey: certInfo.SharedKey,
CryptoConstruction: certInfo.CryptoConstruction,
Name: name,
Timeout: proxy.timeout,
UDPAddr: remoteUDPAddr,
TCPAddr: remoteTCPAddr,
initialRtt: rtt,
}
return serverInfo, nil
}
func (serversInfo *ServersInfo) fetchDoHServerInfo(proxy *Proxy, name string, stamp ServerStamp, isNew bool) (ServerInfo, error) {
if len(stamp.serverAddrStr) > 0 {
addrStr := stamp.serverAddrStr
ipOnly := addrStr[:strings.LastIndex(addrStr, ":")]
proxy.xTransport.cachedIPs.Lock()
proxy.xTransport.cachedIPs.cache[stamp.providerName] = ipOnly
proxy.xTransport.cachedIPs.Unlock()
}
url := &url.URL{
Scheme: "https",
Host: stamp.providerName,
Path: stamp.path,
}
body := []byte{
0xca, 0xfe, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x29, 0x10, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00,
}
useGet := false
if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil {
useGet = true
if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil {
return ServerInfo{}, err
}
dlog.Debugf("Server [%s] doesn't appear to support POST; falling back to GET requests", name)
}
resp, rtt, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout)
if err != nil {
return ServerInfo{}, err
}
tls := resp.TLS
if tls == nil || !tls.HandshakeComplete {
return ServerInfo{}, errors.New("TLS handshake failed")
}
showCerts := len(os.Getenv("SHOW_CERTS")) > 0
found := false
var wantedHash [32]byte
for _, cert := range tls.PeerCertificates {
h := sha256.Sum256(cert.RawTBSCertificate)
if showCerts {
dlog.Infof("Advertised cert: [%s] [%x]", cert.Subject, h)
} else {
dlog.Debugf("Advertised cert: [%s] [%x]", cert.Subject, h)
}
for _, hash := range stamp.hashes {
if len(hash) == len(wantedHash) {
copy(wantedHash[:], hash)
if h == wantedHash {
found = true
break
}
}
}
if found {
break
}
}
if !found && len(stamp.hashes) > 0 {
return ServerInfo{}, fmt.Errorf("Certificate hash [%x] not found for [%s]", wantedHash, name)
}
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength))
if err != nil {
return ServerInfo{}, err
}
if len(respBody) < MinDNSPacketSize || len(respBody) > MaxDNSPacketSize ||
respBody[0] != 0xca || respBody[1] != 0xfe || respBody[4] != 0x00 || respBody[5] != 0x01 {
return ServerInfo{}, errors.New("Webserver returned an unexpected response")
}
if isNew {
dlog.Noticef("[%s] OK (DoH) - rtt: %dms", name, rtt.Nanoseconds()/1000000)
} else {
dlog.Infof("[%s] OK (DoH) - rtt: %dms", name, rtt.Nanoseconds()/1000000)
}
serverInfo := ServerInfo{
Proto: StampProtoTypeDoH,
Name: name,
Timeout: proxy.timeout,
URL: url,
HostName: stamp.providerName,
initialRtt: int(rtt.Nanoseconds() / 1000000),
useGet: useGet,
}
return serverInfo, nil
}
func (serverInfo *ServerInfo) noticeFailure(proxy *Proxy) {
serverInfo.Lock()
serverInfo.rtt.Set(float64(proxy.timeout.Nanoseconds()))
serverInfo.Unlock()
}
func (serverInfo *ServerInfo) noticeBegin(proxy *Proxy) {
serverInfo.Lock()
serverInfo.lastActionTS = time.Now()
serverInfo.Unlock()
}
func (serverInfo *ServerInfo) noticeSuccess(proxy *Proxy) {
now := time.Now()
serverInfo.Lock()
elapsed := now.Sub(serverInfo.lastActionTS) / 1024
if elapsed > 0 {
serverInfo.rtt.Add(float64(elapsed))
}
serverInfo.Unlock()
}