diff --git a/dnscrypt-proxy/coldstart.go b/dnscrypt-proxy/coldstart.go index d47a3750..b2c3d8c2 100644 --- a/dnscrypt-proxy/coldstart.go +++ b/dnscrypt-proxy/coldstart.go @@ -17,14 +17,13 @@ type CaptivePortalMap map[string]CaptivePortalEntryIPs type CaptivePortalHandler struct { cancelChannel chan struct{} countChannel chan struct{} + waitChannel chan struct{} channelCount int } func (captivePortalHandler *CaptivePortalHandler) Stop() { close(captivePortalHandler.cancelChannel) - for len(captivePortalHandler.countChannel) < captivePortalHandler.channelCount { - time.Sleep(10 * time.Millisecond) - } + <-captivePortalHandler.waitChannel close(captivePortalHandler.countChannel) } @@ -138,6 +137,9 @@ func addColdStartListener( } clientPc.Close() captivePortalHandler.countChannel <- struct{}{} + if len(captivePortalHandler.countChannel) == captivePortalHandler.channelCount { + close(captivePortalHandler.waitChannel) + } }() return nil } @@ -186,6 +188,7 @@ func ColdStart(proxy *Proxy) (*CaptivePortalHandler, error) { captivePortalHandler := CaptivePortalHandler{ cancelChannel: make(chan struct{}), countChannel: make(chan struct{}, len(listenAddrStrs)), + waitChannel: make(chan struct{}), channelCount: 0, } for _, listenAddrStr := range listenAddrStrs { diff --git a/dnscrypt-proxy/proxy.go b/dnscrypt-proxy/proxy.go index c1b1e248..b2ef5e9a 100644 --- a/dnscrypt-proxy/proxy.go +++ b/dnscrypt-proxy/proxy.go @@ -517,7 +517,7 @@ func (proxy *Proxy) exchangeWithUDPServer( var pc net.Conn proxyDialer := proxy.xTransport.proxyDialer if proxyDialer == nil { - pc, err = net.DialUDP("udp", nil, upstreamAddr) + pc, err = net.DialTimeout("udp", upstreamAddr.String(), serverInfo.Timeout) } else { pc, err = (*proxyDialer).Dial("udp", upstreamAddr.String()) } @@ -560,7 +560,7 @@ func (proxy *Proxy) exchangeWithTCPServer( var pc net.Conn proxyDialer := proxy.xTransport.proxyDialer if proxyDialer == nil { - pc, err = net.DialTCP("tcp", nil, upstreamAddr) + pc, err = net.DialTimeout("tcp", upstreamAddr.String(), serverInfo.Timeout) } else { pc, err = (*proxyDialer).Dial("tcp", upstreamAddr.String()) }