mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2025-06-05 21:59:39 +02:00
[feature] support nested configuration files, and setting ALL configuration variables by CLI and env (#4109)
This updates our configuration code generator to now also include map marshal and unmarshalers. So we now have much more control over how things get read from pflags, and stored / read from viper configuration. This allows us to set ALL configuration variables by CLI and environment now, AND support nested configuration files. e.g. ```yaml advanced: scraper-deterrence = true http-client: allow-ips = ["127.0.0.1"] ``` is the same as ```yaml advanced-scraper-deterrence = true http-client-allow-ips = ["127.0.0.1"] ``` This also starts cleaning up of our jumbled Configuration{} type by moving the advanced configuration options into their own nested structs, also as a way to show what it's capable of. It's worth noting however that nesting only works if the Go types are nested too (as this is how we hint to our code generator to generate the necessary flattening code :p). closes #3195 Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4109 Co-authored-by: kim <grufwub@gmail.com> Co-committed-by: kim <grufwub@gmail.com>
This commit is contained in:
@@ -19,11 +19,11 @@ package config
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.superseriousbusiness.org/gotosocial/internal/language"
|
||||
"codeberg.org/gruf/go-bytesize"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// cfgtype is the reflected type information of Configuration{}.
|
||||
@@ -32,9 +32,15 @@ var cfgtype = reflect.TypeOf(Configuration{})
|
||||
// fieldtag will fetch the string value for the given tag name
|
||||
// on the given field name in the Configuration{} struct.
|
||||
func fieldtag(field, tag string) string {
|
||||
sfield, ok := cfgtype.FieldByName(field)
|
||||
if !ok {
|
||||
panic("unknown struct field")
|
||||
nextType := cfgtype
|
||||
var sfield reflect.StructField
|
||||
for _, field := range strings.Split(field, ".") {
|
||||
var ok bool
|
||||
sfield, ok = nextType.FieldByName(field)
|
||||
if !ok {
|
||||
panic("unknown struct field")
|
||||
}
|
||||
nextType = sfield.Type
|
||||
}
|
||||
return sfield.Tag.Get(tag)
|
||||
}
|
||||
@@ -45,21 +51,22 @@ func fieldtag(field, tag string) string {
|
||||
// will need to regenerate the global Getter/Setter helpers by running:
|
||||
// `go run ./internal/config/gen/ -out ./internal/config/helpers.gen.go`
|
||||
type Configuration struct {
|
||||
LogLevel string `name:"log-level" usage:"Log level to run at: [trace, debug, info, warn, fatal]"`
|
||||
LogTimestampFormat string `name:"log-timestamp-format" usage:"Format to use for the log timestamp, as supported by Go's time.Layout"`
|
||||
LogDbQueries bool `name:"log-db-queries" usage:"Log database queries verbosely when log-level is trace or debug"`
|
||||
LogClientIP bool `name:"log-client-ip" usage:"Include the client IP in logs"`
|
||||
ApplicationName string `name:"application-name" usage:"Name of the application, used in various places internally"`
|
||||
LandingPageUser string `name:"landing-page-user" usage:"the user that should be shown on the instance's landing page"`
|
||||
ConfigPath string `name:"config-path" usage:"Path to a file containing gotosocial configuration. Values set in this file will be overwritten by values set as env vars or arguments"`
|
||||
Host string `name:"host" usage:"Hostname to use for the server (eg., example.org, gotosocial.whatever.com). DO NOT change this on a server that's already run!"`
|
||||
AccountDomain string `name:"account-domain" usage:"Domain to use in account names (eg., example.org, whatever.com). If not set, will default to the setting for host. DO NOT change this on a server that's already run!"`
|
||||
Protocol string `name:"protocol" usage:"Protocol to use for the REST api of the server (only use http if you are debugging or behind a reverse proxy!)"`
|
||||
BindAddress string `name:"bind-address" usage:"Bind address to use for the GoToSocial server (eg., 0.0.0.0, 172.138.0.9, [::], localhost). For ipv6, enclose the address in square brackets, eg [2001:db8::fed1]. Default binds to all interfaces."`
|
||||
Port int `name:"port" usage:"Port to use for GoToSocial. Change this to 443 if you're running the binary directly on the host machine."`
|
||||
TrustedProxies []string `name:"trusted-proxies" usage:"Proxies to trust when parsing x-forwarded headers into real IPs."`
|
||||
SoftwareVersion string `name:"software-version" usage:""`
|
||||
LogLevel string `name:"log-level" usage:"Log level to run at: [trace, debug, info, warn, fatal]"`
|
||||
LogTimestampFormat string `name:"log-timestamp-format" usage:"Format to use for the log timestamp, as supported by Go's time.Layout"`
|
||||
LogDbQueries bool `name:"log-db-queries" usage:"Log database queries verbosely when log-level is trace or debug"`
|
||||
LogClientIP bool `name:"log-client-ip" usage:"Include the client IP in logs"`
|
||||
RequestIDHeader string `name:"request-id-header" usage:"Header to extract the Request ID from. Eg.,'X-Request-Id'."`
|
||||
|
||||
ConfigPath string `name:"config-path" usage:"Path to a file containing gotosocial configuration. Values set in this file will be overwritten by values set as env vars or arguments"`
|
||||
ApplicationName string `name:"application-name" usage:"Name of the application, used in various places internally"`
|
||||
LandingPageUser string `name:"landing-page-user" usage:"the user that should be shown on the instance's landing page"`
|
||||
Host string `name:"host" usage:"Hostname to use for the server (eg., example.org, gotosocial.whatever.com). DO NOT change this on a server that's already run!"`
|
||||
AccountDomain string `name:"account-domain" usage:"Domain to use in account names (eg., example.org, whatever.com). If not set, will default to the setting for host. DO NOT change this on a server that's already run!"`
|
||||
Protocol string `name:"protocol" usage:"Protocol to use for the REST api of the server (only use http if you are debugging or behind a reverse proxy!)"`
|
||||
BindAddress string `name:"bind-address" usage:"Bind address to use for the GoToSocial server (eg., 0.0.0.0, 172.138.0.9, [::], localhost). For ipv6, enclose the address in square brackets, eg [2001:db8::fed1]. Default binds to all interfaces."`
|
||||
Port int `name:"port" usage:"Port to use for GoToSocial. Change this to 443 if you're running the binary directly on the host machine."`
|
||||
TrustedProxies []string `name:"trusted-proxies" usage:"Proxies to trust when parsing x-forwarded headers into real IPs."`
|
||||
SoftwareVersion string `name:"software-version" usage:""`
|
||||
DbType string `name:"db-type" usage:"Database type: eg., postgres"`
|
||||
DbAddress string `name:"db-address" usage:"Database ipv4 address, hostname, or filename"`
|
||||
DbPort int `name:"db-port" usage:"Database port"`
|
||||
@@ -160,15 +167,8 @@ type Configuration struct {
|
||||
SyslogProtocol string `name:"syslog-protocol" usage:"Protocol to use when directing logs to syslog. Leave empty to connect to local syslog."`
|
||||
SyslogAddress string `name:"syslog-address" usage:"Address:port to send syslog logs to. Leave empty to connect to local syslog."`
|
||||
|
||||
AdvancedCookiesSamesite string `name:"advanced-cookies-samesite" usage:"'strict' or 'lax', see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie/SameSite"`
|
||||
AdvancedRateLimitRequests int `name:"advanced-rate-limit-requests" usage:"Amount of HTTP requests to permit within a 5 minute window. 0 or less turns rate limiting off."`
|
||||
AdvancedRateLimitExceptions IPPrefixes `name:"advanced-rate-limit-exceptions" usage:"Slice of CIDRs to exclude from rate limit restrictions."`
|
||||
AdvancedThrottlingMultiplier int `name:"advanced-throttling-multiplier" usage:"Multiplier to use per cpu for http request throttling. 0 or less turns throttling off."`
|
||||
AdvancedThrottlingRetryAfter time.Duration `name:"advanced-throttling-retry-after" usage:"Retry-After duration response to send for throttled requests."`
|
||||
AdvancedSenderMultiplier int `name:"advanced-sender-multiplier" usage:"Multiplier to use per cpu for batching outgoing fedi messages. 0 or less turns batching off (not recommended)."`
|
||||
AdvancedCSPExtraURIs []string `name:"advanced-csp-extra-uris" usage:"Additional URIs to allow when building content-security-policy for media + images."`
|
||||
AdvancedHeaderFilterMode string `name:"advanced-header-filter-mode" usage:"Set incoming request header filtering mode."`
|
||||
AdvancedScraperDeterrence bool `name:"advanced-scraper-deterrence" usage:"Enable proof-of-work based scraper deterrence on profile / status pages"`
|
||||
// Advanced flags.
|
||||
Advanced AdvancedConfig `name:"advanced"`
|
||||
|
||||
// HTTPClient configuration vars.
|
||||
HTTPClient HTTPClientConfiguration `name:"http-client"`
|
||||
@@ -177,15 +177,13 @@ type Configuration struct {
|
||||
Cache CacheConfiguration `name:"cache"`
|
||||
|
||||
// TODO: move these elsewhere, these are more ephemeral vs long-running flags like above
|
||||
AdminAccountUsername string `name:"username" usage:"the username to create/delete/etc"`
|
||||
AdminAccountEmail string `name:"email" usage:"the email address of this account"`
|
||||
AdminAccountPassword string `name:"password" usage:"the password to set for this account"`
|
||||
AdminTransPath string `name:"path" usage:"the path of the file to import from/export to"`
|
||||
AdminMediaPruneDryRun bool `name:"dry-run" usage:"perform a dry run and only log number of items eligible for pruning"`
|
||||
AdminMediaListLocalOnly bool `name:"local-only" usage:"list only local attachments/emojis; if specified then remote-only cannot also be true"`
|
||||
AdminMediaListRemoteOnly bool `name:"remote-only" usage:"list only remote attachments/emojis; if specified then local-only cannot also be true"`
|
||||
|
||||
RequestIDHeader string `name:"request-id-header" usage:"Header to extract the Request ID from. Eg.,'X-Request-Id'."`
|
||||
AdminAccountUsername string `name:"username" usage:"the username to create/delete/etc" ephemeral:"yes"`
|
||||
AdminAccountEmail string `name:"email" usage:"the email address of this account" ephemeral:"yes"`
|
||||
AdminAccountPassword string `name:"password" usage:"the password to set for this account" ephemeral:"yes"`
|
||||
AdminTransPath string `name:"path" usage:"the path of the file to import from/export to" ephemeral:"yes"`
|
||||
AdminMediaPruneDryRun bool `name:"dry-run" usage:"perform a dry run and only log number of items eligible for pruning" ephemeral:"yes"`
|
||||
AdminMediaListLocalOnly bool `name:"local-only" usage:"list only local attachments/emojis; if specified then remote-only cannot also be true" ephemeral:"yes"`
|
||||
AdminMediaListRemoteOnly bool `name:"remote-only" usage:"list only remote attachments/emojis; if specified then local-only cannot also be true" ephemeral:"yes"`
|
||||
}
|
||||
|
||||
type HTTPClientConfiguration struct {
|
||||
@@ -255,15 +253,27 @@ type CacheConfiguration struct {
|
||||
VisibilityMemRatio float64 `name:"visibility-mem-ratio"`
|
||||
}
|
||||
|
||||
// MarshalMap will marshal current Configuration into a map structure (useful for JSON/TOML/YAML).
|
||||
func (cfg *Configuration) MarshalMap() (map[string]interface{}, error) {
|
||||
var dst map[string]interface{}
|
||||
dec, _ := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
TagName: "name",
|
||||
Result: &dst,
|
||||
})
|
||||
if err := dec.Decode(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dst, nil
|
||||
type AdvancedConfig struct {
|
||||
CookiesSamesite string `name:"cookies-samesite" usage:"'strict' or 'lax', see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie/SameSite"`
|
||||
SenderMultiplier int `name:"sender-multiplier" usage:"Multiplier to use per cpu for batching outgoing fedi messages. 0 or less turns batching off (not recommended)."`
|
||||
CSPExtraURIs []string `name:"csp-extra-uris" usage:"Additional URIs to allow when building content-security-policy for media + images."`
|
||||
HeaderFilterMode string `name:"header-filter-mode" usage:"Set incoming request header filtering mode."`
|
||||
ScraperDeterrence bool `name:"scraper-deterrence" usage:"Enable proof-of-work based scraper deterrence on profile / status pages"`
|
||||
RateLimit RateLimitConfig `name:"rate-limit"`
|
||||
Throttling ThrottlingConfig `name:"throttling"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
Requests int `name:"requests" usage:"Amount of HTTP requests to permit within a 5 minute window. 0 or less turns rate limiting off."`
|
||||
Exceptions IPPrefixes `name:"exceptions" usage:"Slice of CIDRs to exclude from rate limit restrictions."`
|
||||
}
|
||||
|
||||
type ThrottlingConfig struct {
|
||||
Multiplier int `name:"multiplier" usage:"Multiplier to use per cpu for http request throttling. 0 or less turns throttling off."`
|
||||
RetryAfter time.Duration `name:"retry-after" usage:"Retry-After duration response to send for throttled requests."`
|
||||
}
|
||||
|
||||
// type ScraperDeterrenceConfig struct {
|
||||
// Enabled bool `name:"enabled" usage:"Enable proof-of-work based scraper deterrence on profile / status pages"`
|
||||
// Difficulty uint8 `name:"difficulty" usage:"The proof-of-work difficulty, which determines how many leading zeros to try solve in hash solutions."`
|
||||
// }
|
||||
|
@@ -24,19 +24,18 @@ import (
|
||||
"testing"
|
||||
|
||||
"code.superseriousbusiness.org/gotosocial/internal/config"
|
||||
"codeberg.org/gruf/go-kv"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func expectedKV(kvpairs ...string) map[string]interface{} {
|
||||
ret := make(map[string]interface{}, len(kvpairs)/2)
|
||||
|
||||
for i := 0; i < len(kvpairs)-1; i += 2 {
|
||||
ret[kvpairs[i]] = kvpairs[i+1]
|
||||
func expectedKV(kvs ...kv.Field) map[string]interface{} {
|
||||
ret := make(map[string]interface{}, len(kvs))
|
||||
for _, kv := range kvs {
|
||||
ret[kv.K] = kv.V
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
@@ -61,7 +60,7 @@ func TestCLIParsing(t *testing.T) {
|
||||
expected map[string]interface{}
|
||||
}
|
||||
|
||||
defaults, _ := config.Defaults.MarshalMap()
|
||||
defaults := config.Defaults.MarshalMap()
|
||||
|
||||
testcases := map[string]testcase{
|
||||
"Make sure defaults are set correctly": {
|
||||
@@ -73,7 +72,7 @@ func TestCLIParsing(t *testing.T) {
|
||||
"--db-address", "some.db.address",
|
||||
},
|
||||
expected: expectedKV(
|
||||
"db-address", "some.db.address",
|
||||
kv.Field{"db-address", "some.db.address"},
|
||||
),
|
||||
},
|
||||
|
||||
@@ -82,7 +81,7 @@ func TestCLIParsing(t *testing.T) {
|
||||
"GTS_DB_ADDRESS=some.db.address",
|
||||
},
|
||||
expected: expectedKV(
|
||||
"db-address", "some.db.address",
|
||||
kv.Field{"db-address", "some.db.address"},
|
||||
),
|
||||
},
|
||||
|
||||
@@ -94,7 +93,7 @@ func TestCLIParsing(t *testing.T) {
|
||||
"GTS_DB_ADDRESS=some.other.db.address",
|
||||
},
|
||||
expected: expectedKV(
|
||||
"db-address", "some.db.address",
|
||||
kv.Field{"db-address", "some.db.address"},
|
||||
),
|
||||
},
|
||||
|
||||
@@ -119,8 +118,8 @@ func TestCLIParsing(t *testing.T) {
|
||||
},
|
||||
// only checking our overridden one and one non-default from the config file here instead of including all of test.yaml
|
||||
expected: expectedKV(
|
||||
"account-domain", "my.test.domain",
|
||||
"host", "gts.example.org",
|
||||
kv.Field{"account-domain", "my.test.domain"},
|
||||
kv.Field{"host", "gts.example.org"},
|
||||
),
|
||||
},
|
||||
|
||||
@@ -133,8 +132,8 @@ func TestCLIParsing(t *testing.T) {
|
||||
},
|
||||
// only checking our overridden one and one non-default from the config file here instead of including all of test.yaml
|
||||
expected: expectedKV(
|
||||
"account-domain", "my.test.domain",
|
||||
"host", "gts.example.org",
|
||||
kv.Field{"account-domain", "my.test.domain"},
|
||||
kv.Field{"host", "gts.example.org"},
|
||||
),
|
||||
},
|
||||
|
||||
@@ -148,8 +147,8 @@ func TestCLIParsing(t *testing.T) {
|
||||
},
|
||||
// only checking our overridden one and one non-default from the config file here instead of including all of test.yaml
|
||||
expected: expectedKV(
|
||||
"account-domain", "my.test.domain",
|
||||
"host", "gts.example.org",
|
||||
kv.Field{"account-domain", "my.test.domain"},
|
||||
kv.Field{"host", "gts.example.org"},
|
||||
),
|
||||
},
|
||||
|
||||
@@ -165,9 +164,19 @@ func TestCLIParsing(t *testing.T) {
|
||||
"--config-path", "testdata/test2.yaml",
|
||||
},
|
||||
expected: expectedKV(
|
||||
"log-level", "trace",
|
||||
"account-domain", "peepee.poopoo",
|
||||
"application-name", "gotosocial",
|
||||
kv.Field{"log-level", "trace"},
|
||||
kv.Field{"account-domain", "peepee.poopoo"},
|
||||
kv.Field{"application-name", "gotosocial"},
|
||||
),
|
||||
},
|
||||
|
||||
"Loading nested config file. This should also work the same": {
|
||||
cli: []string{
|
||||
"--config-path", "testdata/test3.yaml",
|
||||
},
|
||||
expected: expectedKV(
|
||||
kv.Field{"advanced-scraper-deterrence", true},
|
||||
kv.Field{"advanced-rate-limit-requests", 5000},
|
||||
),
|
||||
},
|
||||
}
|
||||
@@ -185,8 +194,7 @@ func TestCLIParsing(t *testing.T) {
|
||||
|
||||
state := config.NewState()
|
||||
cmd := cobra.Command{}
|
||||
state.AddGlobalFlags(&cmd)
|
||||
state.AddServerFlags(&cmd)
|
||||
config.RegisterGlobalFlags(&cmd)
|
||||
|
||||
if data.cli != nil {
|
||||
cmd.ParseFlags(data.cli)
|
||||
@@ -194,7 +202,7 @@ func TestCLIParsing(t *testing.T) {
|
||||
|
||||
state.BindFlags(&cmd)
|
||||
|
||||
state.Reload()
|
||||
state.LoadConfigFile()
|
||||
|
||||
state.Viper(func(v *viper.Viper) {
|
||||
for k, ev := range data.expected {
|
||||
|
@@ -130,15 +130,23 @@ var Defaults = Configuration{
|
||||
SyslogProtocol: "udp",
|
||||
SyslogAddress: "localhost:514",
|
||||
|
||||
AdvancedCookiesSamesite: "lax",
|
||||
AdvancedRateLimitRequests: 300, // 1 per second per 5 minutes
|
||||
AdvancedRateLimitExceptions: IPPrefixes{},
|
||||
AdvancedThrottlingMultiplier: 8, // 8 open requests per CPU
|
||||
AdvancedThrottlingRetryAfter: time.Second * 30,
|
||||
AdvancedSenderMultiplier: 2, // 2 senders per CPU
|
||||
AdvancedCSPExtraURIs: []string{},
|
||||
AdvancedHeaderFilterMode: RequestHeaderFilterModeDisabled,
|
||||
AdvancedScraperDeterrence: false,
|
||||
Advanced: AdvancedConfig{
|
||||
SenderMultiplier: 2, // 2 senders per CPU
|
||||
CSPExtraURIs: []string{},
|
||||
HeaderFilterMode: RequestHeaderFilterModeDisabled,
|
||||
CookiesSamesite: "lax",
|
||||
ScraperDeterrence: false,
|
||||
|
||||
RateLimit: RateLimitConfig{
|
||||
Requests: 300, // 1 per second per 5 minutes
|
||||
Exceptions: IPPrefixes{},
|
||||
},
|
||||
|
||||
Throttling: ThrottlingConfig{
|
||||
Multiplier: 8, // 8 open requests per CPU
|
||||
RetryAfter: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
|
||||
Cache: CacheConfiguration{
|
||||
// Rough memory target that the total
|
||||
|
@@ -23,150 +23,6 @@ import (
|
||||
|
||||
// TODO: consolidate these methods into the Configuration{} or ConfigState{} structs.
|
||||
|
||||
// AddGlobalFlags will attach global configuration flags to given cobra command, loading defaults from global config.
|
||||
func AddGlobalFlags(cmd *cobra.Command) {
|
||||
global.AddGlobalFlags(cmd)
|
||||
}
|
||||
|
||||
// AddGlobalFlags will attach global configuration flags to given cobra command, loading defaults from State.
|
||||
func (s *ConfigState) AddGlobalFlags(cmd *cobra.Command) {
|
||||
s.Config(func(cfg *Configuration) {
|
||||
// General
|
||||
cmd.PersistentFlags().String(ApplicationNameFlag(), cfg.ApplicationName, fieldtag("ApplicationName", "usage"))
|
||||
cmd.PersistentFlags().String(LandingPageUserFlag(), cfg.LandingPageUser, fieldtag("LandingPageUser", "usage"))
|
||||
cmd.PersistentFlags().String(HostFlag(), cfg.Host, fieldtag("Host", "usage"))
|
||||
cmd.PersistentFlags().String(AccountDomainFlag(), cfg.AccountDomain, fieldtag("AccountDomain", "usage"))
|
||||
cmd.PersistentFlags().String(ProtocolFlag(), cfg.Protocol, fieldtag("Protocol", "usage"))
|
||||
cmd.PersistentFlags().String(LogLevelFlag(), cfg.LogLevel, fieldtag("LogLevel", "usage"))
|
||||
cmd.PersistentFlags().String(LogTimestampFormatFlag(), cfg.LogTimestampFormat, fieldtag("LogTimestampFormat", "usage"))
|
||||
cmd.PersistentFlags().Bool(LogDbQueriesFlag(), cfg.LogDbQueries, fieldtag("LogDbQueries", "usage"))
|
||||
cmd.PersistentFlags().String(ConfigPathFlag(), cfg.ConfigPath, fieldtag("ConfigPath", "usage"))
|
||||
|
||||
// Database
|
||||
cmd.PersistentFlags().String(DbTypeFlag(), cfg.DbType, fieldtag("DbType", "usage"))
|
||||
cmd.PersistentFlags().String(DbAddressFlag(), cfg.DbAddress, fieldtag("DbAddress", "usage"))
|
||||
cmd.PersistentFlags().Int(DbPortFlag(), cfg.DbPort, fieldtag("DbPort", "usage"))
|
||||
cmd.PersistentFlags().String(DbUserFlag(), cfg.DbUser, fieldtag("DbUser", "usage"))
|
||||
cmd.PersistentFlags().String(DbPasswordFlag(), cfg.DbPassword, fieldtag("DbPassword", "usage"))
|
||||
cmd.PersistentFlags().String(DbDatabaseFlag(), cfg.DbDatabase, fieldtag("DbDatabase", "usage"))
|
||||
cmd.PersistentFlags().String(DbTLSModeFlag(), cfg.DbTLSMode, fieldtag("DbTLSMode", "usage"))
|
||||
cmd.PersistentFlags().String(DbTLSCACertFlag(), cfg.DbTLSCACert, fieldtag("DbTLSCACert", "usage"))
|
||||
cmd.PersistentFlags().Int(DbMaxOpenConnsMultiplierFlag(), cfg.DbMaxOpenConnsMultiplier, fieldtag("DbMaxOpenConnsMultiplier", "usage"))
|
||||
cmd.PersistentFlags().String(DbSqliteJournalModeFlag(), cfg.DbSqliteJournalMode, fieldtag("DbSqliteJournalMode", "usage"))
|
||||
cmd.PersistentFlags().String(DbSqliteSynchronousFlag(), cfg.DbSqliteSynchronous, fieldtag("DbSqliteSynchronous", "usage"))
|
||||
cmd.PersistentFlags().Uint64(DbSqliteCacheSizeFlag(), uint64(cfg.DbSqliteCacheSize), fieldtag("DbSqliteCacheSize", "usage"))
|
||||
cmd.PersistentFlags().Duration(DbSqliteBusyTimeoutFlag(), cfg.DbSqliteBusyTimeout, fieldtag("DbSqliteBusyTimeout", "usage"))
|
||||
|
||||
// HTTPClient
|
||||
cmd.PersistentFlags().StringSlice(HTTPClientAllowIPsFlag(), cfg.HTTPClient.AllowIPs, "no usage string")
|
||||
cmd.PersistentFlags().StringSlice(HTTPClientBlockIPsFlag(), cfg.HTTPClient.BlockIPs, "no usage string")
|
||||
cmd.PersistentFlags().Duration(HTTPClientTimeoutFlag(), cfg.HTTPClient.Timeout, "no usage string")
|
||||
cmd.PersistentFlags().Bool(HTTPClientTLSInsecureSkipVerifyFlag(), cfg.HTTPClient.TLSInsecureSkipVerify, "no usage string")
|
||||
})
|
||||
}
|
||||
|
||||
// AddServerFlags will attach server configuration flags to given cobra command, loading defaults from global config.
|
||||
func AddServerFlags(cmd *cobra.Command) {
|
||||
global.AddServerFlags(cmd)
|
||||
}
|
||||
|
||||
// AddServerFlags will attach server configuration flags to given cobra command, loading defaults from State.
|
||||
func (s *ConfigState) AddServerFlags(cmd *cobra.Command) {
|
||||
s.Config(func(cfg *Configuration) {
|
||||
// Router
|
||||
cmd.PersistentFlags().String(BindAddressFlag(), cfg.BindAddress, fieldtag("BindAddress", "usage"))
|
||||
cmd.PersistentFlags().Int(PortFlag(), cfg.Port, fieldtag("Port", "usage"))
|
||||
cmd.PersistentFlags().StringSlice(TrustedProxiesFlag(), cfg.TrustedProxies, fieldtag("TrustedProxies", "usage"))
|
||||
|
||||
// Template
|
||||
cmd.Flags().String(WebTemplateBaseDirFlag(), cfg.WebTemplateBaseDir, fieldtag("WebTemplateBaseDir", "usage"))
|
||||
cmd.Flags().String(WebAssetBaseDirFlag(), cfg.WebAssetBaseDir, fieldtag("WebAssetBaseDir", "usage"))
|
||||
|
||||
// Instance
|
||||
cmd.Flags().String(InstanceFederationModeFlag(), cfg.InstanceFederationMode, fieldtag("InstanceFederationMode", "usage"))
|
||||
cmd.Flags().Bool(InstanceFederationSpamFilterFlag(), cfg.InstanceFederationSpamFilter, fieldtag("InstanceFederationSpamFilter", "usage"))
|
||||
cmd.Flags().Bool(InstanceExposePeersFlag(), cfg.InstanceExposePeers, fieldtag("InstanceExposePeers", "usage"))
|
||||
cmd.Flags().Bool(InstanceExposeSuspendedFlag(), cfg.InstanceExposeSuspended, fieldtag("InstanceExposeSuspended", "usage"))
|
||||
cmd.Flags().Bool(InstanceExposeSuspendedWebFlag(), cfg.InstanceExposeSuspendedWeb, fieldtag("InstanceExposeSuspendedWeb", "usage"))
|
||||
cmd.Flags().Bool(InstanceDeliverToSharedInboxesFlag(), cfg.InstanceDeliverToSharedInboxes, fieldtag("InstanceDeliverToSharedInboxes", "usage"))
|
||||
cmd.Flags().StringSlice(InstanceLanguagesFlag(), cfg.InstanceLanguages.TagStrs(), fieldtag("InstanceLanguages", "usage"))
|
||||
cmd.Flags().String(InstanceSubscriptionsProcessFromFlag(), cfg.InstanceSubscriptionsProcessFrom, fieldtag("InstanceSubscriptionsProcessFrom", "usage"))
|
||||
cmd.Flags().Duration(InstanceSubscriptionsProcessEveryFlag(), cfg.InstanceSubscriptionsProcessEvery, fieldtag("InstanceSubscriptionsProcessEvery", "usage"))
|
||||
cmd.Flags().String(InstanceStatsModeFlag(), cfg.InstanceStatsMode, fieldtag("InstanceStatsMode", "usage"))
|
||||
cmd.Flags().Bool(InstanceAllowBackdatingStatusesFlag(), cfg.InstanceAllowBackdatingStatuses, fieldtag("InstanceAllowBackdatingStatuses", "usage"))
|
||||
|
||||
// Accounts
|
||||
cmd.Flags().Bool(AccountsRegistrationOpenFlag(), cfg.AccountsRegistrationOpen, fieldtag("AccountsRegistrationOpen", "usage"))
|
||||
cmd.Flags().Bool(AccountsReasonRequiredFlag(), cfg.AccountsReasonRequired, fieldtag("AccountsReasonRequired", "usage"))
|
||||
cmd.Flags().Bool(AccountsAllowCustomCSSFlag(), cfg.AccountsAllowCustomCSS, fieldtag("AccountsAllowCustomCSS", "usage"))
|
||||
|
||||
// Media
|
||||
cmd.Flags().Int(MediaDescriptionMinCharsFlag(), cfg.MediaDescriptionMinChars, fieldtag("MediaDescriptionMinChars", "usage"))
|
||||
cmd.Flags().Int(MediaDescriptionMaxCharsFlag(), cfg.MediaDescriptionMaxChars, fieldtag("MediaDescriptionMaxChars", "usage"))
|
||||
cmd.Flags().Int(MediaRemoteCacheDaysFlag(), cfg.MediaRemoteCacheDays, fieldtag("MediaRemoteCacheDays", "usage"))
|
||||
cmd.Flags().Uint64(MediaLocalMaxSizeFlag(), uint64(cfg.MediaLocalMaxSize), fieldtag("MediaLocalMaxSize", "usage"))
|
||||
cmd.Flags().Uint64(MediaRemoteMaxSizeFlag(), uint64(cfg.MediaRemoteMaxSize), fieldtag("MediaRemoteMaxSize", "usage"))
|
||||
cmd.Flags().Uint64(MediaEmojiLocalMaxSizeFlag(), uint64(cfg.MediaEmojiLocalMaxSize), fieldtag("MediaEmojiLocalMaxSize", "usage"))
|
||||
cmd.Flags().Uint64(MediaEmojiRemoteMaxSizeFlag(), uint64(cfg.MediaEmojiRemoteMaxSize), fieldtag("MediaEmojiRemoteMaxSize", "usage"))
|
||||
cmd.Flags().String(MediaCleanupFromFlag(), cfg.MediaCleanupFrom, fieldtag("MediaCleanupFrom", "usage"))
|
||||
cmd.Flags().Duration(MediaCleanupEveryFlag(), cfg.MediaCleanupEvery, fieldtag("MediaCleanupEvery", "usage"))
|
||||
|
||||
// Storage
|
||||
cmd.Flags().String(StorageBackendFlag(), cfg.StorageBackend, fieldtag("StorageBackend", "usage"))
|
||||
cmd.Flags().String(StorageLocalBasePathFlag(), cfg.StorageLocalBasePath, fieldtag("StorageLocalBasePath", "usage"))
|
||||
|
||||
// Statuses
|
||||
cmd.Flags().Int(StatusesMaxCharsFlag(), cfg.StatusesMaxChars, fieldtag("StatusesMaxChars", "usage"))
|
||||
cmd.Flags().Int(StatusesPollMaxOptionsFlag(), cfg.StatusesPollMaxOptions, fieldtag("StatusesPollMaxOptions", "usage"))
|
||||
cmd.Flags().Int(StatusesPollOptionMaxCharsFlag(), cfg.StatusesPollOptionMaxChars, fieldtag("StatusesPollOptionMaxChars", "usage"))
|
||||
cmd.Flags().Int(StatusesMediaMaxFilesFlag(), cfg.StatusesMediaMaxFiles, fieldtag("StatusesMediaMaxFiles", "usage"))
|
||||
|
||||
// LetsEncrypt
|
||||
cmd.Flags().Bool(LetsEncryptEnabledFlag(), cfg.LetsEncryptEnabled, fieldtag("LetsEncryptEnabled", "usage"))
|
||||
cmd.Flags().Int(LetsEncryptPortFlag(), cfg.LetsEncryptPort, fieldtag("LetsEncryptPort", "usage"))
|
||||
cmd.Flags().String(LetsEncryptCertDirFlag(), cfg.LetsEncryptCertDir, fieldtag("LetsEncryptCertDir", "usage"))
|
||||
cmd.Flags().String(LetsEncryptEmailAddressFlag(), cfg.LetsEncryptEmailAddress, fieldtag("LetsEncryptEmailAddress", "usage"))
|
||||
|
||||
// Manual TLS
|
||||
cmd.Flags().String(TLSCertificateChainFlag(), cfg.TLSCertificateChain, fieldtag("TLSCertificateChain", "usage"))
|
||||
cmd.Flags().String(TLSCertificateKeyFlag(), cfg.TLSCertificateKey, fieldtag("TLSCertificateKey", "usage"))
|
||||
|
||||
// OIDC
|
||||
cmd.Flags().Bool(OIDCEnabledFlag(), cfg.OIDCEnabled, fieldtag("OIDCEnabled", "usage"))
|
||||
cmd.Flags().String(OIDCIdpNameFlag(), cfg.OIDCIdpName, fieldtag("OIDCIdpName", "usage"))
|
||||
cmd.Flags().Bool(OIDCSkipVerificationFlag(), cfg.OIDCSkipVerification, fieldtag("OIDCSkipVerification", "usage"))
|
||||
cmd.Flags().String(OIDCIssuerFlag(), cfg.OIDCIssuer, fieldtag("OIDCIssuer", "usage"))
|
||||
cmd.Flags().String(OIDCClientIDFlag(), cfg.OIDCClientID, fieldtag("OIDCClientID", "usage"))
|
||||
cmd.Flags().String(OIDCClientSecretFlag(), cfg.OIDCClientSecret, fieldtag("OIDCClientSecret", "usage"))
|
||||
cmd.Flags().StringSlice(OIDCScopesFlag(), cfg.OIDCScopes, fieldtag("OIDCScopes", "usage"))
|
||||
|
||||
// SMTP
|
||||
cmd.Flags().String(SMTPHostFlag(), cfg.SMTPHost, fieldtag("SMTPHost", "usage"))
|
||||
cmd.Flags().Int(SMTPPortFlag(), cfg.SMTPPort, fieldtag("SMTPPort", "usage"))
|
||||
cmd.Flags().String(SMTPUsernameFlag(), cfg.SMTPUsername, fieldtag("SMTPUsername", "usage"))
|
||||
cmd.Flags().String(SMTPPasswordFlag(), cfg.SMTPPassword, fieldtag("SMTPPassword", "usage"))
|
||||
cmd.Flags().String(SMTPFromFlag(), cfg.SMTPFrom, fieldtag("SMTPFrom", "usage"))
|
||||
cmd.Flags().Bool(SMTPDiscloseRecipientsFlag(), cfg.SMTPDiscloseRecipients, fieldtag("SMTPDiscloseRecipients", "usage"))
|
||||
|
||||
// Syslog
|
||||
cmd.Flags().Bool(SyslogEnabledFlag(), cfg.SyslogEnabled, fieldtag("SyslogEnabled", "usage"))
|
||||
cmd.Flags().String(SyslogProtocolFlag(), cfg.SyslogProtocol, fieldtag("SyslogProtocol", "usage"))
|
||||
cmd.Flags().String(SyslogAddressFlag(), cfg.SyslogAddress, fieldtag("SyslogAddress", "usage"))
|
||||
|
||||
// Advanced flags
|
||||
cmd.Flags().String(AdvancedCookiesSamesiteFlag(), cfg.AdvancedCookiesSamesite, fieldtag("AdvancedCookiesSamesite", "usage"))
|
||||
cmd.Flags().Int(AdvancedRateLimitRequestsFlag(), cfg.AdvancedRateLimitRequests, fieldtag("AdvancedRateLimitRequests", "usage"))
|
||||
cmd.Flags().StringSlice(AdvancedRateLimitExceptionsFlag(), cfg.AdvancedRateLimitExceptions.Strings(), fieldtag("AdvancedRateLimitExceptions", "usage"))
|
||||
cmd.Flags().Int(AdvancedThrottlingMultiplierFlag(), cfg.AdvancedThrottlingMultiplier, fieldtag("AdvancedThrottlingMultiplier", "usage"))
|
||||
cmd.Flags().Duration(AdvancedThrottlingRetryAfterFlag(), cfg.AdvancedThrottlingRetryAfter, fieldtag("AdvancedThrottlingRetryAfter", "usage"))
|
||||
cmd.Flags().Int(AdvancedSenderMultiplierFlag(), cfg.AdvancedSenderMultiplier, fieldtag("AdvancedSenderMultiplier", "usage"))
|
||||
cmd.Flags().StringSlice(AdvancedCSPExtraURIsFlag(), cfg.AdvancedCSPExtraURIs, fieldtag("AdvancedCSPExtraURIs", "usage"))
|
||||
cmd.Flags().String(AdvancedHeaderFilterModeFlag(), cfg.AdvancedHeaderFilterMode, fieldtag("AdvancedHeaderFilterMode", "usage"))
|
||||
|
||||
cmd.Flags().String(RequestIDHeaderFlag(), cfg.RequestIDHeader, fieldtag("RequestIDHeader", "usage"))
|
||||
})
|
||||
}
|
||||
|
||||
// AddAdminAccount attaches flags pertaining to admin account actions.
|
||||
func AddAdminAccount(cmd *cobra.Command) {
|
||||
name := AdminAccountUsernameFlag()
|
||||
|
@@ -25,6 +25,7 @@ import (
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.superseriousbusiness.org/gotosocial/internal/config"
|
||||
)
|
||||
@@ -48,6 +49,11 @@ const license = `// GoToSocial
|
||||
|
||||
`
|
||||
|
||||
var durationType = reflect.TypeOf(time.Duration(0))
|
||||
var stringerType = reflect.TypeOf((*interface{ String() string })(nil)).Elem()
|
||||
var stringersType = reflect.TypeOf((*interface{ Strings() []string })(nil)).Elem()
|
||||
var flagSetType = reflect.TypeOf((*interface{ Set(string) error })(nil)).Elem()
|
||||
|
||||
func main() {
|
||||
var out string
|
||||
|
||||
@@ -61,41 +67,392 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Fprint(output, "// THIS IS A GENERATED FILE, DO NOT EDIT BY HAND\n")
|
||||
fmt.Fprint(output, license)
|
||||
fmt.Fprint(output, "package config\n\n")
|
||||
fmt.Fprint(output, "import (\n")
|
||||
fmt.Fprint(output, "\t\"time\"\n\n")
|
||||
fmt.Fprint(output, "\t\"codeberg.org/gruf/go-bytesize\"\n")
|
||||
fmt.Fprint(output, "\t\"code.superseriousbusiness.org/gotosocial/internal/language\"\n")
|
||||
fmt.Fprint(output, ")\n\n")
|
||||
generateFields(output, nil, reflect.TypeOf(config.Configuration{}))
|
||||
_ = output.Close()
|
||||
_ = exec.Command("gofumpt", "-w", out).Run()
|
||||
configType := reflect.TypeOf(config.Configuration{})
|
||||
|
||||
// The plan here is that eventually we might be able
|
||||
// to generate an example configuration from struct tags
|
||||
// Parse our config type for usable fields.
|
||||
fields := loadConfigFields(nil, nil, configType)
|
||||
|
||||
fprintf(output, "// THIS IS A GENERATED FILE, DO NOT EDIT BY HAND\n")
|
||||
fprintf(output, license)
|
||||
fprintf(output, "package config\n\n")
|
||||
fprintf(output, "import (\n")
|
||||
fprintf(output, "\t\"fmt\"\n")
|
||||
fprintf(output, "\t\"time\"\n\n")
|
||||
fprintf(output, "\t\"codeberg.org/gruf/go-bytesize\"\n")
|
||||
fprintf(output, "\t\"code.superseriousbusiness.org/gotosocial/internal/language\"\n")
|
||||
fprintf(output, "\t\"github.com/spf13/pflag\"\n")
|
||||
fprintf(output, "\t\"github.com/spf13/cast\"\n")
|
||||
fprintf(output, ")\n")
|
||||
fprintf(output, "\n")
|
||||
generateFlagRegistering(output, fields)
|
||||
generateMapMarshaler(output, fields)
|
||||
generateMapUnmarshaler(output, fields)
|
||||
generateGetSetters(output, fields)
|
||||
generateMapFlattener(output, fields)
|
||||
must(output.Close())
|
||||
must(exec.Command("gofumpt", "-w", out).Run())
|
||||
}
|
||||
|
||||
func generateFields(output io.Writer, prefixes []string, t reflect.Type) {
|
||||
type ConfigField struct {
|
||||
// Any CLI flag prefixes,
|
||||
// i.e. with nested fields.
|
||||
Prefixes []string
|
||||
|
||||
// The base CLI flag
|
||||
// name of the field.
|
||||
Name string
|
||||
|
||||
// Path to struct field
|
||||
// in dot-separated form.
|
||||
Path string
|
||||
|
||||
// Usage string.
|
||||
Usage string
|
||||
|
||||
// The underlying Go type
|
||||
// of the config field.
|
||||
Type reflect.Type
|
||||
|
||||
// i.e. is this found in the configuration file?
|
||||
// or just used in specific CLI commands? in the
|
||||
// future we'll remove these from config struct.
|
||||
Ephemeral bool
|
||||
}
|
||||
|
||||
// Flag returns the combined "prefixes-name" CLI flag for config field.
|
||||
func (f ConfigField) Flag() string {
|
||||
flag := strings.Join(append(f.Prefixes, f.Name), "-")
|
||||
flag = strings.ToLower(flag)
|
||||
return flag
|
||||
}
|
||||
|
||||
// PossibleKeys returns a list of possible map key combinations
|
||||
// that this config field may be found under. The combined "prefixes-name"
|
||||
// will always be in the list, but also separates them out to account for
|
||||
// possible nesting. This allows us to support both nested and un-nested
|
||||
// configuration files, always prioritizing "prefixes-name" as its the CLI flag.
|
||||
func (f ConfigField) PossibleKeys() [][]string {
|
||||
if len(f.Prefixes) == 0 {
|
||||
return [][]string{{f.Name}}
|
||||
}
|
||||
|
||||
var keys [][]string
|
||||
|
||||
combined := f.Flag()
|
||||
keys = append(keys, []string{combined})
|
||||
|
||||
basePrefix := strings.TrimSuffix(combined, "-"+f.Name)
|
||||
keys = append(keys, []string{basePrefix, f.Name})
|
||||
|
||||
for i := len(f.Prefixes) - 1; i >= 0; i-- {
|
||||
prefix := f.Prefixes[i]
|
||||
|
||||
basePrefix = strings.TrimSuffix(basePrefix, prefix)
|
||||
basePrefix = strings.TrimSuffix(basePrefix, "-")
|
||||
if len(basePrefix) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
var key []string
|
||||
key = append(key, basePrefix)
|
||||
key = append(key, f.Prefixes[i:]...)
|
||||
key = append(key, f.Name)
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
func loadConfigFields(pathPrefixes, flagPrefixes []string, t reflect.Type) []ConfigField {
|
||||
var out []ConfigField
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
// Struct field at index.
|
||||
field := t.Field(i)
|
||||
|
||||
if ft := field.Type; ft.Kind() == reflect.Struct {
|
||||
// This is a struct field containing further nested config vars.
|
||||
generateFields(output, append(prefixes, field.Name), ft)
|
||||
// Get field's tagged name.
|
||||
name := field.Tag.Get("name")
|
||||
if name == "" || name == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get prefixed config variable name
|
||||
name := strings.Join(prefixes, "") + field.Name
|
||||
if ft := field.Type; ft.Kind() == reflect.Struct {
|
||||
// This is a nested struct, load nested fields.
|
||||
pathPrefixes := append(pathPrefixes, field.Name)
|
||||
flagPrefixes := append(flagPrefixes, name)
|
||||
out = append(out, loadConfigFields(pathPrefixes, flagPrefixes, ft)...)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get period-separated (if nested) config variable "path"
|
||||
fieldPath := strings.Join(append(prefixes, field.Name), ".")
|
||||
// Get prefixed, period-separated, config variable struct "path".
|
||||
fieldPath := strings.Join(append(pathPrefixes, field.Name), ".")
|
||||
|
||||
// Get dash-separated config variable CLI flag "path"
|
||||
flagPath := strings.Join(append(prefixes, field.Tag.Get("name")), "-")
|
||||
flagPath = strings.ToLower(flagPath)
|
||||
// Append prepared ConfigField.
|
||||
out = append(out, ConfigField{
|
||||
Prefixes: flagPrefixes,
|
||||
Name: name,
|
||||
Path: fieldPath,
|
||||
Usage: field.Tag.Get("usage"),
|
||||
Ephemeral: field.Tag.Get("ephemeral") == "yes",
|
||||
Type: field.Type,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// func generateFlagConsts(out io.Writer, fields []ConfigField) {
|
||||
// fprintf(out, "const (\n")
|
||||
// for _, field := range fields {
|
||||
// name := strings.ReplaceAll(field.Path, ".", "")
|
||||
// fprintf(out, "\t%sFlag = \"%s\"\n", name, field.Flag())
|
||||
// }
|
||||
// fprintf(out, ")\n\n")
|
||||
// }
|
||||
|
||||
func generateFlagRegistering(out io.Writer, fields []ConfigField) {
|
||||
fprintf(out, "func (cfg *Configuration) RegisterFlags(flags *pflag.FlagSet) {\n")
|
||||
for _, field := range fields {
|
||||
if field.Ephemeral {
|
||||
// Skip registering
|
||||
// ephemeral flags.
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for easy cases of just regular primitive types.
|
||||
if field.Type.Kind().String() == field.Type.String() {
|
||||
typeName := field.Type.String()
|
||||
typeName = strings.ToUpper(typeName[:1]) + typeName[1:]
|
||||
fprintf(out, "\tflags.%s(\"%s\", cfg.%s, \"%s\")\n", typeName, field.Flag(), field.Path, field.Usage)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for easy cases of just
|
||||
// regular primitive slice types.
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elem := field.Type.Elem()
|
||||
if elem.Kind().String() == elem.String() {
|
||||
typeName := elem.String()
|
||||
typeName = strings.ToUpper(typeName[:1]) + typeName[1:]
|
||||
fprintf(out, "\tflags.%sSlice(\"%s\", cfg.%s, \"%s\")\n", typeName, field.Flag(), field.Path, field.Usage)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Durations should get set directly
|
||||
// as their types as viper knows how
|
||||
// to deal with this type directly.
|
||||
if field.Type == durationType {
|
||||
fprintf(out, "\tflags.Duration(\"%s\", cfg.%s, \"%s\")\n", field.Flag(), field.Path, field.Usage)
|
||||
continue
|
||||
}
|
||||
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
// Check if the field supports Stringers{}.
|
||||
if field.Type.Implements(stringersType) {
|
||||
fprintf(out, "\tflags.StringSlice(\"%s\", cfg.%s.Strings(), \"%s\")\n", field.Flag(), field.Path, field.Usage)
|
||||
continue
|
||||
}
|
||||
|
||||
// Or the pointer type of the field value supports Stringers{}.
|
||||
if ptr := reflect.PointerTo(field.Type); ptr.Implements(stringersType) {
|
||||
fprintf(out, "\tflags.StringSlice(\"%s\", cfg.%s.Strings(), \"%s\")\n", field.Flag(), field.Path, field.Usage)
|
||||
continue
|
||||
}
|
||||
|
||||
fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, stringersType)
|
||||
} else {
|
||||
// Check if the field supports Stringer{}.
|
||||
if field.Type.Implements(stringerType) {
|
||||
fprintf(out, "\tflags.String(\"%s\", cfg.%s.String(), \"%s\")\n", field.Flag(), field.Path, field.Usage)
|
||||
continue
|
||||
}
|
||||
|
||||
// Or the pointer type of the field value supports Stringer{}.
|
||||
if ptr := reflect.PointerTo(field.Type); ptr.Implements(stringerType) {
|
||||
fprintf(out, "\tflags.String(\"%s\", cfg.%s.String(), \"%s\")\n", field.Flag(), field.Path, field.Usage)
|
||||
continue
|
||||
}
|
||||
|
||||
fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, stringerType)
|
||||
}
|
||||
}
|
||||
fprintf(out, "}\n\n")
|
||||
}
|
||||
|
||||
func generateMapMarshaler(out io.Writer, fields []ConfigField) {
|
||||
fprintf(out, "func (cfg *Configuration) MarshalMap() map[string]any {\n")
|
||||
fprintf(out, "\tcfgmap := make(map[string]any, %d)\n", len(fields))
|
||||
for _, field := range fields {
|
||||
// Check for easy cases of just regular primitive types.
|
||||
if field.Type.Kind().String() == field.Type.String() {
|
||||
fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s\n", field.Flag(), field.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for easy cases of just
|
||||
// regular primitive slice types.
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elem := field.Type.Elem()
|
||||
if elem.Kind().String() == elem.String() {
|
||||
fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s\n", field.Flag(), field.Path)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Durations should get set directly
|
||||
// as their types as viper knows how
|
||||
// to deal with this type directly.
|
||||
if field.Type == durationType {
|
||||
fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s\n", field.Flag(), field.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
// Either the field must support Stringers{}.
|
||||
if field.Type.Implements(stringersType) {
|
||||
fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s.Strings()\n", field.Flag(), field.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
// Or the pointer type of the field value must support Stringers{}.
|
||||
if ptr := reflect.PointerTo(field.Type); ptr.Implements(stringersType) {
|
||||
fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s.Strings()\n", field.Flag(), field.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, stringersType)
|
||||
} else {
|
||||
// Either the field must support Stringer{}.
|
||||
if field.Type.Implements(stringerType) {
|
||||
fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s.String()\n", field.Flag(), field.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
// Or the pointer type of the field value must support Stringer{}.
|
||||
if ptr := reflect.PointerTo(field.Type); ptr.Implements(stringerType) {
|
||||
fprintf(out, "\tcfgmap[\"%s\"] = cfg.%s.String()\n", field.Flag(), field.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, stringerType)
|
||||
}
|
||||
}
|
||||
fprintf(out, "\treturn cfgmap")
|
||||
fprintf(out, "}\n\n")
|
||||
}
|
||||
|
||||
func generateMapUnmarshaler(out io.Writer, fields []ConfigField) {
|
||||
fprintf(out, "func (cfg *Configuration) UnmarshalMap(cfgmap map[string]any) error {\n")
|
||||
fprintf(out, "// VERY IMPORTANT FIRST STEP!\n")
|
||||
fprintf(out, "// flatten to normalize map to\n")
|
||||
fprintf(out, "// entirely un-nested key values\n")
|
||||
fprintf(out, "flattenConfigMap(cfgmap)\n")
|
||||
fprintf(out, "\n")
|
||||
for _, field := range fields {
|
||||
// Check for easy cases of just regular primitive types.
|
||||
if field.Type.Kind().String() == field.Type.String() {
|
||||
generateUnmarshalerPrimitive(out, field)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for easy cases of just
|
||||
// regular primitive slice types.
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elem := field.Type.Elem()
|
||||
if elem.Kind().String() == elem.String() {
|
||||
generateUnmarshalerPrimitive(out, field)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Durations should get set directly
|
||||
// as their types as viper knows how
|
||||
// to deal with this type directly.
|
||||
if field.Type == durationType {
|
||||
generateUnmarshalerPrimitive(out, field)
|
||||
continue
|
||||
}
|
||||
|
||||
// Either the field must support flag.Value{}.
|
||||
if field.Type.Implements(flagSetType) {
|
||||
generateUnmarshalerFlagType(out, field)
|
||||
continue
|
||||
}
|
||||
|
||||
// Or the pointer type of the field value must support flag.Value{}.
|
||||
if ptr := reflect.PointerTo(field.Type); ptr.Implements(flagSetType) {
|
||||
generateUnmarshalerFlagType(out, field)
|
||||
continue
|
||||
}
|
||||
|
||||
fprintf(os.Stderr, "field %s doesn't implement %s!\n", field.Path, flagSetType)
|
||||
}
|
||||
fprintf(out, "\treturn nil\n")
|
||||
fprintf(out, "}\n\n")
|
||||
}
|
||||
|
||||
func generateUnmarshalerPrimitive(out io.Writer, field ConfigField) {
|
||||
fprintf(out, "\t\tif ival, ok := cfgmap[\"%s\"]; ok {\n", field.Flag())
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
elem := field.Type.Elem()
|
||||
typeName := elem.String()
|
||||
if i := strings.IndexRune(typeName, '.'); i >= 0 {
|
||||
typeName = typeName[i+1:]
|
||||
}
|
||||
typeName = strings.ToUpper(typeName[:1]) + typeName[1:]
|
||||
fprintf(out, "\t\t\tvar err error\n")
|
||||
// note we specifically handle slice types ourselves to split by comma
|
||||
fprintf(out, "\t\t\tcfg.%s, err = to%sSlice(ival)\n", field.Path, typeName)
|
||||
fprintf(out, "\t\t\tif err != nil {\n")
|
||||
fprintf(out, "\t\t\t\treturn fmt.Errorf(\"error casting %%#v -> []%s for '%s': %%w\", ival, err)\n", elem.String(), field.Flag())
|
||||
fprintf(out, "\t\t\t}\n")
|
||||
} else {
|
||||
typeName := field.Type.String()
|
||||
if i := strings.IndexRune(typeName, '.'); i >= 0 {
|
||||
typeName = typeName[i+1:]
|
||||
}
|
||||
typeName = strings.ToUpper(typeName[:1]) + typeName[1:]
|
||||
fprintf(out, "\t\t\tvar err error\n")
|
||||
fprintf(out, "\t\t\tcfg.%s, err = cast.To%sE(ival)\n", field.Path, typeName)
|
||||
fprintf(out, "\t\t\tif err != nil {\n")
|
||||
fprintf(out, "\t\t\t\treturn fmt.Errorf(\"error casting %%#v -> %s for '%s': %%w\", ival, err)\n", field.Type.String(), field.Flag())
|
||||
fprintf(out, "\t\t\t}\n")
|
||||
}
|
||||
fprintf(out, "\t}\n")
|
||||
fprintf(out, "\n")
|
||||
}
|
||||
|
||||
func generateUnmarshalerFlagType(out io.Writer, field ConfigField) {
|
||||
fprintf(out, "\t\tif ival, ok := cfgmap[\"%s\"]; ok {\n", field.Flag())
|
||||
if field.Type.Kind() == reflect.Slice {
|
||||
// same as above re: slice types and splitting on comma
|
||||
fprintf(out, "\t\tt, err := toStringSlice(ival)\n")
|
||||
fprintf(out, "\t\tif err != nil {\n")
|
||||
fprintf(out, "\t\t\treturn fmt.Errorf(\"error casting %%#v -> []string for '%s': %%w\", ival, err)\n", field.Flag())
|
||||
fprintf(out, "\t\t}\n")
|
||||
fprintf(out, "\t\tcfg.%s = %s{}\n", field.Path, strings.TrimPrefix(field.Type.String(), "config."))
|
||||
fprintf(out, "\t\tfor _, in := range t {\n")
|
||||
fprintf(out, "\t\t\tif err := cfg.%s.Set(in); err != nil {\n", field.Path)
|
||||
fprintf(out, "\t\t\t\treturn fmt.Errorf(\"error parsing %%#v for '%s': %%w\", ival, err)\n", field.Flag())
|
||||
fprintf(out, "\t\t\t}\n")
|
||||
fprintf(out, "\t\t}\n")
|
||||
} else {
|
||||
fprintf(out, "\t\tt, err := cast.ToStringE(ival)\n")
|
||||
fprintf(out, "\t\tif err != nil {\n")
|
||||
fprintf(out, "\t\t\treturn fmt.Errorf(\"error casting %%#v -> string for '%s': %%w\", ival, err)\n", field.Flag())
|
||||
fprintf(out, "\t\t}\n")
|
||||
fprintf(out, "\t\tcfg.%s = %#v\n", field.Path, reflect.New(field.Type).Elem().Interface())
|
||||
fprintf(out, "\t\tif err := cfg.%s.Set(t); err != nil {\n", field.Path)
|
||||
fprintf(out, "\t\t\treturn fmt.Errorf(\"error parsing %%#v for '%s': %%w\", ival, err)\n", field.Flag())
|
||||
fprintf(out, "\t\t}\n")
|
||||
}
|
||||
fprintf(out, "\t}\n")
|
||||
fprintf(out, "\n")
|
||||
}
|
||||
|
||||
func generateGetSetters(out io.Writer, fields []ConfigField) {
|
||||
for _, field := range fields {
|
||||
// Get name from struct path, without periods.
|
||||
name := strings.ReplaceAll(field.Path, ".", "")
|
||||
|
||||
// Get type without "config." prefix.
|
||||
fieldType := strings.ReplaceAll(
|
||||
@@ -103,29 +460,67 @@ func generateFields(output io.Writer, prefixes []string, t reflect.Type) {
|
||||
"config.", "",
|
||||
)
|
||||
|
||||
fprintf(out, "// %sFlag returns the flag name for the '%s' field\n", name, field.Path)
|
||||
fprintf(out, "func %sFlag() string { return \"%s\" }\n\n", name, field.Flag())
|
||||
|
||||
// ConfigState structure helper methods
|
||||
fmt.Fprintf(output, "// Get%s safely fetches the Configuration value for state's '%s' field\n", name, fieldPath)
|
||||
fmt.Fprintf(output, "func (st *ConfigState) Get%s() (v %s) {\n", name, fieldType)
|
||||
fmt.Fprintf(output, "\tst.mutex.RLock()\n")
|
||||
fmt.Fprintf(output, "\tv = st.config.%s\n", fieldPath)
|
||||
fmt.Fprintf(output, "\tst.mutex.RUnlock()\n")
|
||||
fmt.Fprintf(output, "\treturn\n")
|
||||
fmt.Fprintf(output, "}\n\n")
|
||||
fmt.Fprintf(output, "// Set%s safely sets the Configuration value for state's '%s' field\n", name, fieldPath)
|
||||
fmt.Fprintf(output, "func (st *ConfigState) Set%s(v %s) {\n", name, fieldType)
|
||||
fmt.Fprintf(output, "\tst.mutex.Lock()\n")
|
||||
fmt.Fprintf(output, "\tdefer st.mutex.Unlock()\n")
|
||||
fmt.Fprintf(output, "\tst.config.%s = v\n", fieldPath)
|
||||
fmt.Fprintf(output, "\tst.reloadToViper()\n")
|
||||
fmt.Fprintf(output, "}\n\n")
|
||||
fprintf(out, "// Get%s safely fetches the Configuration value for state's '%s' field\n", name, field.Path)
|
||||
fprintf(out, "func (st *ConfigState) Get%s() (v %s) {\n", name, fieldType)
|
||||
fprintf(out, "\tst.mutex.RLock()\n")
|
||||
fprintf(out, "\tv = st.config.%s\n", field.Path)
|
||||
fprintf(out, "\tst.mutex.RUnlock()\n")
|
||||
fprintf(out, "\treturn\n")
|
||||
fprintf(out, "}\n\n")
|
||||
fprintf(out, "// Set%s safely sets the Configuration value for state's '%s' field\n", name, field.Path)
|
||||
fprintf(out, "func (st *ConfigState) Set%s(v %s) {\n", name, fieldType)
|
||||
fprintf(out, "\tst.mutex.Lock()\n")
|
||||
fprintf(out, "\tdefer st.mutex.Unlock()\n")
|
||||
fprintf(out, "\tst.config.%s = v\n", field.Path)
|
||||
fprintf(out, "\tst.reloadToViper()\n")
|
||||
fprintf(out, "}\n\n")
|
||||
|
||||
// Global ConfigState helper methods
|
||||
// TODO: remove when we pass around a ConfigState{}
|
||||
fmt.Fprintf(output, "// %sFlag returns the flag name for the '%s' field\n", name, fieldPath)
|
||||
fmt.Fprintf(output, "func %sFlag() string { return \"%s\" }\n\n", name, flagPath)
|
||||
fmt.Fprintf(output, "// Get%s safely fetches the value for global configuration '%s' field\n", name, fieldPath)
|
||||
fmt.Fprintf(output, "func Get%[1]s() %[2]s { return global.Get%[1]s() }\n\n", name, fieldType)
|
||||
fmt.Fprintf(output, "// Set%s safely sets the value for global configuration '%s' field\n", name, fieldPath)
|
||||
fmt.Fprintf(output, "func Set%[1]s(v %[2]s) { global.Set%[1]s(v) }\n\n", name, fieldType)
|
||||
fprintf(out, "// Get%s safely fetches the value for global configuration '%s' field\n", name, field.Path)
|
||||
fprintf(out, "func Get%[1]s() %[2]s { return global.Get%[1]s() }\n\n", name, fieldType)
|
||||
fprintf(out, "// Set%s safely sets the value for global configuration '%s' field\n", name, field.Path)
|
||||
fprintf(out, "func Set%[1]s(v %[2]s) { global.Set%[1]s(v) }\n\n", name, fieldType)
|
||||
}
|
||||
}
|
||||
|
||||
func generateMapFlattener(out io.Writer, fields []ConfigField) {
|
||||
fprintf(out, "func flattenConfigMap(cfgmap map[string]any) {\n")
|
||||
fprintf(out, "\tnestedKeys := make(map[string]struct{})\n")
|
||||
for _, field := range fields {
|
||||
keys := field.PossibleKeys()
|
||||
if len(keys) <= 1 {
|
||||
continue
|
||||
}
|
||||
fprintf(out, "\tfor _, key := range [][]string{\n")
|
||||
for _, key := range keys[1:] {
|
||||
fprintf(out, "\t\t{\"%s\"},\n", strings.Join(key, "\", \""))
|
||||
}
|
||||
fprintf(out, "\t} {\n")
|
||||
fprintf(out, "\t\tival, ok := mapGet(cfgmap, key...)\n")
|
||||
fprintf(out, "\t\tif ok {\n")
|
||||
fprintf(out, "\t\t\tcfgmap[\"%s\"] = ival\n", field.Flag())
|
||||
fprintf(out, "\t\t\tnestedKeys[key[0]] = struct{}{}\n")
|
||||
fprintf(out, "\t\t\tbreak\n")
|
||||
fprintf(out, "\t\t}\n")
|
||||
fprintf(out, "\t}\n\n")
|
||||
}
|
||||
fprintf(out, "\tfor key := range nestedKeys {\n")
|
||||
fprintf(out, "\t\tdelete(cfgmap, key)\n")
|
||||
fprintf(out, "\t}\n")
|
||||
fprintf(out, "}\n\n")
|
||||
}
|
||||
|
||||
func fprintf(out io.Writer, format string, args ...any) {
|
||||
_, err := fmt.Fprintf(out, format, args...)
|
||||
must(err)
|
||||
}
|
||||
|
||||
func must(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
@@ -32,29 +32,17 @@ func init() {
|
||||
// package, and instead pass the ConfigState round in a global gts state.
|
||||
|
||||
// Config provides you safe access to the global configuration.
|
||||
func Config(fn func(cfg *Configuration)) {
|
||||
global.Config(fn)
|
||||
}
|
||||
func Config(fn func(cfg *Configuration)) { global.Config(fn) }
|
||||
|
||||
// Reload will reload the current configuration values from file.
|
||||
func Reload() error {
|
||||
return global.Reload()
|
||||
}
|
||||
|
||||
// LoadEarlyFlags will bind specific flags from given Cobra command to global viper
|
||||
// instance, and load the current configuration values. This is useful for flags like
|
||||
// .ConfigPath which have to parsed first in order to perform early configuration load.
|
||||
func LoadEarlyFlags(cmd *cobra.Command) error {
|
||||
return global.LoadEarlyFlags(cmd)
|
||||
}
|
||||
// RegisterGlobalFlags ...
|
||||
func RegisterGlobalFlags(root *cobra.Command) { global.RegisterGlobalFlags(root) }
|
||||
|
||||
// BindFlags binds given command's pflags to the global viper instance.
|
||||
func BindFlags(cmd *cobra.Command) error {
|
||||
return global.BindFlags(cmd)
|
||||
}
|
||||
func BindFlags(cmd *cobra.Command) error { return global.BindFlags(cmd) }
|
||||
|
||||
// LoadConfigFile loads the currently set configuration file into the global viper instance.
|
||||
func LoadConfigFile() error { return global.LoadConfigFile() }
|
||||
|
||||
// Reset will totally clear global
|
||||
// ConfigState{}, loading defaults.
|
||||
func Reset() {
|
||||
global.Reset()
|
||||
}
|
||||
func Reset() { global.Reset() }
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -18,10 +18,11 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -46,34 +47,25 @@ func NewState() *ConfigState {
|
||||
// and will reload the current Configuration back into viper settings.
|
||||
func (st *ConfigState) Config(fn func(*Configuration)) {
|
||||
st.mutex.Lock()
|
||||
defer func() {
|
||||
st.reloadToViper()
|
||||
st.mutex.Unlock()
|
||||
}()
|
||||
defer st.mutex.Unlock()
|
||||
fn(&st.config)
|
||||
st.reloadToViper()
|
||||
}
|
||||
|
||||
// Viper provides safe access to the ConfigState's contained viper instance,
|
||||
// and will reload the current viper setting state back into Configuration.
|
||||
func (st *ConfigState) Viper(fn func(*viper.Viper)) {
|
||||
st.mutex.Lock()
|
||||
defer func() {
|
||||
st.reloadFromViper()
|
||||
st.mutex.Unlock()
|
||||
}()
|
||||
defer st.mutex.Unlock()
|
||||
fn(st.viper)
|
||||
st.reloadFromViper()
|
||||
}
|
||||
|
||||
// LoadEarlyFlags will bind specific flags from given Cobra command to ConfigState's viper
|
||||
// instance, and load the current configuration values. This is useful for flags like
|
||||
// .ConfigPath which have to parsed first in order to perform early configuration load.
|
||||
func (st *ConfigState) LoadEarlyFlags(cmd *cobra.Command) (err error) {
|
||||
name := ConfigPathFlag()
|
||||
flag := cmd.Flags().Lookup(name)
|
||||
st.Viper(func(v *viper.Viper) {
|
||||
err = v.BindPFlag(name, flag)
|
||||
})
|
||||
return
|
||||
// RegisterGlobalFlags ...
|
||||
func (st *ConfigState) RegisterGlobalFlags(root *cobra.Command) {
|
||||
st.mutex.RLock()
|
||||
st.config.RegisterFlags(root.PersistentFlags())
|
||||
st.mutex.RUnlock()
|
||||
}
|
||||
|
||||
// BindFlags will bind given Cobra command's pflags to this ConfigState's viper instance.
|
||||
@@ -84,15 +76,21 @@ func (st *ConfigState) BindFlags(cmd *cobra.Command) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Reload will reload the Configuration values from ConfigState's viper instance, and from file if set.
|
||||
func (st *ConfigState) Reload() (err error) {
|
||||
// LoadConfigFile loads the currently set configuration file into this ConfigState's viper instance.
|
||||
func (st *ConfigState) LoadConfigFile() (err error) {
|
||||
st.Viper(func(v *viper.Viper) {
|
||||
if st.config.ConfigPath != "" {
|
||||
// Ensure configuration path is set
|
||||
v.SetConfigFile(st.config.ConfigPath)
|
||||
if path := st.config.ConfigPath; path != "" {
|
||||
var cfgmap map[string]any
|
||||
|
||||
// Read in configuration from file
|
||||
if err = v.ReadInConfig(); err != nil {
|
||||
// Read config map into memory.
|
||||
cfgmap, err := readConfigMap(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Merge the parsed config into viper.
|
||||
err = st.viper.MergeConfigMap(cfgmap)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -108,18 +106,17 @@ func (st *ConfigState) Reset() {
|
||||
defer st.mutex.Unlock()
|
||||
|
||||
// Create new viper.
|
||||
viper := viper.New()
|
||||
st.viper = viper.New()
|
||||
|
||||
// Flag 'some-flag-name' becomes env var 'GTS_SOME_FLAG_NAME'
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
viper.SetEnvPrefix("gts")
|
||||
st.viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
st.viper.SetEnvPrefix("gts")
|
||||
|
||||
// Load appropriate
|
||||
// named vals from env.
|
||||
viper.AutomaticEnv()
|
||||
st.viper.AutomaticEnv()
|
||||
|
||||
// Reset variables.
|
||||
st.viper = viper
|
||||
// Set default config.
|
||||
st.config = Defaults
|
||||
|
||||
// Load into viper.
|
||||
@@ -128,31 +125,45 @@ func (st *ConfigState) Reset() {
|
||||
|
||||
// reloadToViper will reload Configuration{} values into viper.
|
||||
func (st *ConfigState) reloadToViper() {
|
||||
raw, err := st.config.MarshalMap()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := st.viper.MergeConfigMap(raw); err != nil {
|
||||
if err := st.viper.MergeConfigMap(st.config.MarshalMap()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// reloadFromViper will reload Configuration{} values from viper.
|
||||
func (st *ConfigState) reloadFromViper() {
|
||||
if err := st.viper.Unmarshal(&st.config, func(c *mapstructure.DecoderConfig) {
|
||||
c.TagName = "name"
|
||||
|
||||
// empty config before marshaling
|
||||
c.ZeroFields = true
|
||||
|
||||
oldhook := c.DecodeHook
|
||||
|
||||
// Use the TextUnmarshaler interface when decoding.
|
||||
c.DecodeHook = mapstructure.ComposeDecodeHookFunc(
|
||||
mapstructure.TextUnmarshallerHookFunc(),
|
||||
oldhook,
|
||||
)
|
||||
}); err != nil {
|
||||
if err := st.config.UnmarshalMap(st.viper.AllSettings()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// readConfigMap reads given configuration file into memory,
|
||||
// using viper's codec registry to handle decoding into a map,
|
||||
// flattening the result for standardization, returning this.
|
||||
// this ensures the stored config map in viper always has the
|
||||
// same level of nesting, given we support varying levels.
|
||||
func readConfigMap(file string) (map[string]any, error) {
|
||||
ext := path.Ext(file)
|
||||
ext = strings.TrimPrefix(ext, ".")
|
||||
|
||||
registry := viper.NewCodecRegistry()
|
||||
dec, err := registry.Decoder(ext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfgmap := make(map[string]any)
|
||||
|
||||
if err := dec.Decode(data, cfgmap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flattenConfigMap(cfgmap)
|
||||
|
||||
return cfgmap, nil
|
||||
}
|
||||
|
4
internal/config/testdata/test3.yaml
vendored
Normal file
4
internal/config/testdata/test3.yaml
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
advanced:
|
||||
scraper-deterrence: true
|
||||
rate-limit:
|
||||
requests: 5000
|
@@ -18,9 +18,8 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"codeberg.org/gruf/go-byteutil"
|
||||
)
|
||||
|
||||
// IPPrefixes is a type-alias for []netip.Prefix
|
||||
@@ -28,6 +27,9 @@ import (
|
||||
type IPPrefixes []netip.Prefix
|
||||
|
||||
func (p *IPPrefixes) Set(in string) error {
|
||||
if p == nil {
|
||||
return errors.New("nil receiver")
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(in)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -36,20 +38,6 @@ func (p *IPPrefixes) Set(in string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *IPPrefixes) String() string {
|
||||
if p == nil || len(*p) == 0 {
|
||||
return ""
|
||||
}
|
||||
var buf byteutil.Buffer
|
||||
for _, prefix := range *p {
|
||||
str := prefix.String()
|
||||
buf.B = append(buf.B, str...)
|
||||
buf.B = append(buf.B, ',')
|
||||
}
|
||||
buf.Truncate(1)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (p *IPPrefixes) Strings() []string {
|
||||
if p == nil || len(*p) == 0 {
|
||||
return nil
|
||||
|
74
internal/config/util.go
Normal file
74
internal/config/util.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/gruf/go-split"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func toStringSlice(a any) ([]string, error) {
|
||||
switch a := a.(type) {
|
||||
case []string:
|
||||
return a, nil
|
||||
case string:
|
||||
return split.SplitStrings[string](a)
|
||||
case []any:
|
||||
ss := make([]string, len(a))
|
||||
for i, a := range a {
|
||||
var err error
|
||||
ss[i], err = cast.ToStringE(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return ss, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("cannot cast %T to []string", a)
|
||||
}
|
||||
}
|
||||
|
||||
func mapGet(m map[string]any, keys ...string) (any, bool) {
|
||||
for len(keys) > 0 {
|
||||
key := keys[0]
|
||||
keys = keys[1:]
|
||||
|
||||
// Check for key.
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
// Has to be value.
|
||||
return v, true
|
||||
}
|
||||
|
||||
// Else, it needs to have
|
||||
// nesting to keep searching.
|
||||
switch t := v.(type) {
|
||||
case map[string]any:
|
||||
m = t
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
Reference in New Issue
Block a user