diff --git a/cmd/gotosocial/admincommands.go b/cmd/gotosocial/admincommands.go index 5d505fe77..a70693b2c 100644 --- a/cmd/gotosocial/admincommands.go +++ b/cmd/gotosocial/admincommands.go @@ -25,7 +25,7 @@ import ( "github.com/urfave/cli/v2" ) -func adminCommands() []*cli.Command { +func adminCommands(allFlags []cli.Flag) []*cli.Command { return []*cli.Command{ { Name: "admin", @@ -56,7 +56,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, account.Create) + return runAction(c, allFlags, account.Create) }, }, { @@ -70,7 +70,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, account.Confirm) + return runAction(c, allFlags, account.Confirm) }, }, { @@ -84,7 +84,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, account.Promote) + return runAction(c, allFlags, account.Promote) }, }, { @@ -98,7 +98,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, account.Demote) + return runAction(c, allFlags, account.Demote) }, }, { @@ -112,7 +112,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, account.Disable) + return runAction(c, allFlags, account.Disable) }, }, { @@ -126,7 +126,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, account.Suspend) + return runAction(c, allFlags, account.Suspend) }, }, { @@ -145,7 +145,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, account.Password) + return runAction(c, allFlags, account.Password) }, }, }, @@ -161,7 +161,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, trans.Export) + return runAction(c, allFlags, trans.Export) }, }, { @@ -175,7 +175,7 @@ func adminCommands() []*cli.Command { }, }, Action: func(c *cli.Context) error { - return runAction(c, trans.Import) + return runAction(c, allFlags, trans.Import) }, }, }, diff --git a/cmd/gotosocial/commands.go b/cmd/gotosocial/commands.go index 0c16d1fb5..9b61a66ec 100644 --- a/cmd/gotosocial/commands.go +++ b/cmd/gotosocial/commands.go @@ -22,12 +22,12 @@ import ( "github.com/urfave/cli/v2" ) -func getCommands() []*cli.Command { +func getCommands(allFlags []cli.Flag) []*cli.Command { commands := []*cli.Command{} commandSets := [][]*cli.Command{ - serverCommands(), - adminCommands(), - testrigCommands(), + serverCommands(allFlags), + adminCommands(allFlags), + testrigCommands(allFlags), } for _, cs := range commandSets { commands = append(commands, cs...) diff --git a/cmd/gotosocial/main.go b/cmd/gotosocial/main.go index 3d41e0fda..87c44487c 100644 --- a/cmd/gotosocial/main.go +++ b/cmd/gotosocial/main.go @@ -42,11 +42,12 @@ func main() { v = Version + " " + Commit[:7] } + flagsSlice := getFlags() app := &cli.App{ Version: v, Usage: "a fediverse social media server", - Flags: getFlags(), - Commands: getCommands(), + Flags: flagsSlice, + Commands: getCommands(flagsSlice), } if err := app.Run(os.Args); err != nil { diff --git a/cmd/gotosocial/runaction.go b/cmd/gotosocial/runaction.go index c8af9ddbe..96c4edaf6 100644 --- a/cmd/gotosocial/runaction.go +++ b/cmd/gotosocial/runaction.go @@ -27,17 +27,48 @@ import ( "github.com/urfave/cli/v2" ) +type MonkeyPatchedCLIContext struct { + CLIContext *cli.Context + AllFlags []cli.Flag +} + +func (f MonkeyPatchedCLIContext) Bool(k string) bool { return f.CLIContext.Bool(k) } +func (f MonkeyPatchedCLIContext) String(k string) string { return f.CLIContext.String(k) } +func (f MonkeyPatchedCLIContext) StringSlice(k string) []string { return f.CLIContext.StringSlice(k) } +func (f MonkeyPatchedCLIContext) Int(k string) int { return f.CLIContext.Int(k) } +func (f MonkeyPatchedCLIContext) IsSet(k string) bool { + for _, flag := range f.AllFlags { + flagNames := flag.Names() + for _, name := range flagNames { + if name == k { + return flag.IsSet() + } + } + + } + return false +} + // runAction builds up the config and logger necessary for any // gotosocial action, and then executes the action. -func runAction(c *cli.Context, a cliactions.GTSAction) error { +func runAction(c *cli.Context, allFlags []cli.Flag, a cliactions.GTSAction) error { // create a new *config.Config based on the config path provided... conf, err := config.FromFile(c.String(config.GetFlagNames().ConfigPath)) if err != nil { return fmt.Errorf("error creating config: %s", err) } + // ... and the flags set on the *cli.Context by urfave - if err := conf.ParseCLIFlags(c, c.App.Version); err != nil { + // + // The IsSet function on the cli.Context object `c` here appears to have some issues right now, it always returns false in my tests. + // However we can re-create the behaviour we want by simply referencing the flag objects we created previously + // https://picopublish.sequentialread.com/files/chatlog_2021_11_18.txt + monkeyPatchedCLIContext := MonkeyPatchedCLIContext{ + CLIContext: c, + AllFlags: allFlags, + } + if err := conf.ParseCLIFlags(monkeyPatchedCLIContext, c.App.Version); err != nil { return fmt.Errorf("error parsing config: %s", err) } diff --git a/cmd/gotosocial/servercommands.go b/cmd/gotosocial/servercommands.go index 7a1692b79..fb6574216 100644 --- a/cmd/gotosocial/servercommands.go +++ b/cmd/gotosocial/servercommands.go @@ -23,7 +23,7 @@ import ( "github.com/urfave/cli/v2" ) -func serverCommands() []*cli.Command { +func serverCommands(allFlags []cli.Flag) []*cli.Command { return []*cli.Command{ { Name: "server", @@ -33,7 +33,7 @@ func serverCommands() []*cli.Command { Name: "start", Usage: "start the gotosocial server", Action: func(c *cli.Context) error { - return runAction(c, server.Start) + return runAction(c, allFlags, server.Start) }, }, }, diff --git a/cmd/gotosocial/testrigcommands.go b/cmd/gotosocial/testrigcommands.go index 9b9aa3806..aabe04267 100644 --- a/cmd/gotosocial/testrigcommands.go +++ b/cmd/gotosocial/testrigcommands.go @@ -23,7 +23,7 @@ import ( "github.com/urfave/cli/v2" ) -func testrigCommands() []*cli.Command { +func testrigCommands(allFlags []cli.Flag) []*cli.Command { return []*cli.Command{ { Name: "testrig", @@ -33,7 +33,7 @@ func testrigCommands() []*cli.Command { Name: "start", Usage: "start the gotosocial testrig", Action: func(c *cli.Context) error { - return runAction(c, testrig.Start) + return runAction(c, allFlags, testrig.Start) }, }, }, diff --git a/internal/api/client/auth/auth_test.go b/internal/api/client/auth/auth_test.go index ca606e8f2..ae58ffbbb 100644 --- a/internal/api/client/auth/auth_test.go +++ b/internal/api/client/auth/auth_test.go @@ -47,7 +47,7 @@ type AuthTestSuite struct { // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout func (suite *AuthTestSuite) SetupSuite() { - c := config.Empty() + c := config.Default() // we're running on localhost without https so set the protocol to http c.Protocol = "http" // just for testing diff --git a/internal/config/config.go b/internal/config/config.go index 9daeba0f0..eb9e6385b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -83,24 +83,7 @@ func FromFile(path string) (*Config, error) { } return c, nil } - return Empty(), nil -} - -// Empty just returns a new empty config -func Empty() *Config { - return &Config{ - DBConfig: &DBConfig{}, - TemplateConfig: &TemplateConfig{}, - AccountsConfig: &AccountsConfig{}, - MediaConfig: &MediaConfig{}, - StorageConfig: &StorageConfig{}, - StatusesConfig: &StatusesConfig{}, - LetsEncryptConfig: &LetsEncryptConfig{}, - OIDCConfig: &OIDCConfig{}, - SMTPConfig: &SMTPConfig{}, - AccountCLIFlags: make(map[string]string), - ExportCLIFlags: make(map[string]string), - } + return Default(), nil } // loadFromFile takes a path to a yaml file and attempts to load a Config object from it @@ -110,7 +93,7 @@ func loadFromFile(path string) (*Config, error) { return nil, fmt.Errorf("could not read file at path %s: %s", path, err) } - config := Empty() + config := Default() if err := yaml.Unmarshal(bytes, config); err != nil { return nil, fmt.Errorf("could not unmarshal file at path %s: %s", path, err) } @@ -131,87 +114,87 @@ func (c *Config) ParseCLIFlags(f KeyedFlags, version string) error { // as a command-line argument or an env variable, which takes priority. // general flags - if c.LogLevel == "" || f.IsSet(fn.LogLevel) { + if f.IsSet(fn.LogLevel) { c.LogLevel = f.String(fn.LogLevel) } - if c.ApplicationName == "" || f.IsSet(fn.ApplicationName) { + if f.IsSet(fn.ApplicationName) { c.ApplicationName = f.String(fn.ApplicationName) } - if c.Host == "" || f.IsSet(fn.Host) { + if f.IsSet(fn.Host) { c.Host = f.String(fn.Host) } if c.Host == "" { return errors.New("host was not set") } - if c.AccountDomain == "" || f.IsSet(fn.AccountDomain) { + if f.IsSet(fn.AccountDomain) { c.AccountDomain = f.String(fn.AccountDomain) } if c.AccountDomain == "" { c.AccountDomain = c.Host // default to whatever the host is, if this is empty } - if c.Protocol == "" || f.IsSet(fn.Protocol) { + if f.IsSet(fn.Protocol) { c.Protocol = f.String(fn.Protocol) } if c.Protocol == "" { return errors.New("protocol was not set") } - if c.BindAddress == "" || f.IsSet(fn.BindAddress) { + if f.IsSet(fn.BindAddress) { c.BindAddress = f.String(fn.BindAddress) } - if c.Port == 0 || f.IsSet(fn.Port) { + if f.IsSet(fn.Port) { c.Port = f.Int(fn.Port) } - if len(c.TrustedProxies) == 0 || f.IsSet(fn.TrustedProxies) { + if f.IsSet(fn.TrustedProxies) { c.TrustedProxies = f.StringSlice(fn.TrustedProxies) } // db flags - if c.DBConfig.Type == "" || f.IsSet(fn.DbType) { + if f.IsSet(fn.DbType) { c.DBConfig.Type = f.String(fn.DbType) } - if c.DBConfig.Address == "" || f.IsSet(fn.DbAddress) { + if f.IsSet(fn.DbAddress) { c.DBConfig.Address = f.String(fn.DbAddress) } - if c.DBConfig.Port == 0 || f.IsSet(fn.DbPort) { + if f.IsSet(fn.DbPort) { c.DBConfig.Port = f.Int(fn.DbPort) } - if c.DBConfig.User == "" || f.IsSet(fn.DbUser) { + if f.IsSet(fn.DbUser) { c.DBConfig.User = f.String(fn.DbUser) } - if c.DBConfig.Password == "" || f.IsSet(fn.DbPassword) { + if f.IsSet(fn.DbPassword) { c.DBConfig.Password = f.String(fn.DbPassword) } - if c.DBConfig.Database == "" || f.IsSet(fn.DbDatabase) { + if f.IsSet(fn.DbDatabase) { c.DBConfig.Database = f.String(fn.DbDatabase) } - if c.DBConfig.TLSMode == DBTLSModeUnset || f.IsSet(fn.DbTLSMode) { + if f.IsSet(fn.DbTLSMode) { c.DBConfig.TLSMode = DBTLSMode(f.String(fn.DbTLSMode)) } - if c.DBConfig.TLSCACert == "" || f.IsSet(fn.DbTLSCACert) { + if f.IsSet(fn.DbTLSCACert) { c.DBConfig.TLSCACert = f.String(fn.DbTLSCACert) } // template flags - if c.TemplateConfig.BaseDir == "" || f.IsSet(fn.TemplateBaseDir) { + if f.IsSet(fn.TemplateBaseDir) { c.TemplateConfig.BaseDir = f.String(fn.TemplateBaseDir) } // template flags - if c.TemplateConfig.AssetBaseDir == "" || f.IsSet(fn.AssetBaseDir) { + if f.IsSet(fn.AssetBaseDir) { c.TemplateConfig.AssetBaseDir = f.String(fn.AssetBaseDir) } @@ -225,57 +208,57 @@ func (c *Config) ParseCLIFlags(f KeyedFlags, version string) error { } // media flags - if c.MediaConfig.MaxImageSize == 0 || f.IsSet(fn.MediaMaxImageSize) { + if f.IsSet(fn.MediaMaxImageSize) { c.MediaConfig.MaxImageSize = f.Int(fn.MediaMaxImageSize) } - if c.MediaConfig.MaxVideoSize == 0 || f.IsSet(fn.MediaMaxVideoSize) { + if f.IsSet(fn.MediaMaxVideoSize) { c.MediaConfig.MaxVideoSize = f.Int(fn.MediaMaxVideoSize) } - if c.MediaConfig.MinDescriptionChars == 0 || f.IsSet(fn.MediaMinDescriptionChars) { + if f.IsSet(fn.MediaMinDescriptionChars) { c.MediaConfig.MinDescriptionChars = f.Int(fn.MediaMinDescriptionChars) } - if c.MediaConfig.MaxDescriptionChars == 0 || f.IsSet(fn.MediaMaxDescriptionChars) { + if f.IsSet(fn.MediaMaxDescriptionChars) { c.MediaConfig.MaxDescriptionChars = f.Int(fn.MediaMaxDescriptionChars) } // storage flags - if c.StorageConfig.Backend == "" || f.IsSet(fn.StorageBackend) { + if f.IsSet(fn.StorageBackend) { c.StorageConfig.Backend = f.String(fn.StorageBackend) } - if c.StorageConfig.BasePath == "" || f.IsSet(fn.StorageBasePath) { + if f.IsSet(fn.StorageBasePath) { c.StorageConfig.BasePath = f.String(fn.StorageBasePath) } - if c.StorageConfig.ServeProtocol == "" || f.IsSet(fn.StorageServeProtocol) { + if f.IsSet(fn.StorageServeProtocol) { c.StorageConfig.ServeProtocol = f.String(fn.StorageServeProtocol) } - if c.StorageConfig.ServeHost == "" || f.IsSet(fn.StorageServeHost) { + if f.IsSet(fn.StorageServeHost) { c.StorageConfig.ServeHost = f.String(fn.StorageServeHost) } - if c.StorageConfig.ServeBasePath == "" || f.IsSet(fn.StorageServeBasePath) { + if f.IsSet(fn.StorageServeBasePath) { c.StorageConfig.ServeBasePath = f.String(fn.StorageServeBasePath) } // statuses flags - if c.StatusesConfig.MaxChars == 0 || f.IsSet(fn.StatusesMaxChars) { + if f.IsSet(fn.StatusesMaxChars) { c.StatusesConfig.MaxChars = f.Int(fn.StatusesMaxChars) } - if c.StatusesConfig.CWMaxChars == 0 || f.IsSet(fn.StatusesCWMaxChars) { + if f.IsSet(fn.StatusesCWMaxChars) { c.StatusesConfig.CWMaxChars = f.Int(fn.StatusesCWMaxChars) } - if c.StatusesConfig.PollMaxOptions == 0 || f.IsSet(fn.StatusesPollMaxOptions) { + if f.IsSet(fn.StatusesPollMaxOptions) { c.StatusesConfig.PollMaxOptions = f.Int(fn.StatusesPollMaxOptions) } - if c.StatusesConfig.PollOptionMaxChars == 0 || f.IsSet(fn.StatusesPollOptionMaxChars) { + if f.IsSet(fn.StatusesPollOptionMaxChars) { c.StatusesConfig.PollOptionMaxChars = f.Int(fn.StatusesPollOptionMaxChars) } - if c.StatusesConfig.MaxMediaFiles == 0 || f.IsSet(fn.StatusesMaxMediaFiles) { + if f.IsSet(fn.StatusesMaxMediaFiles) { c.StatusesConfig.MaxMediaFiles = f.Int(fn.StatusesMaxMediaFiles) } @@ -284,15 +267,15 @@ func (c *Config) ParseCLIFlags(f KeyedFlags, version string) error { c.LetsEncryptConfig.Enabled = f.Bool(fn.LetsEncryptEnabled) } - if c.LetsEncryptConfig.Port == 0 || f.IsSet(fn.LetsEncryptPort) { + if f.IsSet(fn.LetsEncryptPort) { c.LetsEncryptConfig.Port = f.Int(fn.LetsEncryptPort) } - if c.LetsEncryptConfig.CertDir == "" || f.IsSet(fn.LetsEncryptCertDir) { + if f.IsSet(fn.LetsEncryptCertDir) { c.LetsEncryptConfig.CertDir = f.String(fn.LetsEncryptCertDir) } - if c.LetsEncryptConfig.EmailAddress == "" || f.IsSet(fn.LetsEncryptEmailAddress) { + if f.IsSet(fn.LetsEncryptEmailAddress) { c.LetsEncryptConfig.EmailAddress = f.String(fn.LetsEncryptEmailAddress) } @@ -301,7 +284,7 @@ func (c *Config) ParseCLIFlags(f KeyedFlags, version string) error { c.OIDCConfig.Enabled = f.Bool(fn.OIDCEnabled) } - if c.OIDCConfig.IDPName == "" || f.IsSet(fn.OIDCIdpName) { + if f.IsSet(fn.OIDCIdpName) { c.OIDCConfig.IDPName = f.String(fn.OIDCIdpName) } @@ -309,40 +292,40 @@ func (c *Config) ParseCLIFlags(f KeyedFlags, version string) error { c.OIDCConfig.SkipVerification = f.Bool(fn.OIDCSkipVerification) } - if c.OIDCConfig.Issuer == "" || f.IsSet(fn.OIDCIssuer) { + if f.IsSet(fn.OIDCIssuer) { c.OIDCConfig.Issuer = f.String(fn.OIDCIssuer) } - if c.OIDCConfig.ClientID == "" || f.IsSet(fn.OIDCClientID) { + if f.IsSet(fn.OIDCClientID) { c.OIDCConfig.ClientID = f.String(fn.OIDCClientID) } - if c.OIDCConfig.ClientSecret == "" || f.IsSet(fn.OIDCClientSecret) { + if f.IsSet(fn.OIDCClientSecret) { c.OIDCConfig.ClientSecret = f.String(fn.OIDCClientSecret) } - if len(c.OIDCConfig.Scopes) == 0 || f.IsSet(fn.OIDCScopes) { + if f.IsSet(fn.OIDCScopes) { c.OIDCConfig.Scopes = f.StringSlice(fn.OIDCScopes) } // smtp flags - if c.SMTPConfig.Host == "" || f.IsSet(fn.SMTPHost) { + if f.IsSet(fn.SMTPHost) { c.SMTPConfig.Host = f.String(fn.SMTPHost) } - if c.SMTPConfig.Port == 0 || f.IsSet(fn.SMTPPort) { + if f.IsSet(fn.SMTPPort) { c.SMTPConfig.Port = f.Int(fn.SMTPPort) } - if c.SMTPConfig.Username == "" || f.IsSet(fn.SMTPUsername) { + if f.IsSet(fn.SMTPUsername) { c.SMTPConfig.Username = f.String(fn.SMTPUsername) } - if c.SMTPConfig.Password == "" || f.IsSet(fn.SMTPPassword) { + if f.IsSet(fn.SMTPPassword) { c.SMTPConfig.Password = f.String(fn.SMTPPassword) } - if c.SMTPConfig.From == "" || f.IsSet(fn.SMTPFrom) { + if f.IsSet(fn.SMTPFrom) { c.SMTPConfig.From = f.String(fn.SMTPFrom) } diff --git a/internal/config/default.go b/internal/config/default.go index 996033ae4..6e8f63177 100644 --- a/internal/config/default.go +++ b/internal/config/default.go @@ -150,6 +150,8 @@ func Default() *Config { Password: defaults.SMTPPassword, From: defaults.SMTPFrom, }, + AccountCLIFlags: make(map[string]string), + ExportCLIFlags: make(map[string]string), } }