gotosocial/internal/config/state.go

149 lines
4.1 KiB
Go

// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package config
import (
"strings"
"sync"
"github.com/mitchellh/mapstructure"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
// ConfigState manages safe concurrent access to Configuration{} values,
// and provides ease of linking them (including reloading) via viper to
// environment, CLI and configuration file variables.
type ConfigState struct { //nolint
viper *viper.Viper
config Configuration
mutex sync.RWMutex
}
// NewState returns a new initialized ConfigState instance.
func NewState() *ConfigState {
viper := viper.New()
// Flag 'some-flag-name' becomes env var 'GTS_SOME_FLAG_NAME'
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.SetEnvPrefix("gts")
// Load appropriate named vals from env
viper.AutomaticEnv()
// Create new ConfigState with defaults
state := &ConfigState{
viper: viper,
config: Defaults,
}
// Perform initial load into viper
state.reloadToViper()
return state
}
// Config provides safe access to the ConfigState's contained Configuration,
// and will reload the current Configuration back into viper settings.
func (st *ConfigState) Config(fn func(*Configuration)) {
st.mutex.Lock()
defer func() {
st.reloadToViper()
st.mutex.Unlock()
}()
fn(&st.config)
}
// Viper provides safe access to the ConfigState's contained viper instance,
// and will reload the current viper setting state back into Configuration.
func (st *ConfigState) Viper(fn func(*viper.Viper)) {
st.mutex.Lock()
defer func() {
st.reloadFromViper()
st.mutex.Unlock()
}()
fn(st.viper)
}
// LoadEarlyFlags will bind specific flags from given Cobra command to ConfigState's viper
// instance, and load the current configuration values. This is useful for flags like
// .ConfigPath which have to parsed first in order to perform early configuration load.
func (st *ConfigState) LoadEarlyFlags(cmd *cobra.Command) (err error) {
name := ConfigPathFlag()
flag := cmd.Flags().Lookup(name)
st.Viper(func(v *viper.Viper) {
err = v.BindPFlag(name, flag)
})
return
}
// BindFlags will bind given Cobra command's pflags to this ConfigState's viper instance.
func (st *ConfigState) BindFlags(cmd *cobra.Command) (err error) {
st.Viper(func(v *viper.Viper) {
err = v.BindPFlags(cmd.Flags())
})
return
}
// Reload will reload the Configuration values from ConfigState's viper instance, and from file if set.
func (st *ConfigState) Reload() (err error) {
st.Viper(func(v *viper.Viper) {
if st.config.ConfigPath != "" {
// Ensure configuration path is set
v.SetConfigFile(st.config.ConfigPath)
// Read in configuration from file
if err = v.ReadInConfig(); err != nil {
return
}
}
})
return
}
// reloadToViper will reload Configuration{} values into viper.
func (st *ConfigState) reloadToViper() {
raw, err := st.config.MarshalMap()
if err != nil {
panic(err)
}
if err := st.viper.MergeConfigMap(raw); err != nil {
panic(err)
}
}
// reloadFromViper will reload Configuration{} values from viper.
func (st *ConfigState) reloadFromViper() {
if err := st.viper.Unmarshal(&st.config, func(c *mapstructure.DecoderConfig) {
c.TagName = "name"
// empty config before marshaling
c.ZeroFields = true
oldhook := c.DecodeHook
// Use the TextUnmarshaler interface when decoding.
c.DecodeHook = mapstructure.ComposeDecodeHookFunc(
mapstructure.TextUnmarshallerHookFunc(),
oldhook,
)
}); err != nil {
panic(err)
}
}