Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support rfc7523 private_key_jwt in client credentials flow #450

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions clientcredentials/clientcredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// server.
//
// See https://tools.ietf.org/html/rfc6749#section-4.4
// See https://tools.ietf.org/html/rfc7523
package clientcredentials // import "golang.org/x/oauth2/clientcredentials"

import (
Expand All @@ -19,6 +20,7 @@ import (
"net/http"
"net/url"
"strings"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/internal"
Expand Down Expand Up @@ -46,11 +48,29 @@ type Config struct {
// AuthStyle optionally specifies how the endpoint wants the
// client ID & client secret sent. The zero value means to
// auto-detect.
// See https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication.
AuthStyle oauth2.AuthStyle

// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache

// JWTExpires optionally specifies how long the jwt token is valid for.
JWTExpires time.Duration

// PrivateKey contains the contents of an RSA private key or the
// contents of a PEM file that contains a private key. The provided
// private key is used to sign JWT payloads.
// PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey []byte

// KeyID contains an optional hint indicating which key is being
// used.
KeyID string
}

// Token uses client credentials to retrieve a token.
Expand Down Expand Up @@ -95,6 +115,14 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
v := url.Values{
"grant_type": {"client_credentials"},
}
if c.conf.AuthStyle == oauth2.AuthStylePrivateKeyJWT {
var err error
v, err = c.jwtAssertionValues()
if err != nil {
return nil, err
}

}
if len(c.conf.Scopes) > 0 {
v.Set("scope", strings.Join(c.conf.Scopes, " "))
}
Expand Down
145 changes: 145 additions & 0 deletions clientcredentials/clientcredentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@ package clientcredentials

import (
"context"
"encoding/base64"
"encoding/json"
"io"
"io/ioutil"
"math"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/jws"
)

func newConf(serverURL string) *Config {
Expand Down Expand Up @@ -111,6 +119,143 @@ func TestTokenRequest(t *testing.T) {
}
}

var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE
DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY
fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK
1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr
k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9
/E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt
3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn
2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3
nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK
6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf
5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e
DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1
M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g
z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y
1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK
J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U
f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx
QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA
cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr
Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw
5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg
KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84
OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd
mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ
5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg==
-----END RSA PRIVATE KEY-----`)

func TestTokenJWTRequest(t *testing.T) {
var assertion string
audience := "audience1"
scopes := "scope1 scope2"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" {
t.Errorf("authenticate client request URL = %q; want %q", r.URL, "/token")
}
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Content-Type header = %q; want %q", got, want)
}
err := r.ParseForm()
if err != nil {
t.Fatal(err)
}

if got, want := r.Form.Get("scope"), scopes; got != want {
t.Errorf("scope = %q; want %q", got, want)
}
if got, want := r.Form.Get("audience"), audience; got != want {
t.Errorf("audience = %q; want %q", got, want)
}
if got, want := r.Form.Get("grant_type"), "client_credentials"; got != want {
t.Errorf("grant_type = %q; want %q", got, want)
}
expectedAssertionType := "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
if got, want := r.Form.Get("client_assertion_type"), expectedAssertionType; got != want {
t.Errorf("client_assertion_type = %q; want %q", got, want)
}

assertion = r.Form.Get("client_assertion")

w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
"token_type": "bearer",
"expires_in": 3600
}`))
}))
defer ts.Close()

