move flags parsing into main()

This commit is contained in:
Alison Winters 2019-10-30 21:30:31 +08:00 committed by Frank Denis
parent 116f985b96
commit 2f7e057996
2 changed files with 47 additions and 38 deletions

View File

@ -194,6 +194,17 @@ type ServerSummary struct {
Stamp string `json:"stamp"`
}
type ConfigFlags struct {
List *bool
ListAll *bool
JsonOutput *bool
Check *bool
ConfigFile *string
Child *bool
NetprobeTimeoutOverride *int
ShowCerts *bool
}
func findConfigFile(configFile *string) (string, error) {
if _, err := os.Stat(*configFile); os.IsNotExist(err) {
cdLocal()
@ -211,35 +222,10 @@ func findConfigFile(configFile *string) (string, error) {
return path.Join(pwd, *configFile), nil
}
func ConfigLoad(proxy *Proxy, svcFlag *string) error {
version := flag.Bool("version", false, "print current proxy version")
resolve := flag.String("resolve", "", "resolve a name using system libraries")
list := flag.Bool("list", false, "print the list of available resolvers for the enabled filters")
listAll := flag.Bool("list-all", false, "print the complete list of available resolvers, ignoring filters")
jsonOutput := flag.Bool("json", false, "output list as JSON")
check := flag.Bool("check", false, "check the configuration file and exit")
configFile := flag.String("config", DefaultConfigFileName, "Path to the configuration file")
child := flag.Bool("child", false, "Invokes program as a child process")
netprobeTimeoutOverride := flag.Int("netprobe-timeout", 60, "Override the netprobe timeout")
showCerts := flag.Bool("show-certs", false, "print DoH certificate chain hashes")
flag.Parse()
if *svcFlag == "stop" || *svcFlag == "uninstall" {
return nil
}
if *version {
fmt.Println(AppVersion)
os.Exit(0)
}
if resolve != nil && len(*resolve) > 0 {
Resolve(*resolve)
os.Exit(0)
}
foundConfigFile, err := findConfigFile(configFile)
func ConfigLoad(proxy *Proxy, flags *ConfigFlags) error {
foundConfigFile, err := findConfigFile(flags.ConfigFile)
if err != nil {
dlog.Fatalf("Unable to load the configuration file [%s] -- Maybe use the -config command-line switch?", *configFile)
dlog.Fatalf("Unable to load the configuration file [%s] -- Maybe use the -config command-line switch?", *flags.ConfigFile)
}
config := newConfig()
md, err := toml.DecodeFile(foundConfigFile, &config)
@ -261,7 +247,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
dlog.UseSyslog(true)
} else if config.LogFile != nil {
dlog.UseLogFile(*config.LogFile)
if !*child {
if !*flags.Child {
FileDescriptors = append(FileDescriptors, dlog.GetFileDescriptor())
} else {
FileDescriptorNum++
@ -274,7 +260,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
proxy.userName = config.UserName
proxy.child = *child
proxy.child = *flags.Child
proxy.xTransport = NewXTransport()
proxy.xTransport.tlsDisableSessionTickets = config.TLSDisableSessionTickets
proxy.xTransport.tlsCipherSuite = config.TLSCipherSuite
@ -451,7 +437,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
proxy.routes = &routes
}
if *listAll {
if *flags.ListAll {
config.ServerNames = nil
config.DisabledServerNames = nil
config.SourceRequireDNSSEC = false
@ -465,8 +451,8 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
netprobeTimeout := config.NetprobeTimeout
flag.Visit(func(flag *flag.Flag) {
if flag.Name == "netprobe-timeout" && netprobeTimeoutOverride != nil {
netprobeTimeout = *netprobeTimeoutOverride
if flag.Name == "netprobe-timeout" && flags.NetprobeTimeoutOverride != nil {
netprobeTimeout = *flags.NetprobeTimeoutOverride
}
})
netprobeAddress := DefaultNetprobeAddress
@ -475,7 +461,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
} else if len(config.FallbackResolver) > 0 {
netprobeAddress = config.FallbackResolver
}
proxy.showCerts = *showCerts || len(os.Getenv("SHOW_CERTS")) > 0
proxy.showCerts = *flags.ShowCerts || len(os.Getenv("SHOW_CERTS")) > 0
if proxy.showCerts {
proxy.listenAddresses = nil
}
@ -491,8 +477,8 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
return errors.New("No servers configured")
}
}
if *list || *listAll {
config.printRegisteredServers(proxy, *jsonOutput)
if *flags.List || *flags.ListAll {
config.printRegisteredServers(proxy, *flags.JsonOutput)
os.Exit(0)
}
if proxy.routes != nil && len(*proxy.routes) > 0 {
@ -515,7 +501,7 @@ func ConfigLoad(proxy *Proxy, svcFlag *string) error {
}
}
}
if *check {
if *flags.Check {
dlog.Notice("Configuration successfully checked")
os.Exit(0)
}

View File

@ -44,6 +44,29 @@ func main() {
WorkingDirectory: pwd,
}
svcFlag := flag.String("service", "", fmt.Sprintf("Control the system service: %q", service.ControlAction))
version := flag.Bool("version", false, "print current proxy version")
resolve := flag.String("resolve", "", "resolve a name using system libraries")
flags := ConfigFlags{}
flags.List = flag.Bool("list", false, "print the list of available resolvers for the enabled filters")
flags.ListAll = flag.Bool("list-all", false, "print the complete list of available resolvers, ignoring filters")
flags.JsonOutput = flag.Bool("json", false, "output list as JSON")
flags.Check = flag.Bool("check", false, "check the configuration file and exit")
flags.ConfigFile = flag.String("config", DefaultConfigFileName, "Path to the configuration file")
flags.Child = flag.Bool("child", false, "Invokes program as a child process")
flags.NetprobeTimeoutOverride = flag.Int("netprobe-timeout", 60, "Override the netprobe timeout")
flags.ShowCerts = flag.Bool("show-certs", false, "print DoH certificate chain hashes")
flag.Parse()
if *version {
fmt.Println(AppVersion)
os.Exit(0)
}
if resolve != nil && len(*resolve) > 0 {
Resolve(*resolve)
os.Exit(0)
}
app := &App{}
svc, err := service.New(app, svcConfig)
if err != nil {
@ -52,7 +75,7 @@ func main() {
}
app.proxy = NewProxy()
_ = ServiceManagerStartNotify()
if err := ConfigLoad(app.proxy, svcFlag); err != nil {
if err := ConfigLoad(app.proxy, &flags); err != nil {
dlog.Fatal(err)
}
if len(*svcFlag) != 0 {