Skip to content

Commit

Permalink
Add support rfc7523 in client credentials flow
Browse files Browse the repository at this point in the history
Implement JSON Web Token Profile for OAuth 2.0 Client Authentication in client credentials flow.

See https://tools.ietf.org/html/rfc7523
See https://openid.net/specs/openid-connect-core-1_0.html

Fixes golang#433
  • Loading branch information
SmotrovaLilit committed Oct 25, 2020
1 parent 5d25da1 commit d4025d6
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 0 deletions.
28 changes: 28 additions & 0 deletions clientcredentials/clientcredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// server.
//
// See https://tools.ietf.org/html/rfc6749#section-4.4
// See https://tools.ietf.org/html/rfc7523
package clientcredentials // import "golang.org/x/oauth2/clientcredentials"

import (
Expand All @@ -19,6 +20,7 @@ import (
"net/http"
"net/url"
"strings"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/internal"
Expand Down Expand Up @@ -46,7 +48,25 @@ type Config struct {
// AuthStyle optionally specifies how the endpoint wants the
// client ID & client secret sent. The zero value means to
// auto-detect.
// See https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication.
AuthStyle oauth2.AuthStyle

// JWTExpires optionally specifies how long the jwt token is valid for.
JWTExpires time.Duration

// PrivateKey contains the contents of an RSA private key or the
// contents of a PEM file that contains a private key. The provided
// private key is used to sign JWT payloads.
// PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey []byte

// KeyID contains an optional hint indicating which key is being
// used.
KeyID string
}

// Token uses client credentials to retrieve a token.
Expand Down Expand Up @@ -91,6 +111,14 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
v := url.Values{
"grant_type": {"client_credentials"},
}
if c.conf.AuthStyle == oauth2.AuthStylePrivateKeyJWT {
var err error
v, err = c.jwtAssertionValues()
if err != nil {
return nil, err
}

}
if len(c.conf.Scopes) > 0 {
v.Set("scope", strings.Join(c.conf.Scopes, " "))
}
Expand Down
136 changes: 136 additions & 0 deletions clientcredentials/clientcredentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@ package clientcredentials

import (
"context"
"encoding/base64"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws"
)

func newConf(serverURL string) *Config {
Expand Down Expand Up @@ -113,6 +119,136 @@ func TestTokenRequest(t *testing.T) {
}
}

var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE
DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY
fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK
1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr
k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9
/E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt
3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn
2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3
nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK
6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf
5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e
DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1
M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g
z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y
1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK
J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U
f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx
QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA
cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr
Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw
5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg
KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84
OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd
mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ
5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg==
-----END RSA PRIVATE KEY-----`)

func TestTokenJWTRequest(t *testing.T) {
var assertion string
audience := "audience1"
scopes := "scope1 scope2"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" {
t.Errorf("authenticate client request URL = %q; want %q", r.URL, "/token")
}
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Content-Type header = %q; want %q", got, want)
}
err := r.ParseForm()
if err != nil {
t.Fatal(err)
}

if got, want := r.Form.Get("scope"), scopes; got != want {
t.Errorf("scope = %q; want %q", got, want)
}
if got, want := r.Form.Get("audience"), audience; got != want {
t.Errorf("audience = %q; want %q", got, want)
}
if got, want := r.Form.Get("grant_type"), "client_credentials"; got != want {
t.Errorf("grant_type = %q; want %q", got, want)
}
expectedAssertionType := "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
if got, want := r.Form.Get("client_assertion_type"), expectedAssertionType; got != want {
t.Errorf("client_assertion_type = %q; want %q", got, want)
}

assertion = r.Form.Get("client_assertion")

w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
"token_type": "bearer",
"expires_in": 3600
}`))
}))
defer ts.Close()