for _, conf := range []*Config{
{
ClientID: "CLIENT_ID",
Scopes: strings.Split(scopes, " "),
TokenURL: ts.URL + "/token",
EndpointParams: url.Values{"audience": {audience}},
AuthStyle: oauth2.AuthStylePrivateKeyJWT,
PrivateKey: dummyPrivateKey,
KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
},
{
ClientID: "CLIENT_ID_set_jwt_expiration_time",
Scopes: strings.Split(scopes, " "),
TokenURL: ts.URL + "/token",
EndpointParams: url.Values{"audience": {audience}},
AuthStyle: oauth2.AuthStylePrivateKeyJWT,
PrivateKey: dummyPrivateKey,
KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
JWTExpires: time.Minute,
},
} {
t.Run(conf.ClientID, func(t *testing.T) {
_, err := conf.TokenSource(context.Background()).Token()
if err != nil {
t.Fatalf("Failed to fetch token: %v", err)
}
parts := strings.Split(assertion, ".")
if len(parts) != 3 {
t.Fatalf("assertion = %q; want 3 parts", assertion)
}
gotJson, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("invalid token payload; err = %v", err)
}
claimSet := jws.ClaimSet{}
if err := json.Unmarshal(gotJson, &claimSet); err != nil {
t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotJson, err)
}
if got, want := claimSet.Iss, conf.ClientID; got != want {
t.Errorf("payload iss = %q; want %q", got, want)
}
if claimSet.Jti == "" {
t.Errorf("payload jti is empty")
}
expectedDuration := time.Hour
if conf.JWTExpires > 0 {
expectedDuration = conf.JWTExpires
}

if got, want := claimSet.Exp, time.Now().Add(expectedDuration).Unix(); got != want {
t.Errorf("payload exp = %q; want %q", got, want)
}

errorMarginInSeconds := 5.0
if got, want := claimSet.Exp, time.Now().Add(expectedDuration).Unix(); math.Abs(float64(got-want)) > errorMarginInSeconds {
t.Errorf("payload exp is not within the acceptable range: got %q, want around %q", got, want)
}

if got, want := claimSet.Aud, conf.TokenURL; got != want {
t.Errorf("payload aud = %q; want %q", got, want)
}
if got, want := claimSet.Sub, conf.ClientID; got != want {
t.Errorf("payload sub = %q; want %q", got, want)
}
})
}
}

func TestTokenRefreshRequest(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
Expand Down
73 changes: 73 additions & 0 deletions clientcredentials/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package clientcredentials

import (
"crypto/rand"
"math/big"
"net/url"
"time"

"golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws"
)

const (
clientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
)

var (
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
)

func randJWTID(n int) (string, error) {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
ret := make([]byte, n)
for i := 0; i < n; i++ {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
ret = append(ret, letters[num.Int64()])
}

return string(ret), nil
}

func (c *tokenSource) jwtAssertionValues() (url.Values, error) {
v := url.Values{
"grant_type": {"client_credentials"},
}
pk, err := internal.ParseKey(c.conf.PrivateKey)
if err != nil {
return nil, err
}
claimSet := &jws.ClaimSet{
Iss: c.conf.ClientID,
Sub: c.conf.ClientID,
Aud: c.conf.TokenURL,
}

claimSet.Jti, err = randJWTID(36)
if err != nil {
return nil, err
}
if t := c.conf.JWTExpires; t > 0 {
claimSet.Exp = time.Now().Add(t).Unix()
} else {
claimSet.Exp = time.Now().Add(time.Hour).Unix()
}

h := *defaultHeader
h.KeyID = c.conf.KeyID
payload, err := jws.Encode(&h, claimSet, pk)
if err != nil {
return nil, err
}
v.Set("client_assertion", payload)
v.Set("client_assertion_type", clientAssertionType)

return v, nil
}
4 changes: 4 additions & 0 deletions jws/jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type ClaimSet struct {
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
PrivateClaims map[string]interface{} `json:"-"`

// See https://tools.ietf.org/html/rfc7523#section-3.
// Unique identifier for the jwt token.
Jti string `json:"jti"`
}

func (c *ClaimSet) encode() (string, error) {
Expand Down
5 changes: 5 additions & 0 deletions oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ const (
// using HTTP Basic Authorization. This is an optional style
// described in the OAuth2 RFC 6749 section 2.3.1.
AuthStyleInHeader AuthStyle = 2

// AuthStylePrivateKeyJWT send jwt token signed by private key.
// See https://openid.net/specs/openid-connect-core-1_0.html.
// See https://tools.ietf.org/html/rfc7523.
AuthStylePrivateKeyJWT AuthStyle = 3
)

var (
Expand Down