update test configuration setting to ensure no settings linger
This commit is contained in:
parent
9939a980ca
commit
91dc4d4d13
|
@ -52,3 +52,9 @@ func LoadEarlyFlags(cmd *cobra.Command) error {
|
|||
func BindFlags(cmd *cobra.Command) error {
|
||||
return global.BindFlags(cmd)
|
||||
}
|
||||
|
||||
// Reset will totally clear global
|
||||
// ConfigState{}, loading defaults.
|
||||
func Reset() {
|
||||
global.Reset()
|
||||
}
|
||||
|
|
|
@ -37,25 +37,9 @@ type ConfigState struct { //nolint
|
|||
|
||||
// NewState returns a new initialized ConfigState instance.
|
||||
func NewState() *ConfigState {
|
||||
viper := viper.New()
|
||||
|
||||
// Flag 'some-flag-name' becomes env var 'GTS_SOME_FLAG_NAME'
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
viper.SetEnvPrefix("gts")
|
||||
|
||||
// Load appropriate named vals from env
|
||||
viper.AutomaticEnv()
|
||||
|
||||
// Create new ConfigState with defaults
|
||||
state := &ConfigState{
|
||||
viper: viper,
|
||||
config: Defaults,
|
||||
}
|
||||
|
||||
// Perform initial load into viper
|
||||
state.reloadToViper()
|
||||
|
||||
return state
|
||||
st := new(ConfigState)
|
||||
st.Reset()
|
||||
return st
|
||||
}
|
||||
|
||||
// Config provides safe access to the ConfigState's contained Configuration,
|
||||
|
@ -116,6 +100,32 @@ func (st *ConfigState) Reload() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// Reset will totally clear
|
||||
// ConfigState{}, loading defaults.
|
||||
func (st *ConfigState) Reset() {
|
||||
// Do within lock.
|
||||
st.mutex.Lock()
|
||||
defer st.mutex.Unlock()
|
||||
|
||||
// Create new viper.
|
||||
viper := viper.New()
|
||||
|
||||
// Flag 'some-flag-name' becomes env var 'GTS_SOME_FLAG_NAME'
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
|
||||
viper.SetEnvPrefix("gts")
|
||||
|
||||
// Load appropriate
|
||||
// named vals from env.
|
||||
viper.AutomaticEnv()
|
||||
|
||||
// Reset variables.
|
||||
st.viper = viper
|
||||
st.config = Defaults
|
||||
|
||||
// Load into viper.
|
||||
st.reloadToViper()
|
||||
}
|
||||
|
||||
// reloadToViper will reload Configuration{} values into viper.
|
||||
func (st *ConfigState) reloadToViper() {
|
||||
raw, err := st.config.MarshalMap()
|
||||
|
|
|
@ -29,10 +29,10 @@ var (
|
|||
// global PostgreSQL driver instances.
|
||||
postgresDriver = pgx.GetDefaultDriver().(*pgx.Driver)
|
||||
|
||||
// check the postgres connection
|
||||
// conforms to our conn{} interface.
|
||||
// check the postgres driver types
|
||||
// conforms to our interface types.
|
||||
// (note SQLite doesn't export their
|
||||
// conn type, and gets checked in
|
||||
// driver types, and gets checked in
|
||||
// tests very regularly anywho).
|
||||
_ connIface = (*pgx.Conn)(nil)
|
||||
_ stmtIface = (*pgx.Stmt)(nil)
|
||||
|
|
|
@ -19,31 +19,31 @@ package testrig
|
|||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
// linkname shenanigans
|
||||
_ "unsafe"
|
||||
|
||||
"codeberg.org/gruf/go-bytesize"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/language"
|
||||
)
|
||||
|
||||
// InitTestConfig initializes viper configuration with test defaults.
|
||||
// initialize with test defaults.
|
||||
func init() { InitTestConfig() }
|
||||
|
||||
// InitTestConfig initializes viper
|
||||
// configuration with test defaults.
|
||||
func InitTestConfig() {
|
||||
config.Config(func(cfg *config.Configuration) {
|
||||
*cfg = testDefaults
|
||||
})
|
||||
config.Defaults = testDefaults()
|
||||
config.Reset()
|
||||
}
|
||||
|
||||
func logLevel() string {
|
||||
level := "error"
|
||||
if lv := os.Getenv("GTS_LOG_LEVEL"); lv != "" {
|
||||
level = lv
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
var testDefaults = config.Configuration{
|
||||
LogLevel: logLevel(),
|
||||
func testDefaults() config.Configuration {
|
||||
return config.Configuration{
|
||||
LogLevel: envStr("GTS_LOG_LEVEL", "error"),
|
||||
LogTimestampFormat: "02/01/2006 15:04:05.000",
|
||||
LogDbQueries: true,
|
||||
ApplicationName: "gotosocial",
|
||||
|
@ -55,16 +55,15 @@ var testDefaults = config.Configuration{
|
|||
BindAddress: "127.0.0.1",
|
||||
Port: 8080,
|
||||
TrustedProxies: []string{"127.0.0.1/32", "::1"},
|
||||
|
||||
DbType: "sqlite",
|
||||
DbAddress: ":memory:",
|
||||
DbPort: 5432,
|
||||
DbUser: "postgres",
|
||||
DbPassword: "postgres",
|
||||
DbDatabase: "postgres",
|
||||
DbTLSMode: "disable",
|
||||
DbTLSCACert: "",
|
||||
DbMaxOpenConnsMultiplier: 8,
|
||||
DbType: envStr("GTS_DB_TYPE", "sqlite"),
|
||||
DbAddress: envStr("GTS_DB_ADDRESS", ":memory:"),
|
||||
DbPort: envInt("GTS_DB_PORT", 5432),
|
||||
DbUser: envStr("GTS_DB_USER", "postgres"),
|
||||
DbPassword: envStr("GTS_DB_PASSWORD", "postgres"),
|
||||
DbDatabase: envStr("GTS_DB_DATABASE", "postgres"),
|
||||
DbTLSMode: envStr("GTS_DB_TLS_MODE", "disable"),
|
||||
DbTLSCACert: envStr("GTS_DB_TLS_CA_CERT", ""),
|
||||
DbMaxOpenConnsMultiplier: 1,
|
||||
DbSqliteJournalMode: "WAL",
|
||||
DbSqliteSynchronous: "NORMAL",
|
||||
DbSqliteCacheSize: 8 * bytesize.MiB,
|
||||
|
@ -158,4 +157,26 @@ var testDefaults = config.Configuration{
|
|||
|
||||
// simply use cache defaults.
|
||||
Cache: config.Defaults.Cache,
|
||||
}
|
||||
}
|
||||
|
||||
func envInt(key string, _default int) int {
|
||||
return env(key, _default, func(value string) int {
|
||||
i, _ := strconv.Atoi(value)
|
||||
return i
|
||||
})
|
||||
}
|
||||
|
||||
func envStr(key string, _default string) string {
|
||||
return env(key, _default, func(value string) string {
|
||||
return value
|
||||
})
|
||||
}
|
||||
|
||||
func env[T any](key string, _default T, parse func(string) T) T {
|
||||
value, ok := os.LookupEnv(key)
|
||||
if ok {
|
||||
return parse(value)
|
||||
}
|
||||
return _default
|
||||
}
|
||||
|
|
|
@ -19,10 +19,7 @@ package testrig
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
@ -84,22 +81,6 @@ var testModels = []interface{}{
|
|||
// If the environment variable GTS_DB_PORT is set, it will take that
|
||||
// value as the port instead.
|
||||
func NewTestDB(state *state.State) db.DB {
|
||||
if alternateAddress := os.Getenv("GTS_DB_ADDRESS"); alternateAddress != "" {
|
||||
config.SetDbAddress(alternateAddress)
|
||||
}
|
||||
|
||||
if alternateDBType := os.Getenv("GTS_DB_TYPE"); alternateDBType != "" {
|
||||
config.SetDbType(alternateDBType)
|
||||
}
|
||||
|
||||
if alternateDBPort := os.Getenv("GTS_DB_PORT"); alternateDBPort != "" {
|
||||
port, err := strconv.ParseUint(alternateDBPort, 10, 16)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
config.SetDbPort(int(port))
|
||||
}
|
||||
|
||||
state.Caches.Init()
|
||||
|
||||
testDB, err := bundb.NewBunDBService(context.Background(), state)
|
||||
|
|
Loading…
Reference in New Issue