move flags parsing into main()

This commit is contained in:
Alison Winters 2019-10-30 21:30:31 +08:00 committed by Frank Denis
parent 116f985b96
commit 2f7e057996
2 changed files with 47 additions and 38 deletions

View File

@ -194,6 +194,17 @@ type ServerSummary struct {
Stamp string `json:"stamp"` Stamp string `json:"stamp"`
} }
type ConfigFlags struct {
List *bool
ListAll *bool
JsonOutput *bool
Check *bool
ConfigFile *string
Child *bool
NetprobeTimeoutOverride *int
ShowCerts *bool
}
func findConfigFile(configFile *string) (string, error) { func findConfigFile(configFile *string) (string, error) {
if _, err := os.Stat(*configFile); os.IsNotExist(err) { if _, err := os.Stat(*configFile); os.IsNotExist(err) {
cdLocal() cdLocal()
@ -211,35 +222,10 @@ func findConfigFile(configFile *string) (string, error) {
return path.Join(pwd, *configFile), nil return path.Join(pwd, *configFile), nil
} }
func ConfigLoad(proxy *Proxy, svcFlag *string) error { func ConfigLoad(proxy *Proxy, flags *ConfigFlags) error {
version := flag.Bool("version", false, "print current proxy version") foundConfigFile, err := findConfigFile(flags.ConfigFile)
resolve := flag.String("resolve", "", "resolve a name using system libraries")
list := flag.Bool("list", false, "print the list of available resolvers for the enabled filters")
listAll := flag.Bool("list-all", false, "print the complete list of available resolvers, ignoring filters")
jsonOutput := flag.Bool("json", false, "output list as JSON")
check := flag.Bool("check", false, "check the configuration file and exit")
configFile := flag.String("config", DefaultConfigFileName, "Path to the configuration file")
child := flag.Bool("child", false, "Invokes program as a child process")
netprobeTimeoutOverride := flag.Int("netprobe-timeout", 60, "Override the netprobe timeout")
showCerts := flag.Bool("show-certs", false, "print DoH certificate chain hashes")
flag.Parse()
if *svcFlag == "stop" || *svcFlag == "uninstall" {
return nil
}
if *version {
fmt.Println(AppVersion)
os.Exit(0)
}
if resolve != nil && len(*resolve) > 0 {
Resolve(*resolve)
os.Exit(0)
}
foundConfigFile, err := findConfigFile(configFile)
if err != nil { if err != nil {
dlog.Fatalf("Unable to load the configuration file [%s] -- Maybe use the -config command-line switch?", *configFile) dlog.Fatalf("Unable to load the configuration file [%s] -- Maybe use the -config command-line switch?", *flags.ConfigFile)
} }
config := newConfig() config := newConfig()
md, err := toml.DecodeFile(foundConfigFile, &config) md, err := toml.DecodeFile(foundConfigFile, &config)
@ -261,7 +247,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
dlog.UseSyslog(true) dlog.UseSyslog(true)
} else if config.LogFile != nil { } else if config.LogFile != nil {
dlog.UseLogFile(*config.LogFile) dlog.UseLogFile(*config.LogFile)
if !*child { if !*flags.Child {
FileDescriptors = append(FileDescriptors, dlog.GetFileDescriptor()) FileDescriptors = append(FileDescriptors, dlog.GetFileDescriptor())
} else { } else {
FileDescriptorNum++ FileDescriptorNum++
@ -274,7 +260,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
proxy.userName = config.UserName proxy.userName = config.UserName
proxy.child = *child proxy.child = *flags.Child
proxy.xTransport = NewXTransport() proxy.xTransport = NewXTransport()
proxy.xTransport.tlsDisableSessionTickets = config.TLSDisableSessionTickets proxy.xTransport.tlsDisableSessionTickets = config.TLSDisableSessionTickets
proxy.xTransport.tlsCipherSuite = config.TLSCipherSuite proxy.xTransport.tlsCipherSuite = config.TLSCipherSuite
@ -451,7 +437,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
proxy.routes = &routes proxy.routes = &routes
} }
if *listAll { if *flags.ListAll {
config.ServerNames = nil config.ServerNames = nil
config.DisabledServerNames = nil config.DisabledServerNames = nil
config.SourceRequireDNSSEC = false config.SourceRequireDNSSEC = false
@ -465,8 +451,8 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
netprobeTimeout := config.NetprobeTimeout netprobeTimeout := config.NetprobeTimeout
flag.Visit(func(flag *flag.Flag) { flag.Visit(func(flag *flag.Flag) {
if flag.Name == "netprobe-timeout" && netprobeTimeoutOverride != nil { if flag.Name == "netprobe-timeout" && flags.NetprobeTimeoutOverride != nil {
netprobeTimeout = *netprobeTimeoutOverride netprobeTimeout = *flags.NetprobeTimeoutOverride
} }
}) })
netprobeAddress := DefaultNetprobeAddress netprobeAddress := DefaultNetprobeAddress
@ -475,7 +461,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
} else if len(config.FallbackResolver) > 0 { } else if len(config.FallbackResolver) > 0 {
netprobeAddress = config.FallbackResolver netprobeAddress = config.FallbackResolver
} }
proxy.showCerts = *showCerts || len(os.Getenv("SHOW_CERTS")) > 0 proxy.showCerts = *flags.ShowCerts || len(os.Getenv("SHOW_CERTS")) > 0
if proxy.showCerts { if proxy.showCerts {
proxy.listenAddresses = nil proxy.listenAddresses = nil
} }
@ -491,8 +477,8 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
return errors.New("No servers configured") return errors.New("No servers configured")
} }
} }
if *list || *listAll { if *flags.List || *flags.ListAll {
config.printRegisteredServers(proxy, *jsonOutput) config.printRegisteredServers(proxy, *flags.JsonOutput)
os.Exit(0) os.Exit(0)
} }
if proxy.routes != nil && len(*proxy.routes) > 0 { if proxy.routes != nil && len(*proxy.routes) > 0 {
@ -515,7 +501,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
} }
} }
} }
if *check { if *flags.Check {
dlog.Notice("Configuration successfully checked") dlog.Notice("Configuration successfully checked")
os.Exit(0) os.Exit(0)
} }

