diff --git a/dnscrypt-proxy/coldstart.go b/dnscrypt-proxy/coldstart.go index b2c3d8c2..8a2c459d 100644 --- a/dnscrypt-proxy/coldstart.go +++ b/dnscrypt-proxy/coldstart.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "strings" + "sync" "time" "github.com/jedisct1/dlog" @@ -15,16 +16,13 @@ type CaptivePortalEntryIPs []net.IP type CaptivePortalMap map[string]CaptivePortalEntryIPs type CaptivePortalHandler struct { + wg sync.WaitGroup cancelChannel chan struct{} - countChannel chan struct{} - waitChannel chan struct{} - channelCount int } func (captivePortalHandler *CaptivePortalHandler) Stop() { close(captivePortalHandler.cancelChannel) - <-captivePortalHandler.waitChannel - close(captivePortalHandler.countChannel) + captivePortalHandler.wg.Wait() } func (ipsMap *CaptivePortalMap) GetEntry(msg *dns.Msg) (*dns.Question, *CaptivePortalEntryIPs) { @@ -132,14 +130,12 @@ func addColdStartListener( if err != nil { return err } + captivePortalHandler.wg.Add(1) go func() { for !handleColdStartClient(clientPc, captivePortalHandler.cancelChannel, ipsMap) { } clientPc.Close() - captivePortalHandler.countChannel <- struct{}{} - if len(captivePortalHandler.countChannel) == captivePortalHandler.channelCount { - close(captivePortalHandler.waitChannel) - } + captivePortalHandler.wg.Done() }() return nil } @@ -187,15 +183,17 @@ func ColdStart(proxy *Proxy) (*CaptivePortalHandler, error) { listenAddrStrs := proxy.listenAddresses captivePortalHandler := CaptivePortalHandler{ cancelChannel: make(chan struct{}), - countChannel: make(chan struct{}, len(listenAddrStrs)), - waitChannel: make(chan struct{}), - channelCount: 0, } + ok := false for _, listenAddrStr := range listenAddrStrs { - if err := addColdStartListener(proxy, &ipsMap, listenAddrStr, &captivePortalHandler); err == nil { - captivePortalHandler.channelCount++ + err = addColdStartListener(proxy, &ipsMap, listenAddrStr, &captivePortalHandler) + if err == nil { + ok = true } } + if ok { + err = nil + } proxy.captivePortalMap = &ipsMap - return &captivePortalHandler, nil + return &captivePortalHandler, err }