Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: only load config and rules once #1470

Merged
merged 6 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 57 additions & 50 deletions config/configLoadHelpers.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package config

import (
"bytes"
"crypto/md5"
"encoding/hex"
"encoding/json"
Expand Down Expand Up @@ -58,7 +57,7 @@ func formatFromResponse(resp *http.Response) Format {
}

// getReaderFor returns an io.ReadCloser for the given URL or filename.
func getReaderFor(u string) (io.ReadCloser, Format, error) {
func getReaderFor(u string) ([]byte, Format, error) {
if u == "" {
return nil, FormatUnknown, fmt.Errorf("empty url")
}
Expand All @@ -68,7 +67,7 @@ func getReaderFor(u string) (io.ReadCloser, Format, error) {
}
switch uu.Scheme {
case "file", "": // we treat an empty scheme as a filename
r, err := os.Open(uu.Path)
r, err := os.ReadFile(uu.Path)
if err != nil {
return nil, FormatUnknown, err
}
Expand Down Expand Up @@ -102,76 +101,86 @@ func getReaderFor(u string) (io.ReadCloser, Format, error) {
if format == FormatUnknown {
format = formatFromFilename(uu.Path)
}
return resp.Body, format, nil
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, FormatUnknown, err
}
return body, format, nil
default:
return nil, FormatUnknown, fmt.Errorf("unknown scheme %q", uu.Scheme)
}
}

func load(r io.Reader, format Format, into any) error {
func load(data []byte, format Format, into any) error {
switch format {
case FormatYAML:
decoder := yaml.NewDecoder(r)
err := decoder.Decode(into)
err := yaml.Unmarshal(data, into)
return err
case FormatTOML:
decoder := toml.NewDecoder(r)
err := decoder.Decode(into)
err := toml.Unmarshal(data, into)
return err
case FormatJSON:
decoder := json.NewDecoder(r)
err := decoder.Decode(into)
err := json.Unmarshal(data, into)
return err
default:
return fmt.Errorf("unable to determine data format")
}
}

type configReader struct {
body []byte
format Format
location string
}

func getReadersForLocations(locations []string) ([]configReader, error) {
readers := make([]configReader, len(locations))
for i, location := range locations {
// trim leading and trailing whitespace just in case
location := strings.TrimSpace(location)
body, format, err := getReaderFor(location)
if err != nil {
return nil, err
}
readers[i] = configReader{
body: body,
format: format,
location: location,
}
}
return readers, nil
}

// This loads all the named configs into destination in the order they are listed.
// It returns the MD5 hash of the collected configs as a string (if there's only one
// config, this is the hash of that config; if there are multiple, it's the hash of
// all of them concatenated together).
func loadConfigsInto(dest any, locations []string) (string, error) {
func loadConfigsInto(dest any, readers []configReader) (string, error) {
// start a hash of the configs we read
h := md5.New()
for _, location := range locations {
// trim leading and trailing whitespace just in case
location := strings.TrimSpace(location)
r, format, err := getReaderFor(location)
if err != nil {
return "", err
}
defer r.Close()
// write the data to the hash as we read it
rdr := io.TeeReader(r, h)
for _, reader := range readers {
// write the data to the hash
h.Write(reader.body)

// when working on a struct, load only overwrites destination values that are
// explicitly named. So we can just keep loading successive files into
// the same object without losing data we've already specified.
if err := load(rdr, format, dest); err != nil {
return "", fmt.Errorf("loadConfigsInto unable to load config %s: %w", location, err)
if err := load(reader.body, reader.format, dest); err != nil {
return "", fmt.Errorf("loadConfigsInto unable to load config %s: %w", reader.location, err)
}
}
hash := hex.EncodeToString(h.Sum(nil))
return hash, nil
}

func loadConfigsIntoMap(dest map[string]any, locations []string) error {
for _, location := range locations {
// trim leading and trailing whitespace just in case
location := strings.TrimSpace(location)
r, format, err := getReaderFor(location)
if err != nil {
return err
}
defer r.Close()

func loadConfigsIntoMap(dest map[string]any, readers []configReader) error {
for _, reader := range readers {
// when working on a map, when loading a nested object, load will overwrite the entire destination
// value, so we can't just keep loading successive files into the same object. Instead, we
// need to load into a new object and then merge it into the map.
temp := make(map[string]any)
if err := load(r, format, &temp); err != nil {
return fmt.Errorf("loadConfigsInto unable to load config %s: %w", location, err)
if err := load(reader.body, reader.format, &temp); err != nil {
return fmt.Errorf("loadConfigsInto unable to load config %s: %w", reader.location, err)
}
for k, v := range temp {
switch vm := v.(type) {
Expand Down Expand Up @@ -199,10 +208,10 @@ func loadConfigsIntoMap(dest map[string]any, locations []string) error {
// validateConfigs reads the configs from the given location and validates them.
// It returns a list of failures; if the list is empty, the config is valid.
// err is non-nil only for significant errors like a missing file.
func validateConfigs(opts *CmdEnv) ([]string, error) {
func validateConfigs(readers []configReader, opts *CmdEnv) ([]string, error) {
// first read the configs into a map so we can validate them
userData := make(map[string]any)
err := loadConfigsIntoMap(userData, opts.ConfigLocations)
err := loadConfigsIntoMap(userData, readers)
if err != nil {
return nil, err
}
Expand All @@ -220,7 +229,7 @@ func validateConfigs(opts *CmdEnv) ([]string, error) {
// Basic validation worked. Now we need to reload everything into our struct so that
// we can apply defaults and options, and then validate a second time.
var config configContents
_, err = loadConfigsInto(&config, opts.ConfigLocations)
_, err = loadConfigsInto(&config, readers)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -250,16 +259,14 @@ func validateConfigs(opts *CmdEnv) ([]string, error) {
}

// The validator needs a map[string]any to work with, so we need to
// write it out to a buffer (we always use YAML) and then reload it.
buf := new(bytes.Buffer)
encoder := yaml.NewEncoder(buf)
encoder.SetIndent(2)
if err := encoder.Encode(config); err != nil {
return nil, fmt.Errorf("readConfigInto unable to reencode config: %w", err)
// yaml bytes (we always use YAML) and then reload it.
data, err := yaml.Marshal(config)
if err != nil {
return nil, fmt.Errorf("readConfigInto unable to remarshal config: %w", err)
}

var rewrittenUserData map[string]any
if err := load(buf, FormatYAML, &rewrittenUserData); err != nil {
if err := load(data, FormatYAML, &rewrittenUserData); err != nil {
return nil, fmt.Errorf("validateConfig unable to reload hydrated config from buffer: %w", err)
}

Expand All @@ -268,10 +275,10 @@ func validateConfigs(opts *CmdEnv) ([]string, error) {
return failures, nil
}

func validateRules(locations []string) ([]string, error) {
func validateRules(readers []configReader) ([]string, error) {
// first read the configs into a map so we can validate them
userData := make(map[string]any)
err := loadConfigsIntoMap(userData, locations)
err := loadConfigsIntoMap(userData, readers)
if err != nil {
return nil, err
}
Expand All @@ -286,8 +293,8 @@ func validateRules(locations []string) ([]string, error) {
}

// readConfigInto reads the config from the given location and applies it to the given struct.
func readConfigInto(dest any, locations []string, opts *CmdEnv) (string, error) {
hash, err := loadConfigsInto(dest, locations)
func readConfigInto(dest any, readers []configReader, opts *CmdEnv) (string, error) {
hash, err := loadConfigsInto(dest, readers)
if err != nil {
return hash, err
}
Expand Down
17 changes: 11 additions & 6 deletions config/configLoadHelpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func Test_loadDuration(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := load(strings.NewReader(tt.text), tt.format, tt.into); (err != nil) != tt.wantErr {
if err := load([]byte(tt.text), tt.format, tt.into); (err != nil) != tt.wantErr {
t.Errorf("load() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(tt.into, tt.want) {
Expand Down Expand Up @@ -123,7 +123,7 @@ func Test_loadMemsize(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := load(strings.NewReader(tt.text), tt.format, tt.into); (err != nil) != tt.wantErr {
if err := load([]byte(tt.text), tt.format, tt.into); (err != nil) != tt.wantErr {
t.Errorf("load() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(tt.into, tt.want) {
Expand Down Expand Up @@ -198,9 +198,10 @@ func Test_loadConfigsInto(t *testing.T) {
cm1 := makeYAML("General.ConfigurationVersion", 2, "General.ConfigReloadInterval", Duration(1*time.Second), "Network.ListenAddr", "0.0.0.0:8080")
cm2 := makeYAML("General.ConfigReloadInterval", Duration(2*time.Second), "General.DatasetPrefix", "hello")
cfgfiles := createTempConfigs(t, cm1, cm2)

readers, err := getReadersForLocations(cfgfiles)
require.NoError(t, err)
cfg := configContents{}
hash, err := loadConfigsInto(&cfg, cfgfiles)
hash, err := loadConfigsInto(&cfg, readers)
require.NoError(t, err)
require.Equal(t, "2381a6563085f50ac56663b67ca85299", hash)
require.Equal(t, 2, cfg.General.ConfigurationVersion)
Expand All @@ -213,9 +214,11 @@ func Test_loadConfigsIntoMap(t *testing.T) {
cm1 := makeYAML("General.ConfigurationVersion", 2, "General.ConfigReloadInterval", Duration(1*time.Second), "Network.ListenAddr", "0.0.0.0:8080")
cm2 := makeYAML("General.ConfigReloadInterval", Duration(2*time.Second), "General.DatasetPrefix", "hello")
cfgfiles := createTempConfigs(t, cm1, cm2)
readers, err := getReadersForLocations(cfgfiles)
require.NoError(t, err)

cfg := map[string]any{}
err := loadConfigsIntoMap(cfg, cfgfiles)
err = loadConfigsIntoMap(cfg, readers)
require.NoError(t, err)
gen := cfg["General"].(map[string]any)
require.Equal(t, 2, gen["ConfigurationVersion"])
Expand Down Expand Up @@ -262,7 +265,9 @@ func Test_validateConfigs(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
cfgfiles := createTempConfigs(t, tt.cfgs...)
opts := &CmdEnv{ConfigLocations: cfgfiles}
got, err := validateConfigs(opts)
readers, err := getReadersForLocations(cfgfiles)
require.NoError(t, err)
got, err := validateConfigs(readers, opts)
if (err != nil) != tt.wantErr {
t.Errorf("validateConfigs() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down
17 changes: 13 additions & 4 deletions config/file_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,14 +453,23 @@ func (e *FileConfigError) Error() string {
// In order to do proper validation, we actually read the file twice -- once into
// a map, and once into the actual config object.
func newFileConfig(opts *CmdEnv) (*fileConfig, error) {
configReaders, err := getReadersForLocations(opts.ConfigLocations)
if err != nil {
return nil, err
}
rulesReaders, err := getReadersForLocations(opts.RulesLocations)
if err != nil {
return nil, err
}

// If we're not validating, skip this part
if !opts.NoValidate {
cfgFails, err := validateConfigs(opts)
cfgFails, err := validateConfigs(configReaders, opts)
if err != nil {
return nil, err
}

ruleFails, err := validateRules(opts.RulesLocations)
ruleFails, err := validateRules(rulesReaders)
if err != nil {
return nil, err
}
Expand All @@ -477,13 +486,13 @@ func newFileConfig(opts *CmdEnv) (*fileConfig, error) {

// Now load the files
mainconf := &configContents{}
mainhash, err := readConfigInto(mainconf, opts.ConfigLocations, opts)
mainhash, err := readConfigInto(mainconf, configReaders, opts)
if err != nil {
return nil, err
}

var rulesconf *V2SamplerConfig
ruleshash, err := readConfigInto(&rulesconf, opts.RulesLocations, nil)
ruleshash, err := readConfigInto(&rulesconf, rulesReaders, nil)
if err != nil {
return nil, err
}
Expand Down
Loading