diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 2459d069f..0933d4ab4 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -11,6 +11,7 @@ // server. // // See https://tools.ietf.org/html/rfc6749#section-4.4 +// See https://tools.ietf.org/html/rfc7523 package clientcredentials // import "golang.org/x/oauth2/clientcredentials" import ( @@ -19,6 +20,7 @@ import ( "net/http" "net/url" "strings" + "time" "golang.org/x/oauth2" "golang.org/x/oauth2/internal" @@ -46,11 +48,29 @@ type Config struct { // AuthStyle optionally specifies how the endpoint wants the // client ID & client secret sent. The zero value means to // auto-detect. + // See https://openid.net/specs/openid-connect-core-1_0.html#ClientAuthentication. AuthStyle oauth2.AuthStyle // authStyleCache caches which auth style to use when Endpoint.AuthStyle is // the zero value (AuthStyleAutoDetect). authStyleCache internal.LazyAuthStyleCache + + // JWTExpires optionally specifies how long the jwt token is valid for. + JWTExpires time.Duration + + // PrivateKey contains the contents of an RSA private key or the + // contents of a PEM file that contains a private key. The provided + // private key is used to sign JWT payloads. + // PEM containers with a passphrase are not supported. + // Use the following command to convert a PKCS 12 file into a PEM. + // + // $ openssl pkcs12 -in key.p12 -out key.pem -nodes + // + PrivateKey []byte + + // KeyID contains an optional hint indicating which key is being + // used. + KeyID string } // Token uses client credentials to retrieve a token. @@ -95,6 +115,14 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { v := url.Values{ "grant_type": {"client_credentials"}, } + if c.conf.AuthStyle == oauth2.AuthStylePrivateKeyJWT { + var err error + v, err = c.jwtAssertionValues() + if err != nil { + return nil, err + } + + } if len(c.conf.Scopes) > 0 { v.Set("scope", strings.Join(c.conf.Scopes, " ")) } diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 078e75ec7..cd7b07d95 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -6,12 +6,20 @@ package clientcredentials import ( "context" + "encoding/base64" + "encoding/json" "io" "io/ioutil" + "math" "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/jws" ) func newConf(serverURL string) *Config { @@ -111,6 +119,143 @@ func TestTokenRequest(t *testing.T) { } } +var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE +DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY +fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK +1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr +k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9 +/E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt +3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn +2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3 +nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK +6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf +5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e +DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1 +M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g +z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y +1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK +J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U +f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx +QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA +cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr +Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw +5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg +KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84 +OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd +mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ +5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg== +-----END RSA PRIVATE KEY-----`) + +func TestTokenJWTRequest(t *testing.T) { + var assertion string + audience := "audience1" + scopes := "scope1 scope2" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/token" { + t.Errorf("authenticate client request URL = %q; want %q", r.URL, "/token") + } + if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { + t.Errorf("Content-Type header = %q; want %q", got, want) + } + err := r.ParseForm() + if err != nil { + t.Fatal(err) + } + + if got, want := r.Form.Get("scope"), scopes; got != want { + t.Errorf("scope = %q; want %q", got, want) + } + if got, want := r.Form.Get("audience"), audience; got != want { + t.Errorf("audience = %q; want %q", got, want) + } + if got, want := r.Form.Get("grant_type"), "client_credentials"; got != want { + t.Errorf("grant_type = %q; want %q", got, want) + } + expectedAssertionType := "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + if got, want := r.Form.Get("client_assertion_type"), expectedAssertionType; got != want { + t.Errorf("client_assertion_type = %q; want %q", got, want) + } + + assertion = r.Form.Get("client_assertion") + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", + "token_type": "bearer", + "expires_in": 3600 + }`)) + })) + defer ts.Close() + + for _, conf := range []*Config{ + { + ClientID: "CLIENT_ID", + Scopes: strings.Split(scopes, " "), + TokenURL: ts.URL + "/token", + EndpointParams: url.Values{"audience": {audience}}, + AuthStyle: oauth2.AuthStylePrivateKeyJWT, + PrivateKey: dummyPrivateKey, + KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + }, + { + ClientID: "CLIENT_ID_set_jwt_expiration_time", + Scopes: strings.Split(scopes, " "), + TokenURL: ts.URL + "/token", + EndpointParams: url.Values{"audience": {audience}}, + AuthStyle: oauth2.AuthStylePrivateKeyJWT, + PrivateKey: dummyPrivateKey, + KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + JWTExpires: time.Minute, + }, + } { + t.Run(conf.ClientID, func(t *testing.T) { + _, err := conf.TokenSource(context.Background()).Token() + if err != nil { + t.Fatalf("Failed to fetch token: %v", err) + } + parts := strings.Split(assertion, ".") + if len(parts) != 3 { + t.Fatalf("assertion = %q; want 3 parts", assertion) + } + gotJson, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + t.Fatalf("invalid token payload; err = %v", err) + } + claimSet := jws.ClaimSet{} + if err := json.Unmarshal(gotJson, &claimSet); err != nil { + t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotJson, err) + } + if got, want := claimSet.Iss, conf.ClientID; got != want { + t.Errorf("payload iss = %q; want %q", got, want) + } + if claimSet.Jti == "" { + t.Errorf("payload jti is empty") + } + expectedDuration := time.Hour + if conf.JWTExpires > 0 { + expectedDuration = conf.JWTExpires + } + + if got, want := claimSet.Exp, time.Now().Add(expectedDuration).Unix(); got != want { + t.Errorf("payload exp = %q; want %q", got, want) + } + + errorMarginInSeconds := 5.0 + if got, want := claimSet.Exp, time.Now().Add(expectedDuration).Unix(); math.Abs(float64(got-want)) > errorMarginInSeconds { + t.Errorf("payload exp is not within the acceptable range: got %q, want around %q", got, want) + } + + if got, want := claimSet.Aud, conf.TokenURL; got != want { + t.Errorf("payload aud = %q; want %q", got, want) + } + if got, want := claimSet.Sub, conf.ClientID; got != want { + t.Errorf("payload sub = %q; want %q", got, want) + } + }) + } +} + func TestTokenRefreshRequest(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/somethingelse" { diff --git a/clientcredentials/jwt.go b/clientcredentials/jwt.go new file mode 100644 index 000000000..a20a46591 --- /dev/null +++ b/clientcredentials/jwt.go @@ -0,0 +1,73 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package clientcredentials + +import ( + "crypto/rand" + "math/big" + "net/url" + "time" + + "golang.org/x/oauth2/internal" + "golang.org/x/oauth2/jws" +) + +const ( + clientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" +) + +var ( + defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"} +) + +func randJWTID(n int) (string, error) { + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" + ret := make([]byte, n) + for i := 0; i < n; i++ { + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + if err != nil { + return "", err + } + ret = append(ret, letters[num.Int64()]) + } + + return string(ret), nil +} + +func (c *tokenSource) jwtAssertionValues() (url.Values, error) { + v := url.Values{ + "grant_type": {"client_credentials"}, + } + pk, err := internal.ParseKey(c.conf.PrivateKey) + if err != nil { + return nil, err + } + claimSet := &jws.ClaimSet{ + Iss: c.conf.ClientID, + Sub: c.conf.ClientID, + Aud: c.conf.TokenURL, + } + + claimSet.Jti, err = randJWTID(36) + if err != nil { + return nil, err + } + if t := c.conf.JWTExpires; t > 0 { + claimSet.Exp = time.Now().Add(t).Unix() + } else { + claimSet.Exp = time.Now().Add(time.Hour).Unix() + } + + h := *defaultHeader + h.KeyID = c.conf.KeyID + payload, err := jws.Encode(&h, claimSet, pk) + if err != nil { + return nil, err + } + v.Set("client_assertion", payload) + v.Set("client_assertion_type", clientAssertionType) + + return v, nil +} diff --git a/jws/jws.go b/jws/jws.go index 95015648b..e2030f3f5 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -49,6 +49,10 @@ type ClaimSet struct { // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 // This array is marshalled using custom code (see (c *ClaimSet) encode()). PrivateClaims map[string]interface{} `json:"-"` + + // See https://tools.ietf.org/html/rfc7523#section-3. + // Unique identifier for the jwt token. + Jti string `json:"jti"` } func (c *ClaimSet) encode() (string, error) { diff --git a/oauth2.go b/oauth2.go index 90a2c3d6d..286536ea9 100644 --- a/oauth2.go +++ b/oauth2.go @@ -103,6 +103,11 @@ const ( // using HTTP Basic Authorization. This is an optional style // described in the OAuth2 RFC 6749 section 2.3.1. AuthStyleInHeader AuthStyle = 2 + + // AuthStylePrivateKeyJWT send jwt token signed by private key. + // See https://openid.net/specs/openid-connect-core-1_0.html. + // See https://tools.ietf.org/html/rfc7523. + AuthStylePrivateKeyJWT AuthStyle = 3 ) var (