From 36cb165f1cf82687f1c0e93573c31c8f4c84ce5a Mon Sep 17 00:00:00 2001 From: Wioletta Holownia Date: Thu, 23 Jul 2020 11:27:25 -0400 Subject: [PATCH] Allow to customize isRetriable fn on the ACME client. (#1) * Allow to customize isRetriable fn on the ACME client. --- acme/acme.go | 5 ++ acme/http.go | 83 ++++++++++++++++------------ acme/http_test.go | 137 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 190 insertions(+), 35 deletions(-) diff --git a/acme/acme.go b/acme/acme.go index 6e6c9d1319..9a5cd73547 100644 --- a/acme/acme.go +++ b/acme/acme.go @@ -100,6 +100,10 @@ type Client struct { // will have no effect. DirectoryURL string + // ShouldRetry reports whether a request can be retried based on the HTTP response status code. + // When ShouldRetry is nil, the default behavior will take precedence. + ShouldRetry func(resp *http.Response) bool + // RetryBackoff computes the duration after which the nth retry of a failed request // should occur. The value of n for the first call on failure is 1. // The values of r and resp are the request and response of the last failed attempt. @@ -108,6 +112,7 @@ type Client struct { // // Requests which result in a 4xx client error are not retried, // except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests. + // If ShouldRetry is not nil, it will override the default behavior described above. // // If RetryBackoff is nil, a truncated exponential backoff algorithm // with the ceiling of 10 seconds is used, where each subsequent retry n diff --git a/acme/http.go b/acme/http.go index c51943e71a..588fa9e05a 100644 --- a/acme/http.go +++ b/acme/http.go @@ -19,9 +19,13 @@ import ( "time" ) -// retryTimer encapsulates common logic for retrying unsuccessful requests. +// retries encapsulates common logic for retrying unsuccessful requests. // It is not safe for concurrent use. -type retryTimer struct { +type retries struct { + // shouldRetry reports whether a request can be retried based on the HTTP response status code. + // See Client.ShouldRetry doc comment. + shouldRetry func(resp *http.Response) bool + // backoffFn provides backoff delay sequence for retries. // See Client.RetryBackoff doc comment. backoffFn func(n int, r *http.Request, res *http.Response) time.Duration @@ -29,15 +33,15 @@ type retryTimer struct { n int } -func (t *retryTimer) inc() { - t.n++ +func (r *retries) inc() { + r.n++ } // backoff pauses the current goroutine as described in Client.RetryBackoff. -func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Response) error { - d := t.backoffFn(t.n, r, res) +func (r *retries) backoff(ctx context.Context, req *http.Request, resp *http.Response) error { + d := r.backoffFn(r.n, req, resp) if d <= 0 { - return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", r.URL, t.n) + return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", req.URL, r.n) } wakeup := time.NewTimer(d) defer wakeup.Stop() @@ -49,12 +53,21 @@ func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Res } } -func (c *Client) retryTimer() *retryTimer { - f := c.RetryBackoff - if f == nil { - f = defaultBackoff +func (c *Client) retries() *retries { + backoff := c.RetryBackoff + if backoff == nil { + backoff = defaultBackoff + } + + shouldRetry := c.ShouldRetry + if shouldRetry == nil { + shouldRetry = defaultShouldRetry + } + + return &retries{ + backoffFn: backoff, + shouldRetry: shouldRetry, } - return &retryTimer{backoffFn: f} } // defaultBackoff provides default Client.RetryBackoff implementation @@ -127,30 +140,30 @@ func wantStatus(codes ...int) resOkay { // get retries unsuccessful attempts according to c.RetryBackoff // until the context is done or a non-retriable error is received. func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) { - retry := c.retryTimer() + retry := c.retries() for { req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err } - res, err := c.doNoRetry(ctx, req) + resp, err := c.doNoRetry(ctx, req) switch { case err != nil: return nil, err - case ok(res): - return res, nil - case isRetriable(res.StatusCode): + case ok(resp): + return resp, nil + case retry.shouldRetry(resp): retry.inc() - resErr := responseError(res) - res.Body.Close() + resErr := responseError(resp) + resp.Body.Close() // Ignore the error value from retry.backoff // and return the one from last retry, as received from the CA. - if retry.backoff(ctx, req, res) != nil { + if retry.backoff(ctx, req, resp) != nil { return nil, resErr } default: - defer res.Body.Close() - return nil, responseError(res) + defer resp.Body.Close() + return nil, responseError(resp) } } } @@ -171,30 +184,30 @@ func (c *Client) postAsGet(ctx context.Context, url string, ok resOkay) (*http.R // until the context is done or a non-retriable error is received. // It uses postNoRetry to make individual requests. func (c *Client) post(ctx context.Context, key crypto.Signer, url string, body interface{}, ok resOkay) (*http.Response, error) { - retry := c.retryTimer() + retry := c.retries() for { - res, req, err := c.postNoRetry(ctx, key, url, body) + resp, req, err := c.postNoRetry(ctx, key, url, body) if err != nil { return nil, err } - if ok(res) { - return res, nil + if ok(resp) { + return resp, nil } - resErr := responseError(res) - res.Body.Close() + resErr := responseError(resp) + resp.Body.Close() switch { - // Check for bad nonce before isRetriable because it may have been returned + // Check for bad nonce before defaultShouldRetry because it may have been returned // with an unretriable response code such as 400 Bad Request. case isBadNonce(resErr): // Consider any previously stored nonce values to be invalid. c.clearNonces() - case !isRetriable(res.StatusCode): + case !retry.shouldRetry(resp): return nil, resErr } retry.inc() // Ignore the error value from retry.backoff // and return the one from last retry, as received from the CA. - if err := retry.backoff(ctx, req, res); err != nil { + if err := retry.backoff(ctx, req, resp); err != nil { return nil, resErr } } @@ -293,13 +306,13 @@ func isBadNonce(err error) bool { return ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") } -// isRetriable reports whether a request can be retried -// based on the response status code. +// defaultShouldRetry reports whether a request can be retried +// based on the HTTP response status code. // // Note that a "bad nonce" error is returned with a non-retriable 400 Bad Request code. // Callers should parse the response and check with isBadNonce. -func isRetriable(code int) bool { - return code <= 399 || code >= 500 || code == http.StatusTooManyRequests +func defaultShouldRetry(resp *http.Response) bool { + return resp.StatusCode <= 399 || resp.StatusCode >= 500 || resp.StatusCode == http.StatusTooManyRequests } // responseError creates an error of Error type from resp. diff --git a/acme/http_test.go b/acme/http_test.go index 79095ccae6..2990e17e01 100644 --- a/acme/http_test.go +++ b/acme/http_test.go @@ -212,6 +212,143 @@ func TestRetryBackoffArgs(t *testing.T) { } } +func TestShouldRetry(t *testing.T) { + tt := []struct { + desc string + requestMethodType string + shouldRetry func(*http.Response) bool + responseHttpStatusCode int + expectedToRetry bool + expectedErr *Error + }{ + { + desc: "retries post request that returns response code that is retriable in defaultShouldRetry configuration", + requestMethodType: http.MethodPost, + responseHttpStatusCode: http.StatusOK, + expectedToRetry: true, + expectedErr: &Error{ + StatusCode: http.StatusOK, + }, + }, + { + desc: "does not retry post request that returns response code that is not retriable in defaultShouldRetry configuration", + requestMethodType: http.MethodPost, + responseHttpStatusCode: http.StatusUnprocessableEntity, + expectedErr: &Error{ + StatusCode: http.StatusUnprocessableEntity, + }, + }, + { + desc: "retries get request that returns response code that is retriable in defaultShouldRetry configuration", + requestMethodType: http.MethodGet, + responseHttpStatusCode: http.StatusTooManyRequests, + expectedToRetry: true, + expectedErr: &Error{ + StatusCode: http.StatusTooManyRequests, + }, + }, + { + desc: "does not retry get request that returns response code that is not retriable in defaultShouldRetry configuration", + requestMethodType: http.MethodGet, + responseHttpStatusCode: http.StatusNotFound, + expectedErr: &Error{ + StatusCode: http.StatusNotFound, + }, + }, + { + desc: "retries post request that returns response code that is retriable in the client.ShouldRetry configuration", + requestMethodType: http.MethodPost, + responseHttpStatusCode: http.StatusInternalServerError, + shouldRetry: func(resp *http.Response) bool { + return resp.StatusCode >= 500 + }, + expectedToRetry: true, + expectedErr: &Error{ + StatusCode: http.StatusInternalServerError, + }, + }, + { + desc: "does not retry post request that returns response code that is not retriable in the client.ShouldRetry configuration", + requestMethodType: http.MethodPost, + responseHttpStatusCode: http.StatusTooManyRequests, + shouldRetry: func(resp *http.Response) bool { + return false + }, + expectedErr: &Error{ + StatusCode: http.StatusTooManyRequests, + }, + }, + { + desc: "retries get request that returns response code that is retriable in the client.ShouldRetry configuration", + requestMethodType: http.MethodGet, + responseHttpStatusCode: http.StatusTooManyRequests, + shouldRetry: func(resp *http.Response) bool { + return resp.StatusCode == http.StatusTooManyRequests + }, + expectedToRetry: true, + expectedErr: &Error{ + StatusCode: http.StatusTooManyRequests, + }, + }, + { + desc: "does not retry get request that returns response code that is not retriable in the client.ShouldRetry configuration", + requestMethodType: http.MethodGet, + responseHttpStatusCode: http.StatusTooManyRequests, + shouldRetry: func(resp *http.Response) bool { + return resp.StatusCode <= 399 || resp.StatusCode >= 500 + }, + expectedErr: &Error{ + StatusCode: http.StatusTooManyRequests, + }, + }, + } + + for _, test := range tt { + t.Run(test.desc, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(test.responseHttpStatusCode) + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var retryNr int + + client := &Client{ + Key: testKey, + RetryBackoff: func(n int, r *http.Request, res *http.Response) time.Duration { + retryNr++ + if retryNr == 1 { + cancel() + } + return time.Millisecond + }, + ShouldRetry: test.shouldRetry, + dir: &Directory{AuthzURL: ts.URL}, + } + + var err error + switch test.requestMethodType { + case http.MethodPost: + _, err = client.post(ctx, nil, ts.URL, nil, wantStatus(http.StatusCreated)) + case http.MethodGet: + _, err = client.get(ctx, ts.URL, wantStatus(http.StatusOK)) + default: + } + + acmeError, ok := err.(*Error) + if !ok || test.expectedErr.StatusCode != acmeError.StatusCode { + t.Fatalf("err is %v (%T); want a non-nil *acme.Error %v", err, err, test.expectedErr) + } + + if test.expectedToRetry && retryNr != 1 { + t.Errorf("retryNr = %d; want %d", retryNr, 1) + } + }) + } +} + func TestUserAgent(t *testing.T) { for _, custom := range []string{"", "CUSTOM_UA"} { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {