Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

authhandler: Add authCodeOpts param to TokenSource() #492

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions authhandler/authhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,25 @@ type AuthorizationHandler func(authCodeURL string) (code string, state string, e
// This token source will verify that the "state" is identical in the request
// and response before exchanging the auth code for OAuth token to prevent CSRF
// attacks.
func TokenSource(ctx context.Context, config *oauth2.Config, state string, authHandler AuthorizationHandler) oauth2.TokenSource {
return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state})
func TokenSource(ctx context.Context, config *oauth2.Config,
state string, authHandler AuthorizationHandler, authCodeOpts ...oauth2.AuthCodeOption) oauth2.TokenSource {

return oauth2.ReuseTokenSource(
nil,
authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, opts: authCodeOpts},
)
}

type authHandlerSource struct {
ctx context.Context
config *oauth2.Config
authHandler AuthorizationHandler
state string
opts []oauth2.AuthCodeOption
}

func (source authHandlerSource) Token() (*oauth2.Token, error) {
url := source.config.AuthCodeURL(source.state)
url := source.config.AuthCodeURL(source.state, source.opts...)
code, state, err := source.authHandler(url)
if err != nil {
return nil, err
Expand Down
82 changes: 55 additions & 27 deletions authhandler/authhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,17 @@ import (
)

func TestTokenExchange_Success(t *testing.T) {
authhandler := func(authCodeURL string) (string, string, error) {
if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" {
return "testCode", "testState", nil
}
return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
}

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
if r.Form.Get("code") == "testCode" {
w.Header().Set("Content-Type", "application/json")

w.Write([]byte(`{
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
"scope": "pubsub",
"token_type": "bearer",
"expires_in": 3600
"expires_in": 3600,
"duration": "permanent"
}`))
}
}))
Expand All @@ -45,26 +40,59 @@ func TestTokenExchange_Success(t *testing.T) {
},
}

tok, err := TokenSource(context.Background(), conf, "testState", authhandler).Token()
if err != nil {
t.Fatal(err)
}
if !tok.Valid() {
t.Errorf("got invalid token: %v", tok)
}
if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
t.Errorf("access token = %q; want %q", got, want)
}
if got, want := tok.TokenType, "bearer"; got != want {
t.Errorf("token type = %q; want %q", got, want)
}
if got := tok.Expiry.IsZero(); got {
t.Errorf("token expiry is zero = %v, want false", got)
}
scope := tok.Extra("scope")
if got, want := scope, "pubsub"; got != want {
t.Errorf("scope = %q; want %q", got, want)
testExchange := func(authHandler AuthorizationHandler, opts ...oauth2.AuthCodeOption) {
tok, err := TokenSource(context.Background(), conf, "testState", authHandler, opts...).Token()
if err != nil {
t.Fatal(err)
}
if !tok.Valid() {
t.Errorf("got invalid token: %v", tok)
}
if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
t.Errorf("access token = %q; want %q", got, want)
}
if got, want := tok.TokenType, "bearer"; got != want {
t.Errorf("token type = %q; want %q", got, want)
}
if got := tok.Expiry.IsZero(); got {
t.Errorf("token expiry is zero = %v, want false", got)
}
scope := tok.Extra("scope")
if got, want := scope, "pubsub"; got != want {
t.Errorf("scope = %q; want %q", got, want)
}

if opts != nil && len(opts) > 0 {
duration := tok.Extra("duration")
if got, want := duration, "permanent"; got != want {
t.Errorf("duration = %q; want %q", got, want)
}
}
}

t.Run("test no extra options", func(t *testing.T) {
authhandler := func(authCodeURL string) (string, string, error) {
if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" {
return "testCode", "testState", nil
}
return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
}

testExchange(authhandler)
})

t.Run("test with AuthCodeOptions", func(t *testing.T) {
authhandler := func(authCodeURL string) (string, string, error) {
if authCodeURL == "testAuthCodeURL?client_id=testClientID&duration=permanent&response_type=code&scope=pubsub&state=testState" {
return "testCode", "testState", nil
}
return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
}

testAuthOpt := oauth2.SetAuthURLParam("duration", "permanent")

testExchange(authhandler, testAuthOpt)
})
}

func TestTokenExchange_StateMismatch(t *testing.T) {
Expand Down