View File

@ -44,6 +44,29 @@ func main() {
WorkingDirectory: pwd, WorkingDirectory: pwd,
} }
svcFlag := flag.String("service", "", fmt.Sprintf("Control the system service: %q", service.ControlAction)) svcFlag := flag.String("service", "", fmt.Sprintf("Control the system service: %q", service.ControlAction))
version := flag.Bool("version", false, "print current proxy version")
resolve := flag.String("resolve", "", "resolve a name using system libraries")
flags := ConfigFlags{}
flags.List = flag.Bool("list", false, "print the list of available resolvers for the enabled filters")
flags.ListAll = flag.Bool("list-all", false, "print the complete list of available resolvers, ignoring filters")
flags.JsonOutput = flag.Bool("json", false, "output list as JSON")
flags.Check = flag.Bool("check", false, "check the configuration file and exit")
flags.ConfigFile = flag.String("config", DefaultConfigFileName, "Path to the configuration file")
flags.Child = flag.Bool("child", false, "Invokes program as a child process")
flags.NetprobeTimeoutOverride = flag.Int("netprobe-timeout", 60, "Override the netprobe timeout")
flags.ShowCerts = flag.Bool("show-certs", false, "print DoH certificate chain hashes")
flag.Parse()
if *version {
fmt.Println(AppVersion)
os.Exit(0)
}
if resolve != nil && len(*resolve) > 0 {
Resolve(*resolve)
os.Exit(0)
}
app := &App{} app := &App{}
svc, err := service.New(app, svcConfig) svc, err := service.New(app, svcConfig)
if err != nil { if err != nil {
@ -52,7 +75,7 @@ func main() {
} }
app.proxy = NewProxy() app.proxy = NewProxy()
_ = ServiceManagerStartNotify() _ = ServiceManagerStartNotify()
if err := ConfigLoad(app.proxy, svcFlag); err != nil { if err := ConfigLoad(app.proxy, &flags); err != nil {
dlog.Fatal(err) dlog.Fatal(err)
} }
if len(*svcFlag) != 0 { if len(*svcFlag) != 0 {