diff --git a/dnscrypt-proxy/config.go b/dnscrypt-proxy/config.go index 5c8abbea..bc959dd2 100644 --- a/dnscrypt-proxy/config.go +++ b/dnscrypt-proxy/config.go @@ -358,22 +358,29 @@ func ConfigLoad(proxy *Proxy, flags *ConfigFlags) error { if len(config.ListenAddresses) == 0 && len(config.LocalDoH.ListenAddresses) == 0 { dlog.Debug("No local IP/port configured") } - - lbStrategy := DefaultLBStrategy - switch strings.ToLower(config.LBStrategy) { + lbStrategy := LBStrategy(DefaultLBStrategy) + switch lbStrategyStr := strings.ToLower(config.LBStrategy); lbStrategyStr { case "": // default case "p2": - lbStrategy = LBStrategyP2 + lbStrategy = LBStrategyP2{} case "ph": - lbStrategy = LBStrategyPH + lbStrategy = LBStrategyPH{} case "fastest": case "first": - lbStrategy = LBStrategyFirst + lbStrategy = LBStrategyFirst{} case "random": - lbStrategy = LBStrategyRandom + lbStrategy = LBStrategyRandom{} default: - dlog.Warnf("Unknown load balancing strategy: [%s]", config.LBStrategy) + if strings.HasPrefix(lbStrategyStr, "p") { + n, err := strconv.ParseInt(strings.TrimPrefix(lbStrategyStr, "p"), 10, 32) + if err != nil || n <= 0 { + dlog.Fatalf("Invalid load balancing strategy: [%s]", config.LBStrategy) + } + lbStrategy = LBStrategyPN{n: int(n)} + } else { + dlog.Warnf("Unknown load balancing strategy: [%s]", config.LBStrategy) + } } proxy.serversInfo.lbStrategy = lbStrategy proxy.serversInfo.lbEstimator = config.LBEstimator diff --git a/dnscrypt-proxy/serversInfo.go b/dnscrypt-proxy/serversInfo.go index 6327f7c0..66fb48ff 100644 --- a/dnscrypt-proxy/serversInfo.go +++ b/dnscrypt-proxy/serversInfo.go @@ -62,17 +62,41 @@ type ServerInfo struct { DOHClientCreds DOHClientCreds } -type LBStrategy int +type LBStrategy interface { + getCandidate(serversCount int) int +} -const ( - LBStrategyNone = LBStrategy(iota) - LBStrategyP2 - LBStrategyPH - LBStrategyFirst - LBStrategyRandom -) +type LBStrategyP2 struct{} -const DefaultLBStrategy = LBStrategyP2 +func (LBStrategyP2) getCandidate(serversCount int) int { + return rand.Intn(Min(serversCount, 2)) +} + +type LBStrategyPN struct{ n int } + +func (s LBStrategyPN) getCandidate(serversCount int) int { + return rand.Intn(Min(serversCount, s.n)) +} + +type LBStrategyPH struct{} + +func (LBStrategyPH) getCandidate(serversCount int) int { + return rand.Intn(Max(Min(serversCount, 2), serversCount/2)) +} + +type LBStrategyFirst struct{} + +func (LBStrategyFirst) getCandidate(int) int { + return 0 +} + +type LBStrategyRandom struct{} + +func (LBStrategyRandom) getCandidate(serversCount int) int { + return rand.Intn(serversCount) +} + +var DefaultLBStrategy = LBStrategyP2{} type ServersInfo struct { sync.RWMutex @@ -209,17 +233,7 @@ func (serversInfo *ServersInfo) getOne() *ServerInfo { if serversInfo.lbEstimator { serversInfo.estimatorUpdate() } - var candidate int - switch serversInfo.lbStrategy { - case LBStrategyFirst: - candidate = 0 - case LBStrategyPH: - candidate = rand.Intn(Max(Min(serversCount, 2), serversCount/2)) - case LBStrategyRandom: - candidate = rand.Intn(serversCount) - default: - candidate = rand.Intn(Min(serversCount, 2)) - } + candidate := serversInfo.lbStrategy.getCandidate(serversCount) serverInfo := serversInfo.inner[candidate] dlog.Debugf("Using candidate [%s] RTT: %d", (*serverInfo).Name, int((*serverInfo).rtt.Value())) serversInfo.Unlock()