Skip to content

Commit

Permalink
authhandler: Add authCodeOpts param to TokenSource
Browse files Browse the repository at this point in the history
Certain implementations of OAuth2 (such as Reddit) have extra parameters
available as part of their OAuth2 flow. This commit adds the variadic
parameter `authCodeOpts` to the TokenSource function in authhandler,
allowing a caller to pass as many of these extra URL parameters as they
want or need.

Tests in the authhandler package have also been updated to reflect this
change.

**Does this commit introduce any breaking changes?**

No. The new parameter is variadic, as are its uses, which means that
existing code will not be affected by this change.
  • Loading branch information
jalavosus committed Apr 23, 2021
1 parent 5e61552 commit 05e614e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 30 deletions.
12 changes: 9 additions & 3 deletions authhandler/authhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,25 @@ type AuthorizationHandler func(authCodeURL string) (code string, state string, e
// This token source will verify that the "state" is identical in the request
// and response before exchanging the auth code for OAuth token to prevent CSRF
// attacks.
func TokenSource(ctx context.Context, config *oauth2.Config, state string, authHandler AuthorizationHandler) oauth2.TokenSource {
return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state})
func TokenSource(ctx context.Context, config *oauth2.Config,
state string, authHandler AuthorizationHandler, authCodeOpts ...oauth2.AuthCodeOption) oauth2.TokenSource {

return oauth2.ReuseTokenSource(
nil,
authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, opts: authCodeOpts},
)
}

type authHandlerSource struct {
ctx context.Context
config *oauth2.Config
authHandler AuthorizationHandler
state string
opts []oauth2.AuthCodeOption
}

func (source authHandlerSource) Token() (*oauth2.Token, error) {
url := source.config.AuthCodeURL(source.state)
url := source.config.AuthCodeURL(source.state, source.opts...)
code, state, err := source.authHandler(url)
if err != nil {
return nil, err
Expand Down
82 changes: 55 additions & 27 deletions authhandler/authhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,17 @@ import (
)

func TestTokenExchange_Success(t *testing.T) {
authhandler := func(authCodeURL string) (string, string, error) {
if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" {
return "testCode", "testState", nil
}
return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
}

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
if r.Form.Get("code") == "testCode" {
w.Header().Set("Content-Type", "application/json")

w.Write([]byte(`{
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
"scope": "pubsub",
"token_type": "bearer",
"expires_in": 3600
"expires_in": 3600,
"duration": "permanent"
}`))
}
}))
Expand All @@ -45,26 +40,59 @@ func TestTokenExchange_Success(t *testing.T) {
},
}

tok, err := TokenSource(context.Background(), conf, "testState", authhandler).Token()
if err != nil {
t.Fatal(err)
}
if !tok.Valid() {
t.Errorf("got invalid token: %v", tok)
}
if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
t.Errorf("access token = %q; want %q", got, want)
}
if got, want := tok.TokenType, "bearer"; got != want {
t.Errorf("token type = %q; want %q", got, want)
}
if got := tok.Expiry.IsZero(); got {
t.Errorf("token expiry is zero = %v, want false", got)
}
scope := tok.Extra("scope")
if got, want := scope, "pubsub"; got != want {
t.Errorf("scope = %q; want %q", got, want)
testExchange := func(authHandler AuthorizationHandler, opts ...oauth2.AuthCodeOption) {
tok, err := TokenSource(context.Background(), conf, "testState", authHandler, opts...).Token()
if err != nil {
t.Fatal(err)
}
if !tok.Valid() {
t.Errorf("got invalid token: %v", tok)
}
if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
t.Errorf("access token = %q; want %q", got, want)
}
if got, want := tok.TokenType, "bearer"; got != want {
t.Errorf("token type = %q; want %q", got, want)
}
if got := tok.Expiry.IsZero(); got {
t.Errorf("token expiry is zero = %v, want false", got)
}
scope := tok.Extra("scope")
if got, want := scope, "pubsub"; got != want {
t.Errorf("scope = %q; want %q", got, want)
}

if opts != nil && len(opts) > 0 {
duration := tok.Extra("duration")
if got, want := duration, "permanent"; got != want {
t.Errorf("duration = %q; want %q", got, want)
}
}
}

t.Run("test no extra options", func(t *testing.T) {
authhandler := func(authCodeURL string) (string, string, error) {
if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" {
return "testCode", "testState", nil
}
return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
}

testExchange(authhandler)
})

t.Run("test with AuthCodeOptions", func(t *testing.T) {
authhandler := func(authCodeURL string) (string, string, error) {
if authCodeURL == "testAuthCodeURL?client_id=testClientID&duration=permanent&response_type=code&scope=pubsub&state=testState" {
return "testCode", "testState", nil
}
return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
}

testAuthOpt := oauth2.SetAuthURLParam("duration", "permanent")

testExchange(authhandler, testAuthOpt)
})
}

func TestTokenExchange_StateMismatch(t *testing.T) {
Expand Down

0 comments on commit 05e614e

Please sign in to comment.