mirror of
https://github.com/superseriousbusiness/gotosocial
synced 2025-06-05 21:59:39 +02:00
Implement Cobra CLI tooling, Viper config tooling (#336)
* start pulling out + replacing urfave and config * replace many many instances of config * move more stuff => viper * properly remove urfave * move some flags to root command * add testrig commands to root * alias config file keys * start adding cli parsing tests * reorder viper init * remove config path alias * fmt * change config file keys to non-nested * we're more or less in business now * tidy up the common func * go fmt * get tests passing again * add note about the cliparsing tests * reorganize * update docs with changes * structure cmd dir better * rename + move some files around * fix dangling comma
This commit is contained in:
@@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
var corsConfig = cors.Config{
|
||||
@@ -81,7 +80,7 @@ var corsConfig = cors.Config{
|
||||
}
|
||||
|
||||
// useCors attaches the corsConfig above to the given gin engine
|
||||
func useCors(cfg *config.Config, engine *gin.Engine) error {
|
||||
func useCors(engine *gin.Engine) error {
|
||||
c := cors.New(corsConfig)
|
||||
engine.Use(c)
|
||||
return nil
|
||||
|
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
@@ -58,7 +59,6 @@ type Router interface {
|
||||
type router struct {
|
||||
engine *gin.Engine
|
||||
srv *http.Server
|
||||
config *config.Config
|
||||
certManager *autocert.Manager
|
||||
}
|
||||
|
||||
@@ -69,10 +69,16 @@ func (r *router) AttachStaticFS(relativePath string, fs http.FileSystem) {
|
||||
|
||||
// Start starts the router nicely. It will serve two handlers if letsencrypt is enabled, and only the web/API handler if letsencrypt is not enabled.
|
||||
func (r *router) Start() {
|
||||
if r.config.LetsEncryptConfig.Enabled {
|
||||
keys := config.Keys
|
||||
leEnabled := viper.GetBool(keys.LetsEncryptEnabled)
|
||||
|
||||
if leEnabled {
|
||||
bindAddress := viper.GetString(keys.BindAddress)
|
||||
lePort := viper.GetInt(keys.LetsEncryptPort)
|
||||
|
||||
// serve the http handler on the selected letsencrypt port, for receiving letsencrypt requests and solving their devious riddles
|
||||
go func() {
|
||||
listen := fmt.Sprintf("%s:%d", r.config.BindAddress, r.config.LetsEncryptConfig.Port)
|
||||
listen := fmt.Sprintf("%s:%d", bindAddress, lePort)
|
||||
if err := http.ListenAndServe(listen, r.certManager.HTTPHandler(http.HandlerFunc(httpsRedirect))); err != nil && err != http.ErrServerClosed {
|
||||
logrus.Fatalf("listen: %s", err)
|
||||
}
|
||||
@@ -103,7 +109,9 @@ func (r *router) Stop(ctx context.Context) error {
|
||||
//
|
||||
// The given DB is only used in the New function for parsing config values, and is not otherwise
|
||||
// pinned to the router.
|
||||
func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
|
||||
func New(ctx context.Context, db db.DB) (Router, error) {
|
||||
keys := config.Keys
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
// create the actual engine here -- this is the core request routing handler for gts
|
||||
@@ -116,12 +124,13 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
|
||||
engine.MaxMultipartMemory = 8 << 20
|
||||
|
||||
// set up IP forwarding via x-forward-* headers.
|
||||
if err := engine.SetTrustedProxies(cfg.TrustedProxies); err != nil {
|
||||
trustedProxies := viper.GetStringSlice(keys.TrustedProxies)
|
||||
if err := engine.SetTrustedProxies(trustedProxies); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// enable cors on the engine
|
||||
if err := useCors(cfg, engine); err != nil {
|
||||
if err := useCors(engine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -129,17 +138,19 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
|
||||
loadTemplateFunctions(engine)
|
||||
|
||||
// load templates onto the engine
|
||||
if err := loadTemplates(cfg, engine); err != nil {
|
||||
if err := loadTemplates(engine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// enable session store middleware on the engine
|
||||
if err := useSession(ctx, cfg, db, engine); err != nil {
|
||||
if err := useSession(ctx, db, engine); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create the http server here, passing the gin engine as handler
|
||||
listen := fmt.Sprintf("%s:%d", cfg.BindAddress, cfg.Port)
|
||||
bindAddress := viper.GetString(keys.BindAddress)
|
||||
port := viper.GetInt(keys.Port)
|
||||
listen := fmt.Sprintf("%s:%d", bindAddress, port)
|
||||
s := &http.Server{
|
||||
Addr: listen,
|
||||
Handler: engine,
|
||||
@@ -151,15 +162,19 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
|
||||
|
||||
// We need to spawn the underlying server slightly differently depending on whether lets encrypt is enabled or not.
|
||||
// In either case, the gin engine will still be used for routing requests.
|
||||
leEnabled := viper.GetBool(keys.LetsEncryptEnabled)
|
||||
|
||||
var m *autocert.Manager
|
||||
if cfg.LetsEncryptConfig.Enabled {
|
||||
if leEnabled {
|
||||
// le IS enabled, so roll up an autocert manager for handling letsencrypt requests
|
||||
host := viper.GetString(keys.Host)
|
||||
leCertDir := viper.GetString(keys.LetsEncryptCertDir)
|
||||
leEmailAddress := viper.GetString(keys.LetsEncryptEmailAddress)
|
||||
m = &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: autocert.HostWhitelist(cfg.Host),
|
||||
Cache: autocert.DirCache(cfg.LetsEncryptConfig.CertDir),
|
||||
Email: cfg.LetsEncryptConfig.EmailAddress,
|
||||
HostPolicy: autocert.HostWhitelist(host),
|
||||
Cache: autocert.DirCache(leCertDir),
|
||||
Email: leEmailAddress,
|
||||
}
|
||||
s.TLSConfig = m.TLSConfig()
|
||||
}
|
||||
@@ -167,7 +182,6 @@ func New(ctx context.Context, cfg *config.Config, db db.DB) (Router, error) {
|
||||
return &router{
|
||||
engine: engine,
|
||||
srv: s,
|
||||
config: cfg,
|
||||
certManager: m,
|
||||
}, nil
|
||||
}
|
||||
|
@@ -28,15 +28,16 @@ import (
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/memstore"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
)
|
||||
|
||||
// sessionOptions returns the standard set of options to use for each session.
|
||||
func sessionOptions(cfg *config.Config) sessions.Options {
|
||||
func sessionOptions() sessions.Options {
|
||||
return sessions.Options{
|
||||
Path: "/",
|
||||
Domain: cfg.Host,
|
||||
Domain: viper.GetString(config.Keys.Host),
|
||||
MaxAge: 120, // 2 minutes
|
||||
Secure: true, // only use cookie over https
|
||||
HttpOnly: true, // exclude javascript from inspecting cookie
|
||||
@@ -44,9 +45,12 @@ func sessionOptions(cfg *config.Config) sessions.Options {
|
||||
}
|
||||
}
|
||||
|
||||
func sessionName(cfg *config.Config) (string, error) {
|
||||
// SessionName is a utility function that derives an appropriate session name from the hostname.
|
||||
func SessionName() (string, error) {
|
||||
// parse the protocol + host
|
||||
u, err := url.Parse(fmt.Sprintf("%s://%s", cfg.Protocol, cfg.Host))
|
||||
protocol := viper.GetString(config.Keys.Protocol)
|
||||
host := viper.GetString(config.Keys.Host)
|
||||
u, err := url.Parse(fmt.Sprintf("%s://%s", protocol, host))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -54,13 +58,13 @@ func sessionName(cfg *config.Config) (string, error) {
|
||||
// take the hostname without any port attached
|
||||
strippedHostname := u.Hostname()
|
||||
if strippedHostname == "" {
|
||||
return "", fmt.Errorf("could not derive hostname without port from %s://%s", cfg.Protocol, cfg.Host)
|
||||
return "", fmt.Errorf("could not derive hostname without port from %s://%s", protocol, host)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("gotosocial-%s", strippedHostname), nil
|
||||
}
|
||||
|
||||
func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, engine *gin.Engine) error {
|
||||
func useSession(ctx context.Context, sessionDB db.Session, engine *gin.Engine) error {
|
||||
// check if we have a saved router session already
|
||||
rs, err := sessionDB.GetSession(ctx)
|
||||
if err != nil {
|
||||
@@ -71,9 +75,9 @@ func useSession(ctx context.Context, cfg *config.Config, sessionDB db.Session, e
|
||||
}
|
||||
|
||||
store := memstore.NewStore(rs.Auth, rs.Crypt)
|
||||
store.Options(sessionOptions(cfg))
|
||||
store.Options(sessionOptions())
|
||||
|
||||
sessionName, err := sessionName(cfg)
|
||||
sessionName, err := SessionName()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -16,68 +16,68 @@
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package router
|
||||
package router_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/router"
|
||||
"github.com/superseriousbusiness/gotosocial/testrig"
|
||||
)
|
||||
|
||||
type SessionTestSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (suite *SessionTestSuite) TestDeriveSessionNameLocalhostWithPort() {
|
||||
cfg := &config.Config{
|
||||
Protocol: "http",
|
||||
Host: "localhost:8080",
|
||||
}
|
||||
func (suite *SessionTestSuite) SetupTest() {
|
||||
testrig.InitTestConfig()
|
||||
}
|
||||
|
||||
sessionName, err := sessionName(cfg)
|
||||
func (suite *SessionTestSuite) TestDeriveSessionNameLocalhostWithPort() {
|
||||
viper.Set(config.Keys.Protocol, "http")
|
||||
viper.Set(config.Keys.Host, "localhost:8080")
|
||||
|
||||
sessionName, err := router.SessionName()
|
||||
suite.NoError(err)
|
||||
suite.Equal("gotosocial-localhost", sessionName)
|
||||
}
|
||||
|
||||
func (suite *SessionTestSuite) TestDeriveSessionNameLocalhost() {
|
||||
cfg := &config.Config{
|
||||
Protocol: "http",
|
||||
Host: "localhost",
|
||||
}
|
||||
viper.Set(config.Keys.Protocol, "http")
|
||||
viper.Set(config.Keys.Host, "localhost")
|
||||
|
||||
sessionName, err := sessionName(cfg)
|
||||
sessionName, err := router.SessionName()
|
||||
suite.NoError(err)
|
||||
suite.Equal("gotosocial-localhost", sessionName)
|
||||
}
|
||||
|
||||
func (suite *SessionTestSuite) TestDeriveSessionNoProtocol() {
|
||||
cfg := &config.Config{
|
||||
Host: "localhost",
|
||||
}
|
||||
viper.Set(config.Keys.Protocol, "")
|
||||
viper.Set(config.Keys.Host, "localhost")
|
||||
|
||||
sessionName, err := sessionName(cfg)
|
||||
sessionName, err := router.SessionName()
|
||||
suite.EqualError(err, "parse \"://localhost\": missing protocol scheme")
|
||||
suite.Equal("", sessionName)
|
||||
}
|
||||
|
||||
func (suite *SessionTestSuite) TestDeriveSessionNoHost() {
|
||||
cfg := &config.Config{
|
||||
Protocol: "https",
|
||||
}
|
||||
viper.Set(config.Keys.Protocol, "https")
|
||||
viper.Set(config.Keys.Host, "")
|
||||
viper.Set(config.Keys.Port, 0)
|
||||
|
||||
sessionName, err := sessionName(cfg)
|
||||
sessionName, err := router.SessionName()
|
||||
suite.EqualError(err, "could not derive hostname without port from https://")
|
||||
suite.Equal("", sessionName)
|
||||
}
|
||||
|
||||
func (suite *SessionTestSuite) TestDeriveSessionOK() {
|
||||
cfg := &config.Config{
|
||||
Protocol: "https",
|
||||
Host: "example.org",
|
||||
}
|
||||
viper.Set(config.Keys.Protocol, "https")
|
||||
viper.Set(config.Keys.Host, "example.org")
|
||||
|
||||
sessionName, err := sessionName(cfg)
|
||||
sessionName, err := router.SessionName()
|
||||
suite.NoError(err)
|
||||
suite.Equal("gotosocial-example.org", sessionName)
|
||||
}
|
||||
|
@@ -26,18 +26,20 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
)
|
||||
|
||||
// loadTemplates loads html templates for use by the given engine
|
||||
func loadTemplates(cfg *config.Config, engine *gin.Engine) error {
|
||||
func loadTemplates(engine *gin.Engine) error {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting current working directory: %s", err)
|
||||
}
|
||||
|
||||
tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", cfg.TemplateConfig.BaseDir))
|
||||
templateBaseDir := viper.GetString(config.Keys.WebTemplateBaseDir)
|
||||
tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", templateBaseDir))
|
||||
|
||||
engine.LoadHTMLGlob(tmPath)
|
||||
return nil
|
||||
|
Reference in New Issue
Block a user