Refactor config package

This commit is contained in:
Kristóf Tóth 2021-10-25 18:28:36 +02:00
parent 0ce211b46d
commit 0bdbff89ac
3 changed files with 144 additions and 82 deletions

View File

@ -18,11 +18,11 @@ import (
type afterLockConfig struct { type afterLockConfig struct {
INITIAL_DELAY int INITIAL_DELAY int `optional`
LOOP_DELAY int LOOP_DELAY int `optional`
LOCK_PATH string LOCK_PATH string `optional`
LOG_PATH string LOG_PATH string
PRINT_STACKTRACES bool PRINT_STACKTRACES bool `optional`
} }
var configuration = afterLockConfig{ var configuration = afterLockConfig{
@ -46,7 +46,7 @@ func init() {
func main() { func main() {
defer handleErrors() defer handleErrors()
err := config.Build("AFL", &configuration) err := config.NewBuilder().Build("AFL_", &configuration)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -1,99 +1,95 @@
package config package config
import ( import (
"os"
"strings"
"fmt" "fmt"
"os"
"reflect" "reflect"
"errors"
"strconv" "strconv"
"strings"
) )
var ConfigFetchMethod = os.Environ type ConfigBuilder struct {
ConfigFetchMethod func()[]string
typeParsers map[reflect.Kind]func(string)(interface{}, error)
}
func NewBuilder() *ConfigBuilder {
func Build(prefix string, configStruct interface{}) (err error) { return &ConfigBuilder{
defer func() { ConfigFetchMethod: os.Environ,
if e := recover(); e != nil { typeParsers: map[reflect.Kind]func(string)(interface{}, error) {
switch r := e.(type) { reflect.String: func(value string) (interface{}, error) {
case string: return value, nil
err = errors.New(r) },
case error: reflect.Int: func(value string) (interface{}, error) {
err = r i, err := strconv.Atoi(value)
default: if err != nil {
err = errors.New("Undefined panic") return nil, err
}
return i, nil
},
reflect.Bool: func(value string) (interface{}, error) {
b, err := strconv.ParseBool(value)
if err != nil {
return nil, err
}
return b, nil
},
},
} }
} }
}()
func (cb *ConfigBuilder) Build(prefix string, configStruct interface{}) error {
assertIsPtr(configStruct) assertIsPtr(configStruct)
fieldNames := getConfigStructFieldNames(configStruct) fieldNames := getConfigStructFieldNames(configStruct)
config := loadConfig(prefix) config := cb.loadConfig(prefix)
structRef := reflect.ValueOf(configStruct) structRef := reflect.ValueOf(configStruct)
for _, fieldName := range fieldNames { for _, fieldName := range fieldNames {
value, ok := config[fieldName] value, ok := config[fieldName]
if !ok { if !ok {
// no such field loaded, avoid overwriting default // no such field loaded, avoid overwriting default if field is optional
field, _ := reflect.TypeOf(configStruct).Elem().FieldByName(fieldName)
if strings.Contains(string(field.Tag), "optional") {
continue continue
} }
return fmt.Errorf("no config found for required field '%s'", fieldName)
}
field := reflect.Indirect(structRef).FieldByName(fieldName) field := reflect.Indirect(structRef).FieldByName(fieldName)
typeRef := field.Type() typeRef := field.Type()
ret, err := tryParseString(value, typeRef.Kind()) typedValue, err := cb.tryParseString(value, typeRef.Kind())
if err != nil { if err != nil {
return err return err
} }
field.Set(reflect.ValueOf(ret).Convert(typeRef)) field.Set(reflect.ValueOf(typedValue).Convert(typeRef))
} }
return nil return nil
} }
func tryParseString(what string, toType reflect.Kind) (interface{}, error) { func assertIsPtr(configStruct interface{}) {
switch toType { r := reflect.ValueOf(configStruct)
case reflect.String:
return what, nil
case reflect.Int:
i, err := strconv.Atoi(what)
if err != nil {
return nil, err
}
return i, nil
case reflect.Bool:
b, err := strconv.ParseBool(what)
if err != nil {
return nil, err
}
return b, nil
default:
return nil, fmt.Errorf("Failed to parse value %v", what)
}
}
func assertIsPtr(what interface{}) {
r := reflect.ValueOf(what)
if r.Kind() != reflect.Ptr { if r.Kind() != reflect.Ptr {
panic(fmt.Errorf("Supplied value is not a pointer to a struct")) panic(fmt.Errorf("supplied value is not a pointer to a struct"))
} }
} }
func getConfigStructFieldNames(configStruct interface{}) []string { func getConfigStructFieldNames(configStruct interface{}) []string {
configFields := []string{} configFields := []string{}
structValue := reflect.ValueOf(configStruct).Elem()
val := reflect.ValueOf(configStruct).Elem() for i := 0; i < structValue.NumField(); i++ {
for i := 0; i < val.NumField(); i++ { configFields = append(configFields, structValue.Type().Field(i).Name)
configFields = append(configFields, val.Type().Field(i).Name)
} }
return configFields return configFields
} }
func loadConfig(prefix string) map[string]string { func (cb *ConfigBuilder) loadConfig(prefix string) map[string]string {
prefix = fmt.Sprintf("%s_", prefix)
config := map[string]string{} config := map[string]string{}
for _, env := range ConfigFetchMethod() {
for _, env := range cb.ConfigFetchMethod() {
parts := strings.SplitN(env, "=", 2) parts := strings.SplitN(env, "=", 2)
key, value := parts[0], parts[1] key, value := parts[0], parts[1]
if strings.HasPrefix(key, prefix) { if strings.HasPrefix(key, prefix) {
@ -101,15 +97,22 @@ func loadConfig(prefix string) map[string]string {
config[key] = value config[key] = value
} }
} }
return config return config
} }
// CheckKeysExist returns an error if not all keys are present in config map func (cb *ConfigBuilder) tryParseString(value string, toType reflect.Kind) (interface{}, error) {
func CheckKeysExist(config map[string]string, keys ...string) error { parser, ok := cb.typeParsers[toType]
for _, key := range keys { if !ok {
if _, ok := config[key]; ok { return nil, fmt.Errorf("no parser found for type %v", toType)
return fmt.Errorf("Config key '%s' is not set", key)
} }
parsed_val, err := parser(value)
if err != nil {
return nil, err
} }
return nil return parsed_val, nil
}
func (cb *ConfigBuilder) RegisterParser(toType reflect.Kind, parser func(string)(interface{}, error)) {
cb.typeParsers[toType] = parser
} }

View File

@ -1,9 +1,11 @@
package config_test package config_test
import ( import (
"testing"
"kdelsd/config" "kdelsd/config"
"reflect"
"strings"
"strconv"
"testing"
) )
@ -11,10 +13,11 @@ type TestConfig struct {
INT int INT int
BOOL bool BOOL bool
STR string STR string
OTHER_STR string OTHER_STR string `optional`
COMPLEX complex64 `optional`
} }
var testConfigInstance = TestConfig{ var testConfigReference = TestConfig{
INT: 42, INT: 42,
BOOL: true, BOOL: true,
STR: "sajtok", STR: "sajtok",
@ -27,23 +30,79 @@ var testEnv = []string {
"TEST_STR=sajtok", "TEST_STR=sajtok",
} }
func setup() { func configBuilder() *config.ConfigBuilder {
config.ConfigFetchMethod = func() []string { cb := config.NewBuilder()
cb.ConfigFetchMethod = func()[]string {
return testEnv return testEnv
} }
return cb
} }
func TestBuildConfig(t *testing.T) {
setup() func TestConfigBuilder(t *testing.T) {
c := TestConfig{} c := TestConfig{}
c.OTHER_STR = testConfigInstance.OTHER_STR c.OTHER_STR = testConfigReference.OTHER_STR
builder := configBuilder()
err := config.Build("TEST", &c) err := builder.Build("TEST_", &c)
if err != nil { if err != nil {
panic(err) t.Fatalf("Config building should not fail\n")
} }
if c != testConfigInstance { if c != testConfigReference {
t.Errorf("%v != %v", c, testConfigInstance) t.Fatalf("Parsed config shoudl equal reference config\n%v != %v\n", c, testConfigReference)
}
}
func TestConfigBuilderParsers(t *testing.T) {
c := TestConfig{}
builder := configBuilder()
builder.ConfigFetchMethod = func()[]string {
return append(testEnv, "TEST_COMPLEX=42+42i")
}
err := builder.Build("TEST_", &c)
if err == nil {
t.Fatalf("CONFIG should not be parseable, since type has no default parser\n")
}
if !strings.Contains(err.Error(), "complex64") {
t.Fatalf("Error should contain the type that failed to parse\n")
}
builder.RegisterParser(reflect.Complex64, func(value string)(interface{}, error) {
num, err := strconv.ParseComplex(value, 64)
if err != nil {
return nil, err
}
return num, nil
})
err = builder.Build("TEST_", &c)
if err != nil {
t.Fatalf("Parsing type should not fail after registering a parser for it\n")
}
}
func TestConfigBuilderOptional(t *testing.T) {
type RequiredConfig struct {
REQUIRED_STR string
}
c := RequiredConfig{}
builder := config.NewBuilder()
err := builder.Build("TEST_", &c)
if err == nil {
t.Fatalf("Fields not marked with the `optional` tag should be mandatory\n")
}
if !strings.Contains(err.Error(), "REQUIRED_STR") {
t.Fatalf("Error should contain the name of the field\n")
}
builder.ConfigFetchMethod = func()[]string {
return append(testEnv, "TEST_REQUIRED_STR=prrrr")
}
err = builder.Build("TEST_", &c)
if err != nil {
t.Fatalf("Config building should not fail after providing the mandatory field\n")
} }
} }