Merge pull request #2549 from lifenjoiner/wg

Optimize CaptivePortalHandler for clean code
This commit is contained in:
lifenjoiner 2023-12-17 18:59:05 +08:00 committed by GitHub
commit 3be53642fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 15 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/jedisct1/dlog" "github.com/jedisct1/dlog"
@ -15,16 +16,13 @@ type CaptivePortalEntryIPs []net.IP
type CaptivePortalMap map[string]CaptivePortalEntryIPs type CaptivePortalMap map[string]CaptivePortalEntryIPs
type CaptivePortalHandler struct { type CaptivePortalHandler struct {
wg sync.WaitGroup
cancelChannel chan struct{} cancelChannel chan struct{}
countChannel chan struct{}
waitChannel chan struct{}
channelCount int
} }
func (captivePortalHandler *CaptivePortalHandler) Stop() { func (captivePortalHandler *CaptivePortalHandler) Stop() {
close(captivePortalHandler.cancelChannel) close(captivePortalHandler.cancelChannel)
<-captivePortalHandler.waitChannel captivePortalHandler.wg.Wait()
close(captivePortalHandler.countChannel)
} }
func (ipsMap *CaptivePortalMap) GetEntry(msg *dns.Msg) (*dns.Question, *CaptivePortalEntryIPs) { func (ipsMap *CaptivePortalMap) GetEntry(msg *dns.Msg) (*dns.Question, *CaptivePortalEntryIPs) {
@ -132,14 +130,12 @@ func addColdStartListener(
if err != nil { if err != nil {
return err return err
} }
captivePortalHandler.wg.Add(1)
go func() { go func() {
for !handleColdStartClient(clientPc, captivePortalHandler.cancelChannel, ipsMap) { for !handleColdStartClient(clientPc, captivePortalHandler.cancelChannel, ipsMap) {
} }
clientPc.Close() clientPc.Close()
captivePortalHandler.countChannel <- struct{}{} captivePortalHandler.wg.Done()
if len(captivePortalHandler.countChannel) == captivePortalHandler.channelCount {
close(captivePortalHandler.waitChannel)
}
}() }()
return nil return nil
} }
@ -187,15 +183,17 @@ func ColdStart(proxy *Proxy) (*CaptivePortalHandler, error) {
listenAddrStrs := proxy.listenAddresses listenAddrStrs := proxy.listenAddresses
captivePortalHandler := CaptivePortalHandler{ captivePortalHandler := CaptivePortalHandler{
cancelChannel: make(chan struct{}), cancelChannel: make(chan struct{}),
countChannel: make(chan struct{}, len(listenAddrStrs)),
waitChannel: make(chan struct{}),
channelCount: 0,
} }
ok := false
for _, listenAddrStr := range listenAddrStrs { for _, listenAddrStr := range listenAddrStrs {
if err := addColdStartListener(proxy, &ipsMap, listenAddrStr, &captivePortalHandler); err == nil { err = addColdStartListener(proxy, &ipsMap, listenAddrStr, &captivePortalHandler)
captivePortalHandler.channelCount++ if err == nil {
ok = true
} }
} }
if ok {
err = nil
}
proxy.captivePortalMap = &ipsMap proxy.captivePortalMap = &ipsMap
return &captivePortalHandler, nil return &captivePortalHandler, err
} }