for _, conf := range []*Config{
{
ClientID: "CLIENT_ID",
Scopes: strings.Split(scopes, " "),
TokenURL: ts.URL + "/token",
EndpointParams: url.Values{"audience": {audience}},
AuthStyle: oauth2.AuthStylePrivateKeyJWT,
PrivateKey: dummyPrivateKey,
KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
},
{
ClientID: "CLIENT_ID_set_jwt_expiration_time",
Scopes: strings.Split(scopes, " "),
TokenURL: ts.URL + "/token",
EndpointParams: url.Values{"audience": {audience}},
AuthStyle: oauth2.AuthStylePrivateKeyJWT,
PrivateKey: dummyPrivateKey,
KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
JWTExpires: time.Minute,
},
} {
t.Run(conf.ClientID, func(t *testing.T) {
_, err := conf.TokenSource(context.Background()).Token()
if err != nil {
t.Fatalf("Failed to fetch token: %v", err)
}
parts := strings.Split(assertion, ".")
if len(parts) != 3 {
t.Fatalf("assertion = %q; want 3 parts", assertion)
}
gotJson, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("invalid token payload; err = %v", err)
}
claimSet := jws.ClaimSet{}
if err := json.Unmarshal(gotJson, &claimSet); err != nil {
t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotJson, err)
}
if got, want := claimSet.Iss, conf.ClientID; got != want {
t.Errorf("payload iss = %q; want %q", got, want)
}
if claimSet.Jti == "" {
t.Errorf("payload jti is empty")
}
expectedDuration := time.Hour
if conf.JWTExpires > 0 {
expectedDuration = conf.JWTExpires
}
if got, want := claimSet.Exp, time.Now().Add(expectedDuration).Unix(); got != want {
t.Errorf("payload exp = %q; want %q", got, want)
}
if got, want := claimSet.Aud, conf.TokenURL; got != want {
t.Errorf("payload aud = %q; want %q", got, want)
}
if got, want := claimSet.Sub, conf.ClientID; got != want {
t.Errorf("payload sub = %q; want %q", got, want)
}
})
}
}

func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
73 changes: 73 additions & 0 deletions clientcredentials/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package clientcredentials

import (
"crypto/rand"
"math/big"
"net/url"
"time"

"golang.org/x/oauth2/internal"
"golang.org/x/oauth2/jws"
)

const (
clientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
)

var (
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
)

func randJWTID(n int) (string, error) {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"
ret := make([]byte, n)
for i := 0; i < n; i++ {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
ret = append(ret, letters[num.Int64()])
}

return string(ret), nil
}

func (c *tokenSource) jwtAssertionValues() (url.Values, error) {
v := url.Values{
"grant_type": {"client_credentials"},
}
pk, err := internal.ParseKey(c.conf.PrivateKey)
if err != nil {
return nil, err
}
claimSet := &jws.ClaimSet{
Iss: c.conf.ClientID,
Sub: c.conf.ClientID,
Aud: c.conf.TokenURL,
}

claimSet.Jti, err = randJWTID(36)
if err != nil {
return nil, err
}
if t := c.conf.JWTExpires; t > 0 {
claimSet.Exp = time.Now().Add(t).Unix()
} else {
claimSet.Exp = time.Now().Add(time.Hour).Unix()
}

h := *defaultHeader
h.KeyID = c.conf.KeyID
payload, err := jws.Encode(&h, claimSet, pk)
if err != nil {
return nil, err
}
v.Set("client_assertion", payload)
v.Set("client_assertion_type", clientAssertionType)

return v, nil
}
4 changes: 4 additions & 0 deletions jws/jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type ClaimSet struct {
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
PrivateClaims map[string]interface{} `json:"-"`

// See https://tools.ietf.org/html/rfc7523#section-3.
// Unique identifier for the jwt token.
Jti string `json:"jti"`
}

func (c *ClaimSet) encode() (string, error) {
Expand Down
5 changes: 5 additions & 0 deletions oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ const (
// using HTTP Basic Authorization. This is an optional style
// described in the OAuth2 RFC 6749 section 2.3.1.
AuthStyleInHeader AuthStyle = 2

// AuthStylePrivateKeyJWT send jwt token signed by private key.
// See https://openid.net/specs/openid-connect-core-1_0.html.
// See https://tools.ietf.org/html/rfc7523.
AuthStylePrivateKeyJWT AuthStyle = 3
)

var (
Expand Down

0 comments on commit d4025d6

Please sign in to comment.