diff --git a/dnscrypt-proxy/coldstart.go b/dnscrypt-proxy/coldstart.go index 09307f3e..209c3417 100644 --- a/dnscrypt-proxy/coldstart.go +++ b/dnscrypt-proxy/coldstart.go @@ -15,13 +15,17 @@ type CaptivePortalEntryIPs []net.IP type CaptivePortalMap map[string]CaptivePortalEntryIPs type CaptivePortalHandler struct { - cancelChannels []chan struct{} + cancelChannel chan struct{} + countChannel chan struct{} + channelCount int } func (captivePortalHandler *CaptivePortalHandler) Stop() { - for _, cancelChannel := range captivePortalHandler.cancelChannels { - close(cancelChannel) + close(captivePortalHandler.cancelChannel) + for len(captivePortalHandler.countChannel) < captivePortalHandler.channelCount { + time.Sleep(time.Millisecond) } + close(captivePortalHandler.countChannel) } func (ipsMap *CaptivePortalMap) GetEntry(msg *dns.Msg) (*dns.Question, *CaptivePortalEntryIPs) { @@ -119,7 +123,7 @@ func addColdStartListener( proxy *Proxy, ipsMap *CaptivePortalMap, listenAddrStr string, - cancelChannel chan struct{}, + captivePortalHandler *CaptivePortalHandler, ) error { listenUDPAddr, err := net.ResolveUDPAddr("udp", listenAddrStr) if err != nil { @@ -130,9 +134,10 @@ func addColdStartListener( return err } go func() { - for !handleColdStartClient(clientPc, cancelChannel, ipsMap) { + for !handleColdStartClient(clientPc, captivePortalHandler.cancelChannel, ipsMap) { } clientPc.Close() + captivePortalHandler.countChannel <- struct{}{} }() return nil } @@ -178,15 +183,15 @@ func ColdStart(proxy *Proxy) (*CaptivePortalHandler, error) { ipsMap[name] = ips } listenAddrStrs := proxy.listenAddresses - cancelChannels := make([]chan struct{}, 0) - for _, listenAddrStr := range listenAddrStrs { - cancelChannel := make(chan struct{}) - if err := addColdStartListener(proxy, &ipsMap, listenAddrStr, cancelChannel); err == nil { - cancelChannels = append(cancelChannels, cancelChannel) - } - } captivePortalHandler := CaptivePortalHandler{ - cancelChannels: cancelChannels, + cancelChannel: make(chan struct{}), + countChannel: make(chan struct{}, len(listenAddrStrs)), + channelCount: 0, + } + for _, listenAddrStr := range listenAddrStrs { + if err := addColdStartListener(proxy, &ipsMap, listenAddrStr, &captivePortalHandler); err == nil { + captivePortalHandler.channelCount++ + } } proxy.captivePortalMap = &ipsMap return &captivePortalHandler, nil