-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchecksum_test.go
119 lines (102 loc) · 3.05 KB
/
checksum_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package keyfunc_test
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/dgrijalva/jwt-go"
"github.com/MicahParks/compatibility-keyfunc"
)
// TestChecksum confirms that the JWKS will only perform a refresh if a new JWKS is read from the remote resource.
func TestChecksum(t *testing.T) {
tempDir, err := ioutil.TempDir("", "*")
if err != nil {
t.Errorf("Failed to create a temporary directory.\nError: %s", err.Error())
t.FailNow()
}
defer func() {
err = os.RemoveAll(tempDir)
if err != nil {
t.Errorf("Failed to remove temporary directory.\nError: %s", err.Error())
t.FailNow()
}
}()
jwksFile := filepath.Join(tempDir, jwksFilePath)
err = ioutil.WriteFile(jwksFile, []byte(jwksJSON), 0600)
if err != nil {
t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error())
t.FailNow()
}
server := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
defer server.Close()
testingRefreshErrorHandler := func(err error) {
panic(fmt.Sprintf("Unhandled JWKS error: %s", err.Error()))
}
opts := keyfunc.Options{
RefreshErrorHandler: testingRefreshErrorHandler,
RefreshUnknownKID: true,
}
jwksURL := server.URL + jwksFilePath
jwks, err := keyfunc.Get(jwksURL, opts)
if err != nil {
t.Errorf("Failed to get JWKS from testing URL.\nError: %s", err.Error())
t.FailNow()
}
defer jwks.EndBackground()
cryptoKeyPointers := make(map[string]interface{})
for kid, cryptoKey := range jwks.ReadOnlyKeys() {
cryptoKeyPointers[kid] = cryptoKey
}
// Create a JWT that will not be in the JWKS.
token := jwt.New(jwt.SigningMethodHS256)
token.Header["kid"] = "unknown"
signed, err := token.SignedString([]byte("test"))
if err != nil {
t.Errorf("Failed to sign test JWT.\nError: %s", err.Error())
t.FailNow()
}
// Force the JWKS to refresh.
_, _ = jwt.Parse(signed, jwks.KeyfuncLegacy)
// Confirm the keys in the JWKS have not been refreshed.
newKeys := jwks.ReadOnlyKeys()
if len(newKeys) != len(cryptoKeyPointers) {
t.Errorf("The number of keys should not be different.")
t.FailNow()
}
for kid, cryptoKey := range newKeys {
if !reflect.DeepEqual(cryptoKeyPointers[kid], cryptoKey) {
t.Errorf("The JWKS should not have refreshed without a checksum change.")
t.FailNow()
}
}
// Write a different JWKS.
_, _, jwksBytes, _, err := keysAndJWKS()
if err != nil {
t.Errorf("Failed to create a test JWKS.\nError: %s", err.Error())
t.FailNow()
}
err = ioutil.WriteFile(jwksFile, jwksBytes, 0600)
if err != nil {
t.Errorf("Failed to write JWKS file to temporary directory.\nError: %s", err.Error())
t.FailNow()
}
// Force the JWKS to refresh.
_, _ = jwt.Parse(signed, jwks.KeyfuncLegacy)
// Confirm the keys in the JWKS have been refreshed.
newKeys = jwks.ReadOnlyKeys()
different := false
for kid, cryptoKey := range newKeys {
if !reflect.DeepEqual(cryptoKeyPointers[kid], cryptoKey) {
different = true
break
}
}
if !different {
t.Errorf("A different JWKS checksum should have triggered a JWKS refresh.")
t.FailNow()
}
}