package main import ( "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "io/ioutil" "math/rand" "net" "net/url" "os" "strings" "sync" "time" "github.com/VividCortex/ewma" "github.com/jedisct1/dlog" "golang.org/x/crypto/ed25519" ) const ( RTTEwmaDecay = 10.0 DefaultPort = 443 ) type ServerInformalProperties uint64 const ( ServerInformalPropertyDNSSEC = ServerInformalProperties(1) << 0 ServerInformalPropertyNoLog = ServerInformalProperties(1) << 1 ServerInformalPropertyNoFilter = ServerInformalProperties(1) << 2 ) type RegisteredServer struct { name string stamp ServerStamp description string } type ServerInfo struct { sync.RWMutex Proto StampProtoType MagicQuery [8]byte ServerPk [32]byte SharedKey [32]byte CryptoConstruction CryptoConstruction Name string Timeout time.Duration URL *url.URL HostName string UDPAddr *net.UDPAddr TCPAddr *net.TCPAddr lastActionTS time.Time rtt ewma.MovingAverage initialRtt int useGet bool } type LBStrategy int const ( LBStrategyNone = LBStrategy(iota) LBStrategyP2 LBStrategyPH LBStrategyFastest LBStrategyRandom ) const DefaultLBStrategy = LBStrategyP2 type ServersInfo struct { sync.RWMutex inner []ServerInfo registeredServers []RegisteredServer lbStrategy LBStrategy } func (serversInfo *ServersInfo) registerServer(proxy *Proxy, name string, stamp ServerStamp) error { newRegisteredServer := RegisteredServer{name: name, stamp: stamp} serversInfo.Lock() defer serversInfo.Unlock() for i, oldRegisteredServer := range serversInfo.registeredServers { if oldRegisteredServer.name == name { serversInfo.registeredServers[i] = newRegisteredServer return nil } } serversInfo.registeredServers = append(serversInfo.registeredServers, newRegisteredServer) return nil } func (serversInfo *ServersInfo) refreshServer(proxy *Proxy, name string, stamp ServerStamp) error { serversInfo.Lock() defer serversInfo.Unlock() previousIndex := -1 for i, oldServer := range serversInfo.inner { if oldServer.Name == name { previousIndex = i break } } newServer, err := serversInfo.fetchServerInfo(proxy, name, stamp, previousIndex < 0) if err != nil { return err } if name != newServer.Name { dlog.Fatalf("[%s] != [%s]", name, newServer.Name) } newServer.rtt = ewma.NewMovingAverage(RTTEwmaDecay) if previousIndex >= 0 { serversInfo.inner[previousIndex] = newServer return nil } serversInfo.inner = append(serversInfo.inner, newServer) serversInfo.registeredServers = append(serversInfo.registeredServers, RegisteredServer{name: name, stamp: stamp}) return nil } func (serversInfo *ServersInfo) refresh(proxy *Proxy) (int, error) { dlog.Debug("Refreshing certificates") serversInfo.RLock() registeredServers := serversInfo.registeredServers serversInfo.RUnlock() liveServers := 0 var err error for _, registeredServer := range registeredServers { if err = serversInfo.refreshServer(proxy, registeredServer.name, registeredServer.stamp); err == nil { liveServers++ } } serversInfo.Lock() inner := serversInfo.inner innerLen := len(inner) for i := 0; i < innerLen; i++ { for j := i + 1; j < innerLen; j++ { if inner[j].initialRtt < inner[i].initialRtt { inner[j], inner[i] = inner[i], inner[j] } } } serversInfo.inner = inner if innerLen > 0 { dlog.Noticef("Server with the lowest initial latency: %s (rtt: %dms)", inner[0].Name, inner[0].initialRtt) proxy.certIgnoreTimestamp = false } serversInfo.Unlock() return liveServers, err } func (serversInfo *ServersInfo) liveServers() int { serversInfo.RLock() liveServers := len(serversInfo.inner) serversInfo.RUnlock() return liveServers } func (serversInfo *ServersInfo) getOne() *ServerInfo { serversInfo.Lock() defer serversInfo.Unlock() serversCount := len(serversInfo.inner) if serversCount <= 0 { return nil } candidate := rand.Intn(serversCount) if candidate == 0 { return &serversInfo.inner[candidate] } candidateRtt, currentBestRtt := serversInfo.inner[candidate].rtt.Value(), serversInfo.inner[0].rtt.Value() if currentBestRtt < 0 { currentBestRtt = candidateRtt serversInfo.inner[0].rtt.Set(currentBestRtt) } partialSort := false if candidateRtt < currentBestRtt { serversInfo.inner[candidate], serversInfo.inner[0] = serversInfo.inner[0], serversInfo.inner[candidate] partialSort = true dlog.Debugf("New prefered candidate: %v (rtt: %v vs previous: %v)", serversInfo.inner[0].Name, candidateRtt, currentBestRtt) } else if candidateRtt >= currentBestRtt*4.0 { if time.Since(serversInfo.inner[candidate].lastActionTS) > time.Duration(1*time.Minute) { serversInfo.inner[candidate].rtt.Add(MinF(MaxF(candidateRtt/2.0, currentBestRtt*2.0), candidateRtt)) partialSort = true } } if partialSort { for i := 1; i < serversCount; i++ { if serversInfo.inner[i-1].rtt.Value() > serversInfo.inner[i].rtt.Value() { serversInfo.inner[i-1], serversInfo.inner[i] = serversInfo.inner[i], serversInfo.inner[i-1] } } } switch serversInfo.lbStrategy { case LBStrategyFastest: candidate = 0 case LBStrategyPH: candidate = rand.Intn(Min(Min(serversCount, 2), serversCount/2)) case LBStrategyRandom: candidate = rand.Intn(serversCount) default: candidate = rand.Intn(Min(serversCount, 2)) } serverInfo := &serversInfo.inner[candidate] return serverInfo } func (serversInfo *ServersInfo) fetchServerInfo(proxy *Proxy, name string, stamp ServerStamp, isNew bool) (ServerInfo, error) { if stamp.proto == StampProtoTypeDNSCrypt { return serversInfo.fetchDNSCryptServerInfo(proxy, name, stamp, isNew) } else if stamp.proto == StampProtoTypeDoH { return serversInfo.fetchDoHServerInfo(proxy, name, stamp, isNew) } return ServerInfo{}, errors.New("Unsupported protocol") } func (serversInfo *ServersInfo) fetchDNSCryptServerInfo(proxy *Proxy, name string, stamp ServerStamp, isNew bool) (ServerInfo, error) { if len(stamp.serverPk) != ed25519.PublicKeySize { serverPk, err := hex.DecodeString(strings.Replace(string(stamp.serverPk), ":", "", -1)) if err != nil || len(serverPk) != ed25519.PublicKeySize { dlog.Fatalf("Unsupported public key for [%s]: [%s]", name, stamp.serverPk) } dlog.Warnf("Public key [%s] shouldn't be hex-encoded any more", string(stamp.serverPk)) stamp.serverPk = serverPk } certInfo, rtt, err := FetchCurrentDNSCryptCert(proxy, &name, proxy.mainProto, stamp.serverPk, stamp.serverAddrStr, stamp.providerName, isNew) if err != nil { return ServerInfo{}, err } remoteUDPAddr, err := net.ResolveUDPAddr("udp", stamp.serverAddrStr) if err != nil { return ServerInfo{}, err } remoteTCPAddr, err := net.ResolveTCPAddr("tcp", stamp.serverAddrStr) if err != nil { return ServerInfo{}, err } serverInfo := ServerInfo{ Proto: StampProtoTypeDNSCrypt, MagicQuery: certInfo.MagicQuery, ServerPk: certInfo.ServerPk, SharedKey: certInfo.SharedKey, CryptoConstruction: certInfo.CryptoConstruction, Name: name, Timeout: proxy.timeout, UDPAddr: remoteUDPAddr, TCPAddr: remoteTCPAddr, initialRtt: rtt, } return serverInfo, nil } func (serversInfo *ServersInfo) fetchDoHServerInfo(proxy *Proxy, name string, stamp ServerStamp, isNew bool) (ServerInfo, error) { if len(stamp.serverAddrStr) > 0 { addrStr := stamp.serverAddrStr ipOnly := addrStr[:strings.LastIndex(addrStr, ":")] proxy.xTransport.cachedIPs.Lock() proxy.xTransport.cachedIPs.cache[stamp.providerName] = ipOnly proxy.xTransport.cachedIPs.Unlock() } url := &url.URL{ Scheme: "https", Host: stamp.providerName, Path: stamp.path, } body := []byte{ 0xca, 0xfe, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x29, 0x10, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, } useGet := false if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil { useGet = true if _, _, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout); err != nil { return ServerInfo{}, err } dlog.Debugf("Server [%s] doesn't appear to support POST; falling back to GET requests", name) } resp, rtt, err := proxy.xTransport.DoHQuery(useGet, url, body, proxy.timeout) if err != nil { return ServerInfo{}, err } tls := resp.TLS if tls == nil || !tls.HandshakeComplete { return ServerInfo{}, errors.New("TLS handshake failed") } showCerts := len(os.Getenv("SHOW_CERTS")) > 0 found := false var wantedHash [32]byte for _, cert := range tls.PeerCertificates { h := sha256.Sum256(cert.RawTBSCertificate) if showCerts { dlog.Infof("Advertised cert: [%s] [%x]", cert.Subject, h) } else { dlog.Debugf("Advertised cert: [%s] [%x]", cert.Subject, h) } for _, hash := range stamp.hashes { if len(hash) == len(wantedHash) { copy(wantedHash[:], hash) if h == wantedHash { found = true break } } } if found { break } } if !found && len(stamp.hashes) > 0 { return ServerInfo{}, fmt.Errorf("Certificate hash [%x] not found for [%s]", wantedHash, name) } respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, MaxHTTPBodyLength)) if err != nil { return ServerInfo{}, err } if len(respBody) < MinDNSPacketSize || len(respBody) > MaxDNSPacketSize || respBody[0] != 0xca || respBody[1] != 0xfe || respBody[4] != 0x00 || respBody[5] != 0x01 { return ServerInfo{}, errors.New("Webserver returned an unexpected response") } if isNew { dlog.Noticef("[%s] OK (DoH) - rtt: %dms", name, rtt.Nanoseconds()/1000000) } else { dlog.Infof("[%s] OK (DoH) - rtt: %dms", name, rtt.Nanoseconds()/1000000) } serverInfo := ServerInfo{ Proto: StampProtoTypeDoH, Name: name, Timeout: proxy.timeout, URL: url, HostName: stamp.providerName, initialRtt: int(rtt.Nanoseconds() / 1000000), useGet: useGet, } return serverInfo, nil } func (serverInfo *ServerInfo) noticeFailure(proxy *Proxy) { serverInfo.Lock() serverInfo.rtt.Add(float64(proxy.timeout.Nanoseconds() / 1000000)) serverInfo.Unlock() } func (serverInfo *ServerInfo) noticeBegin(proxy *Proxy) { serverInfo.Lock() serverInfo.lastActionTS = time.Now() serverInfo.Unlock() } func (serverInfo *ServerInfo) noticeSuccess(proxy *Proxy) { now := time.Now() serverInfo.Lock() elapsed := now.Sub(serverInfo.lastActionTS) elapsedMs := elapsed.Nanoseconds() / 1000000 if elapsedMs > 0 && elapsed < proxy.timeout { serverInfo.rtt.Add(float64(elapsedMs)) } serverInfo.Unlock() }