Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

acme: Allow to customize isRetriable fn on the ACME client #149

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions acme/acme.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ type Client struct {
// will have no effect.
DirectoryURL string

// ShouldRetry reports whether a request can be retried based on the HTTP response status code.
// When ShouldRetry is nil, the default behavior will take precedence.
ShouldRetry func(resp *http.Response) bool

// RetryBackoff computes the duration after which the nth retry of a failed request
// should occur. The value of n for the first call on failure is 1.
// The values of r and resp are the request and response of the last failed attempt.
Expand All @@ -108,6 +112,7 @@ type Client struct {
//
// Requests which result in a 4xx client error are not retried,
// except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests.
// If ShouldRetry is not nil, it will override the default behavior described above.
//
// If RetryBackoff is nil, a truncated exponential backoff algorithm
// with the ceiling of 10 seconds is used, where each subsequent retry n
Expand Down
83 changes: 48 additions & 35 deletions acme/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,29 @@ import (
"time"
)

// retryTimer encapsulates common logic for retrying unsuccessful requests.
// retries encapsulates common logic for retrying unsuccessful requests.
// It is not safe for concurrent use.
type retryTimer struct {
type retries struct {
// shouldRetry reports whether a request can be retried based on the HTTP response status code.
// See Client.ShouldRetry doc comment.
shouldRetry func(resp *http.Response) bool

// backoffFn provides backoff delay sequence for retries.
// See Client.RetryBackoff doc comment.
backoffFn func(n int, r *http.Request, res *http.Response) time.Duration
// n is the current retry attempt.
n int
}

func (t *retryTimer) inc() {
t.n++
func (r *retries) inc() {
r.n++
}

// backoff pauses the current goroutine as described in Client.RetryBackoff.
func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Response) error {
d := t.backoffFn(t.n, r, res)
func (r *retries) backoff(ctx context.Context, req *http.Request, resp *http.Response) error {
d := r.backoffFn(r.n, req, resp)
if d <= 0 {
return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", r.URL, t.n)
return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", req.URL, r.n)
}
wakeup := time.NewTimer(d)
defer wakeup.Stop()
Expand All @@ -49,12 +53,21 @@ func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Res
}
}

func (c *Client) retryTimer() *retryTimer {
f := c.RetryBackoff
if f == nil {
f = defaultBackoff
func (c *Client) retries() *retries {
backoff := c.RetryBackoff
if backoff == nil {
backoff = defaultBackoff
}

shouldRetry := c.ShouldRetry
if shouldRetry == nil {
shouldRetry = defaultShouldRetry
}

return &retries{
backoffFn: backoff,
shouldRetry: shouldRetry,
}
return &retryTimer{backoffFn: f}
}

// defaultBackoff provides default Client.RetryBackoff implementation
Expand Down Expand Up @@ -127,30 +140,30 @@ func wantStatus(codes ...int) resOkay {
// get retries unsuccessful attempts according to c.RetryBackoff
// until the context is done or a non-retriable error is received.
func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) {
retry := c.retryTimer()
retry := c.retries()
for {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
res, err := c.doNoRetry(ctx, req)
resp, err := c.doNoRetry(ctx, req)
switch {
case err != nil:
return nil, err
case ok(res):
return res, nil
case isRetriable(res.StatusCode):
case ok(resp):
return resp, nil
case retry.shouldRetry(resp):
retry.inc()
resErr := responseError(res)
res.Body.Close()
resErr := responseError(resp)
resp.Body.Close()
// Ignore the error value from retry.backoff
// and return the one from last retry, as received from the CA.
if retry.backoff(ctx, req, res) != nil {
if retry.backoff(ctx, req, resp) != nil {
return nil, resErr
}
default:
defer res.Body.Close()
return nil, responseError(res)
defer resp.Body.Close()
return nil, responseError(resp)
}
}
}
Expand All @@ -171,30 +184,30 @@ func (c *Client) postAsGet(ctx context.Context, url string, ok resOkay) (*http.R
// until the context is done or a non-retriable error is received.
// It uses postNoRetry to make individual requests.
func (c *Client) post(ctx context.Context, key crypto.Signer, url string, body interface{}, ok resOkay) (*http.Response, error) {
retry := c.retryTimer()
retry := c.retries()
for {
res, req, err := c.postNoRetry(ctx, key, url, body)
resp, req, err := c.postNoRetry(ctx, key, url, body)
if err != nil {
return nil, err
}
if ok(res) {
return res, nil
if ok(resp) {
return resp, nil
}
resErr := responseError(res)
res.Body.Close()
resErr := responseError(resp)
resp.Body.Close()
switch {
// Check for bad nonce before isRetriable because it may have been returned
// Check for bad nonce before defaultShouldRetry because it may have been returned
// with an unretriable response code such as 400 Bad Request.
case isBadNonce(resErr):
// Consider any previously stored nonce values to be invalid.
c.clearNonces()
case !isRetriable(res.StatusCode):
case !retry.shouldRetry(resp):
return nil, resErr
}
retry.inc()
// Ignore the error value from retry.backoff
// and return the one from last retry, as received from the CA.
if err := retry.backoff(ctx, req, res); err != nil {
if err := retry.backoff(ctx, req, resp); err != nil {
return nil, resErr
}
}
Expand Down Expand Up @@ -293,13 +306,13 @@ func isBadNonce(err error) bool {
return ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce")
}

// isRetriable reports whether a request can be retried
// based on the response status code.
// defaultShouldRetry reports whether a request can be retried
// based on the HTTP response status code.
//
// Note that a "bad nonce" error is returned with a non-retriable 400 Bad Request code.
// Callers should parse the response and check with isBadNonce.
func isRetriable(code int) bool {
return code <= 399 || code >= 500 || code == http.StatusTooManyRequests
func defaultShouldRetry(resp *http.Response) bool {
return resp.StatusCode <= 399 || resp.StatusCode >= 500 || resp.StatusCode == http.StatusTooManyRequests
}

// responseError creates an error of Error type from resp.
Expand Down
137 changes: 137 additions & 0 deletions acme/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,143 @@ func TestRetryBackoffArgs(t *testing.T) {
}
}

func TestShouldRetry(t *testing.T) {
tt := []struct {
desc string
requestMethodType string
shouldRetry func(*http.Response) bool
responseHttpStatusCode int
expectedToRetry bool
expectedErr *Error
}{
{
desc: "retries post request that returns response code that is retriable in defaultShouldRetry configuration",
requestMethodType: http.MethodPost,
responseHttpStatusCode: http.StatusOK,
expectedToRetry: true,
expectedErr: &Error{
StatusCode: http.StatusOK,
},
},
{
desc: "does not retry post request that returns response code that is not retriable in defaultShouldRetry configuration",
requestMethodType: http.MethodPost,
responseHttpStatusCode: http.StatusUnprocessableEntity,
expectedErr: &Error{
StatusCode: http.StatusUnprocessableEntity,
},
},
{
desc: "retries get request that returns response code that is retriable in defaultShouldRetry configuration",
requestMethodType: http.MethodGet,
responseHttpStatusCode: http.StatusTooManyRequests,
expectedToRetry: true,
expectedErr: &Error{
StatusCode: http.StatusTooManyRequests,
},
},
{
desc: "does not retry get request that returns response code that is not retriable in defaultShouldRetry configuration",
requestMethodType: http.MethodGet,
responseHttpStatusCode: http.StatusNotFound,
expectedErr: &Error{
StatusCode: http.StatusNotFound,
},
},
{
desc: "retries post request that returns response code that is retriable in the client.ShouldRetry configuration",
requestMethodType: http.MethodPost,
responseHttpStatusCode: http.StatusInternalServerError,
shouldRetry: func(resp *http.Response) bool {
return resp.StatusCode >= 500
},
expectedToRetry: true,
expectedErr: &Error{
StatusCode: http.StatusInternalServerError,
},
},
{
desc: "does not retry post request that returns response code that is not retriable in the client.ShouldRetry configuration",
requestMethodType: http.MethodPost,
responseHttpStatusCode: http.StatusTooManyRequests,
shouldRetry: func(resp *http.Response) bool {
return false
},
expectedErr: &Error{
StatusCode: http.StatusTooManyRequests,
},
},
{
desc: "retries get request that returns response code that is retriable in the client.ShouldRetry configuration",
requestMethodType: http.MethodGet,
responseHttpStatusCode: http.StatusTooManyRequests,
shouldRetry: func(resp *http.Response) bool {
return resp.StatusCode == http.StatusTooManyRequests
},
expectedToRetry: true,
expectedErr: &Error{
StatusCode: http.StatusTooManyRequests,
},
},
{
desc: "does not retry get request that returns response code that is not retriable in the client.ShouldRetry configuration",
requestMethodType: http.MethodGet,
responseHttpStatusCode: http.StatusTooManyRequests,
shouldRetry: func(resp *http.Response) bool {
return resp.StatusCode <= 399 || resp.StatusCode >= 500
},
expectedErr: &Error{
StatusCode: http.StatusTooManyRequests,
},
},
}

for _, test := range tt {
t.Run(test.desc, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.responseHttpStatusCode)
}))
defer ts.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var retryNr int

client := &Client{
Key: testKey,
RetryBackoff: func(n int, r *http.Request, res *http.Response) time.Duration {
retryNr++
if retryNr == 1 {
cancel()
}
return time.Millisecond
},
ShouldRetry: test.shouldRetry,
dir: &Directory{AuthzURL: ts.URL},
}

var err error
switch test.requestMethodType {
case http.MethodPost:
_, err = client.post(ctx, nil, ts.URL, nil, wantStatus(http.StatusCreated))
case http.MethodGet:
_, err = client.get(ctx, ts.URL, wantStatus(http.StatusOK))
default:
}

acmeError, ok := err.(*Error)
if !ok || test.expectedErr.StatusCode != acmeError.StatusCode {
t.Fatalf("err is %v (%T); want a non-nil *acme.Error %v", err, err, test.expectedErr)
}

if test.expectedToRetry && retryNr != 1 {
t.Errorf("retryNr = %d; want %d", retryNr, 1)
}
})
}
}

func TestUserAgent(t *testing.T) {
for _, custom := range []string{"", "CUSTOM_UA"} {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down