From 1e64068578fb7815a5191e67e57d05bae065fd53 Mon Sep 17 00:00:00 2001 From: Joe Adams Date: Mon, 17 Oct 2022 19:45:39 -0400 Subject: [PATCH] Fix AWS credential expiration Use the AWS cached credential provider to automatically handle credentials. The CredentialsCache will automatically handle refreshing expired credentials and keeping them cached as long as necessary. Replaces #634 as this offloads more of the work to the AWS SDK Signed-off-by: Joe Adams --- pkg/roundtripper/roundtripper.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/pkg/roundtripper/roundtripper.go b/pkg/roundtripper/roundtripper.go index 824b77ed..cb3d8d1b 100644 --- a/pkg/roundtripper/roundtripper.go +++ b/pkg/roundtripper/roundtripper.go @@ -36,7 +36,7 @@ const ( type AWSSigningTransport struct { t http.RoundTripper - creds aws.Credentials + creds aws.CredentialsProvider region string log log.Logger } @@ -48,12 +48,17 @@ func NewAWSSigningTransport(transport http.RoundTripper, region string, log log. return nil, err } - creds, err := cfg.Credentials.Retrieve(context.Background()) + // Run a single fetch credentials operation to ensure that the credentials + // are valid before returning the transport. + _, err = cfg.Credentials.Retrieve(context.Background()) if err != nil { _ = level.Error(log).Log("msg", "fail to retrive aws credentials", "err", err) return nil, err } + // Build a cached credentials provider to manage the credentials and prevent new credentials on every request. + creds := aws.NewCredentialsCache(cfg.Credentials) + return &AWSSigningTransport{ t: transport, region: region, @@ -69,8 +74,15 @@ func (a *AWSSigningTransport) RoundTrip(req *http.Request) (*http.Response, erro _ = level.Error(a.log).Log("msg", "fail to hash request body", "err", err) return nil, err } + + creds, err := a.creds.Retrieve(context.Background()) + if err != nil { + _ = level.Error(a.log).Log("msg", "fail to retrive aws credentials", "err", err) + return nil, err + } + req.Body = newReader - err = signer.SignHTTP(context.Background(), a.creds, req, payloadHash, service, a.region, time.Now()) + err = signer.SignHTTP(context.Background(), creds, req, payloadHash, service, a.region, time.Now()) if err != nil { _ = level.Error(a.log).Log("msg", "fail to sign request body", "err", err) return nil, err