diff --git a/authhandler/authhandler.go b/authhandler/authhandler.go index 69967cf87..19eb4c1d1 100644 --- a/authhandler/authhandler.go +++ b/authhandler/authhandler.go @@ -32,8 +32,13 @@ type AuthorizationHandler func(authCodeURL string) (code string, state string, e // This token source will verify that the "state" is identical in the request // and response before exchanging the auth code for OAuth token to prevent CSRF // attacks. -func TokenSource(ctx context.Context, config *oauth2.Config, state string, authHandler AuthorizationHandler) oauth2.TokenSource { - return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state}) +func TokenSource(ctx context.Context, config *oauth2.Config, + state string, authHandler AuthorizationHandler, authCodeOpts ...oauth2.AuthCodeOption) oauth2.TokenSource { + + return oauth2.ReuseTokenSource( + nil, + authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, opts: authCodeOpts}, + ) } type authHandlerSource struct { @@ -41,10 +46,11 @@ type authHandlerSource struct { config *oauth2.Config authHandler AuthorizationHandler state string + opts []oauth2.AuthCodeOption } func (source authHandlerSource) Token() (*oauth2.Token, error) { - url := source.config.AuthCodeURL(source.state) + url := source.config.AuthCodeURL(source.state, source.opts...) code, state, err := source.authHandler(url) if err != nil { return nil, err diff --git a/authhandler/authhandler_test.go b/authhandler/authhandler_test.go index 084198f4c..9e5d90b56 100644 --- a/authhandler/authhandler_test.go +++ b/authhandler/authhandler_test.go @@ -15,22 +15,17 @@ import ( ) func TestTokenExchange_Success(t *testing.T) { - authhandler := func(authCodeURL string) (string, string, error) { - if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" { - return "testCode", "testState", nil - } - return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL) - } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.ParseForm() if r.Form.Get("code") == "testCode" { w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "pubsub", "token_type": "bearer", - "expires_in": 3600 + "expires_in": 3600, + "duration": "permanent" }`)) } })) @@ -45,26 +40,59 @@ func TestTokenExchange_Success(t *testing.T) { }, } - tok, err := TokenSource(context.Background(), conf, "testState", authhandler).Token() - if err != nil { - t.Fatal(err) - } - if !tok.Valid() { - t.Errorf("got invalid token: %v", tok) - } - if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want { - t.Errorf("access token = %q; want %q", got, want) - } - if got, want := tok.TokenType, "bearer"; got != want { - t.Errorf("token type = %q; want %q", got, want) - } - if got := tok.Expiry.IsZero(); got { - t.Errorf("token expiry is zero = %v, want false", got) - } - scope := tok.Extra("scope") - if got, want := scope, "pubsub"; got != want { - t.Errorf("scope = %q; want %q", got, want) + testExchange := func(authHandler AuthorizationHandler, opts ...oauth2.AuthCodeOption) { + tok, err := TokenSource(context.Background(), conf, "testState", authHandler, opts...).Token() + if err != nil { + t.Fatal(err) + } + if !tok.Valid() { + t.Errorf("got invalid token: %v", tok) + } + if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want { + t.Errorf("access token = %q; want %q", got, want) + } + if got, want := tok.TokenType, "bearer"; got != want { + t.Errorf("token type = %q; want %q", got, want) + } + if got := tok.Expiry.IsZero(); got { + t.Errorf("token expiry is zero = %v, want false", got) + } + scope := tok.Extra("scope") + if got, want := scope, "pubsub"; got != want { + t.Errorf("scope = %q; want %q", got, want) + } + + if opts != nil && len(opts) > 0 { + duration := tok.Extra("duration") + if got, want := duration, "permanent"; got != want { + t.Errorf("duration = %q; want %q", got, want) + } + } } + + t.Run("test no extra options", func(t *testing.T) { + authhandler := func(authCodeURL string) (string, string, error) { + if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" { + return "testCode", "testState", nil + } + return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL) + } + + testExchange(authhandler) + }) + + t.Run("test with AuthCodeOptions", func(t *testing.T) { + authhandler := func(authCodeURL string) (string, string, error) { + if authCodeURL == "testAuthCodeURL?client_id=testClientID&duration=permanent&response_type=code&scope=pubsub&state=testState" { + return "testCode", "testState", nil + } + return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL) + } + + testAuthOpt := oauth2.SetAuthURLParam("duration", "permanent") + + testExchange(authhandler, testAuthOpt) + }) } func TestTokenExchange_StateMismatch(t *testing.T) {