refactoring of pull 980

follow up on https://github.com/DNSCrypt/dnscrypt-proxy/pull/980#issuecomment-548153169
This commit is contained in:
Vladimir Bauer 2019-10-31 16:02:20 +05:00 committed by Frank Denis
parent 9eae8de902
commit 6fa420a8e0
5 changed files with 71 additions and 71 deletions

View File

@ -8,7 +8,6 @@ import (
"math/rand" "math/rand"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"github.com/facebookgo/pidfile" "github.com/facebookgo/pidfile"
@ -22,8 +21,6 @@ const (
) )
type App struct { type App struct {
wg sync.WaitGroup
quit chan struct{}
proxy *Proxy proxy *Proxy
flags *ConfigFlags flags *ConfigFlags
} }
@ -71,7 +68,6 @@ func main() {
} }
app := &App{ app := &App{
quit: make(chan struct{}),
flags: &flags, flags: &flags,
} }
svc, err := service.New(app, svcConfig) svc, err := service.New(app, svcConfig)
@ -101,7 +97,6 @@ func main() {
} }
return return
} }
app.wg.Add(1)
if svc != nil { if svc != nil {
if err = svc.Run(); err != nil { if err = svc.Run(); err != nil {
dlog.Fatal(err) dlog.Fatal(err)
@ -113,7 +108,7 @@ func main() {
app.signalWatch() app.signalWatch()
app.Start(nil) app.Start(nil)
} }
app.wg.Wait() app.proxy.ConnCloseWait()
dlog.Notice("Stopped.") dlog.Notice("Stopped.")
} }
@ -135,14 +130,13 @@ func (app *App) Stop(service service.Service) error {
os.Remove(pidFilePath) os.Remove(pidFilePath)
} }
dlog.Notice("Quit signal received...") dlog.Notice("Quit signal received...")
close(app.quit) app.proxy.Stop()
return nil return nil
} }
func (app *App) appMain() { func (app *App) appMain() {
pidfile.Write() pidfile.Write()
app.proxy.StartProxy(app.quit) app.proxy.StartProxy()
app.wg.Done()
} }
func (app *App) signalWatch() { func (app *App) signalWatch() {
@ -151,6 +145,6 @@ func (app *App) signalWatch() {
go func() { go func() {
<-quit <-quit
signal.Stop(quit) signal.Stop(quit)
close(app.quit) app.proxy.Stop()
}() }()
} }

View File

@ -7,6 +7,7 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -72,9 +73,13 @@ type Proxy struct {
queryMeta []string queryMeta []string
routes *map[string][]string routes *map[string][]string
showCerts bool showCerts bool
wg *sync.WaitGroup
quit chan struct{}
} }
func (proxy *Proxy) StartProxy(quit <-chan struct{}) { // StartProxy is blocking
func (proxy *Proxy) StartProxy() {
proxy.questionSizeEstimator = NewQuestionSizeEstimator() proxy.questionSizeEstimator = NewQuestionSizeEstimator()
if _, err := crypto_rand.Read(proxy.proxySecretKey[:]); err != nil { if _, err := crypto_rand.Read(proxy.proxySecretKey[:]); err != nil {
dlog.Fatal(err) dlog.Fatal(err)
@ -96,16 +101,12 @@ func (proxy *Proxy) StartProxy(quit <-chan struct{}) {
// if 'userName' is not set, continue as before // if 'userName' is not set, continue as before
if !(len(proxy.userName) > 0) { if !(len(proxy.userName) > 0) {
udpCloser, err := proxy.udpListenerFromAddr(listenUDPAddr) if err := proxy.udpListenerFromAddr(listenUDPAddr); err != nil {
if err != nil {
dlog.Fatal(err) dlog.Fatal(err)
} }
tcpCloser, err := proxy.tcpListenerFromAddr(listenTCPAddr) if err := proxy.tcpListenerFromAddr(listenTCPAddr); err != nil {
if err != nil {
dlog.Fatal(err) dlog.Fatal(err)
} }
defer udpCloser.Close()
defer tcpCloser.Close()
} else { } else {
// if 'userName' is set and we are the parent process // if 'userName' is set and we are the parent process
if !proxy.child { if !proxy.child {
@ -127,8 +128,13 @@ func (proxy *Proxy) StartProxy(quit <-chan struct{}) {
if err != nil { if err != nil {
dlog.Fatalf("Unable to switch to a different user: %v", err) dlog.Fatalf("Unable to switch to a different user: %v", err)
} }
defer listenerUDP.Close() proxy.wg.Add(1)
defer listenerTCP.Close() go func() {
defer proxy.wg.Done()
<-proxy.quit
listenerUDP.Close()
listenerTCP.Close()
}()
FileDescriptors = append(FileDescriptors, fdUDP) FileDescriptors = append(FileDescriptors, fdUDP)
FileDescriptors = append(FileDescriptors, fdTCP) FileDescriptors = append(FileDescriptors, fdTCP)
@ -148,9 +154,11 @@ func (proxy *Proxy) StartProxy(quit <-chan struct{}) {
FileDescriptorNum++ FileDescriptorNum++
dlog.Noticef("Now listening to %v [UDP]", listenUDPAddr) dlog.Noticef("Now listening to %v [UDP]", listenUDPAddr)
proxy.wg.Add(1)
go proxy.udpListener(listenerUDP.(*net.UDPConn)) go proxy.udpListener(listenerUDP.(*net.UDPConn))
dlog.Noticef("Now listening to %v [TCP]", listenAddrStr) dlog.Noticef("Now listening to %v [TCP]", listenAddrStr)
proxy.wg.Add(1)
go proxy.tcpListener(listenerTCP.(*net.TCPListener)) go proxy.tcpListener(listenerTCP.(*net.TCPListener))
} }
} }
@ -160,11 +168,9 @@ func (proxy *Proxy) StartProxy(quit <-chan struct{}) {
if len(proxy.userName) > 0 && !proxy.child { if len(proxy.userName) > 0 && !proxy.child {
proxy.dropPrivilege(proxy.userName, FileDescriptors) proxy.dropPrivilege(proxy.userName, FileDescriptors)
} }
sdc, err := proxy.SystemDListeners() if err := proxy.SystemDListeners(); err != nil {
if err != nil {
dlog.Fatal(err) dlog.Fatal(err)
} }
defer sdc.Close()
liveServers, err := proxy.serversInfo.refresh(proxy) liveServers, err := proxy.serversInfo.refresh(proxy)
if liveServers > 0 { if liveServers > 0 {
proxy.certIgnoreTimestamp = false proxy.certIgnoreTimestamp = false
@ -185,21 +191,18 @@ func (proxy *Proxy) StartProxy(quit <-chan struct{}) {
} }
go proxy.prefetcher() go proxy.prefetcher()
if len(proxy.serversInfo.registeredServers) > 0 { if len(proxy.serversInfo.registeredServers) > 0 {
go func() { for {
for { delay := proxy.certRefreshDelay
delay := proxy.certRefreshDelay if liveServers == 0 {
if liveServers == 0 { delay = proxy.certRefreshDelayAfterFailure
delay = proxy.certRefreshDelayAfterFailure
}
clocksmith.Sleep(delay)
liveServers, _ = proxy.serversInfo.refresh(proxy)
if liveServers > 0 {
proxy.certIgnoreTimestamp = false
}
} }
}() clocksmith.Sleep(delay)
liveServers, _ = proxy.serversInfo.refresh(proxy)
if liveServers > 0 {
proxy.certIgnoreTimestamp = false
}
}
} }
<-quit
} }
func (proxy *Proxy) prefetcher() { func (proxy *Proxy) prefetcher() {
@ -220,6 +223,11 @@ func (proxy *Proxy) prefetcher() {
} }
func (proxy *Proxy) udpListener(clientPc *net.UDPConn) { func (proxy *Proxy) udpListener(clientPc *net.UDPConn) {
go func() {
defer proxy.wg.Done()
<-proxy.quit
clientPc.Close()
}()
for { for {
buffer := make([]byte, MaxDNSPacketSize-1) buffer := make([]byte, MaxDNSPacketSize-1)
length, clientAddr, err := clientPc.ReadFrom(buffer) length, clientAddr, err := clientPc.ReadFrom(buffer)
@ -239,17 +247,23 @@ func (proxy *Proxy) udpListener(clientPc *net.UDPConn) {
} }
} }
func (proxy *Proxy) udpListenerFromAddr(listenAddr *net.UDPAddr) (io.Closer, error) { func (proxy *Proxy) udpListenerFromAddr(listenAddr *net.UDPAddr) error {
clientPc, err := net.ListenUDP("udp", listenAddr) clientPc, err := net.ListenUDP("udp", listenAddr)
if err != nil { if err != nil {
return nil, err return err
} }
dlog.Noticef("Now listening to %v [UDP]", listenAddr) dlog.Noticef("Now listening to %v [UDP]", listenAddr)
proxy.wg.Add(1)
go proxy.udpListener(clientPc) go proxy.udpListener(clientPc)
return clientPc, nil return nil
} }
func (proxy *Proxy) tcpListener(acceptPc *net.TCPListener) { func (proxy *Proxy) tcpListener(acceptPc *net.TCPListener) {
go func() {
defer proxy.wg.Done()
<-proxy.quit
acceptPc.Close()
}()
for { for {
clientPc, err := acceptPc.Accept() clientPc, err := acceptPc.Accept()
if err != nil { if err != nil {
@ -274,14 +288,15 @@ func (proxy *Proxy) tcpListener(acceptPc *net.TCPListener) {
} }
} }
func (proxy *Proxy) tcpListenerFromAddr(listenAddr *net.TCPAddr) (io.Closer, error) { func (proxy *Proxy) tcpListenerFromAddr(listenAddr *net.TCPAddr) error {
acceptPc, err := net.ListenTCP("tcp", listenAddr) acceptPc, err := net.ListenTCP("tcp", listenAddr)
if err != nil { if err != nil {
return nil, err return err
} }
dlog.Noticef("Now listening to %v [TCP]", listenAddr) dlog.Noticef("Now listening to %v [TCP]", listenAddr)
proxy.wg.Add(1)
go proxy.tcpListener(acceptPc) go proxy.tcpListener(acceptPc)
return acceptPc, nil return nil
} }
func (proxy *Proxy) prepareForRelay(ip net.IP, port int, encryptedQuery *[]byte) { func (proxy *Proxy) prepareForRelay(ip net.IP, port int, encryptedQuery *[]byte) {
@ -510,8 +525,20 @@ func (proxy *Proxy) processIncomingQuery(serverInfo *ServerInfo, clientProto str
} }
} }
func (proxy *Proxy) Stop() {
if proxy.quit != nil {
close(proxy.quit)
}
}
func (proxy *Proxy) ConnCloseWait() {
proxy.wg.Wait()
}
func NewProxy() *Proxy { func NewProxy() *Proxy {
return &Proxy{ return &Proxy{
serversInfo: NewServersInfo(), serversInfo: NewServersInfo(),
wg: new(sync.WaitGroup),
quit: make(chan struct{}),
} }
} }

View File

@ -2,11 +2,6 @@
package main package main
import ( func (proxy *Proxy) SystemDListeners() error {
"io" return nil
"io/ioutil"
)
func (proxy *Proxy) SystemDListeners() (io.Closer, error) {
return ioutil.NopCloser(nil), nil
} }

View File

@ -2,11 +2,6 @@
package main package main
import ( func (proxy *Proxy) SystemDListeners() error {
"io" return nil
"io/ioutil"
)
func (proxy *Proxy) SystemDListeners() (io.Closer, error) {
return ioutil.NopCloser(nil), nil
} }

View File

@ -4,23 +4,13 @@ package main
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"github.com/coreos/go-systemd/activation" "github.com/coreos/go-systemd/activation"
"github.com/jedisct1/dlog" "github.com/jedisct1/dlog"
) )
type multiCloser []io.Closer func (proxy *Proxy) SystemDListeners() error {
func (mc multiCloser) Close() (err error) {
for _, c := range mc {
err = c.Close()
}
return err
}
func (proxy *Proxy) SystemDListeners() (io.Closer, error) {
files := activation.Files(true) files := activation.Files(true)
if len(files) > 0 { if len(files) > 0 {
@ -29,25 +19,24 @@ func (proxy *Proxy) SystemDListeners() (io.Closer, error) {
} }
dlog.Warn("Systemd sockets are untested and unsupported - use at your own risk") dlog.Warn("Systemd sockets are untested and unsupported - use at your own risk")
} }
var mc multiCloser
for i, file := range files { for i, file := range files {
defer file.Close() defer file.Close()
ok := false ok := false
if listener, err := net.FileListener(file); err == nil { if listener, err := net.FileListener(file); err == nil {
dlog.Noticef("Wiring systemd TCP socket #%d, %s, %s", i, file.Name(), listener.Addr()) dlog.Noticef("Wiring systemd TCP socket #%d, %s, %s", i, file.Name(), listener.Addr())
ok = true ok = true
mc = append(mc, listener) proxy.wg.Add(1)
go proxy.tcpListener(listener.(*net.TCPListener)) go proxy.tcpListener(listener.(*net.TCPListener))
} else if pc, err := net.FilePacketConn(file); err == nil { } else if pc, err := net.FilePacketConn(file); err == nil {
dlog.Noticef("Wiring systemd UDP socket #%d, %s, %s", i, file.Name(), pc.LocalAddr()) dlog.Noticef("Wiring systemd UDP socket #%d, %s, %s", i, file.Name(), pc.LocalAddr())
ok = true ok = true
mc = append(mc, pc) proxy.wg.Add(1)
go proxy.udpListener(pc.(*net.UDPConn)) go proxy.udpListener(pc.(*net.UDPConn))
} }
if !ok { if !ok {
return nil, fmt.Errorf("Could not wire systemd socket #%d, %s", i, file.Name()) return fmt.Errorf("Could not wire systemd socket #%d, %s", i, file.Name())
} }
} }
return mc, nil return nil
} }