diff --git a/pkg/jwt/rsa_sha_signature_validator.go b/pkg/jwt/rsa_sha_signature_validator.go index b708c901d..c5e5d4ccd 100644 --- a/pkg/jwt/rsa_sha_signature_validator.go +++ b/pkg/jwt/rsa_sha_signature_validator.go @@ -15,7 +15,8 @@ type rsaSHASignatureValidator struct { // NewRSASHASignatureValidator creates a SignatureValidator that expects // the signature of a JWT to use the Rivest-Shamir-Adleman (RSA) // cryptosystem, using SHA-256, SHA-384 or SHA-512 as a hashing -// algorithm. +// algorithm. Both PKCS#1 v1.5 (RS256/RS384/RS512) and PSS +// (PS256/PS384/PS512) padding schemes are supported. // // RSA uses asymmetrical cryptography, meaning that signing is performed // using a private key, while verification only relies on a public key. @@ -30,7 +31,20 @@ func NewRSASHASignatureValidator(key *rsa.PublicKey) SignatureValidator { func (sv *rsaSHASignatureValidator) ValidateSignature(algorithm string, keyID *string, headerAndPayload string, signature []byte) bool { var hashType crypto.Hash var hasher hash.Hash + var pssOpts *rsa.PSSOptions switch algorithm { + case "PS256": + hashType = crypto.SHA256 + hasher = sha256.New() + pssOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA256} + case "PS384": + hashType = crypto.SHA384 + hasher = sha512.New384() + pssOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA384} + case "PS512": + hashType = crypto.SHA512 + hasher = sha512.New() + pssOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA512} case "RS256": hashType = crypto.SHA256 hasher = sha256.New() @@ -44,5 +58,9 @@ func (sv *rsaSHASignatureValidator) ValidateSignature(algorithm string, keyID *s return false } hasher.Write([]byte(headerAndPayload)) - return rsa.VerifyPKCS1v15(sv.key, hashType, hasher.Sum(nil), signature) == nil + digest := hasher.Sum(nil) + if pssOpts != nil { + return rsa.VerifyPSS(sv.key, hashType, digest, signature, pssOpts) == nil + } + return rsa.VerifyPKCS1v15(sv.key, hashType, digest, signature) == nil } diff --git a/pkg/jwt/rsa_sha_signature_validator_test.go b/pkg/jwt/rsa_sha_signature_validator_test.go index 5bd8b5aa1..fd484e4f5 100644 --- a/pkg/jwt/rsa_sha_signature_validator_test.go +++ b/pkg/jwt/rsa_sha_signature_validator_test.go @@ -1,7 +1,12 @@ package jwt_test import ( + "bytes" + "crypto" + "crypto/rand" "crypto/rsa" + "crypto/sha256" + "crypto/sha512" "crypto/x509" "encoding/pem" "testing" @@ -276,4 +281,96 @@ mwIDAQAB 0x21, 0x5d, 0x0e, 0x9c, 0xf2, 0x7d, 0x07, 0x0f, 0x54, 0x13, 0xc7, 0xef, 0xe0, 0x25, 0xec, 0xe5, })) + + // RSA-PSS uses a different (statically embedded) private key + // because PSS signatures are randomized via salt and so must be + // generated at test time. The "invalid signature" cases below + // flip a byte of the valid signature at runtime. + pssBlock, _ := pem.Decode([]byte(`-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC16sKpfp7wxsn+ +nRsNgITBsn9vGOOVgRIw+Vm/riMo0b+8k2rKjhBnuY2+5ibn1bSC4m0U227FdaSQ +DhurnIx0+/ose5kPPii7tg/jA3Zh/awha+s+I7wvXFiZfqsKjMRtQKrMLLH4xSYR +mDOQt9lAeCRBxBv//ft8p8oKHbQ93TfTiXUQx0l9dHl9vu82PAAPT5Qpi2Afl8m1 +Z8IomeBdj5W9XhwcoIINvexrwr6L/qmMTeqa+0fU/1A9js8AsTeZdAn1WsoGsx+y +3VHvC+rGQ6vrAF5PBMN5dhCn1mBaqiLSeLJAuP214bYkINjDitCHuqazDXQS7W4b +QQ7GLalbAgMBAAECggEAASYwgJ2aJ286C2+2ogkzzX1xLKR9m44rLUAF/x38ft+E +VI6h2BG5rM4RRSlzCXfiIgmIIt+X/YtWtMpCBn2AraSB2hIVz3DXFQh7ZTW6Y7gU +hEmIPOZA5BOzQ5T3Q6PMAdyst4l4vleQe//4I2vl5XBc2kWA38cc9ZahwhqZ7uIa +JCHTXRZ/i0bYXSCks0j+C4gXpquiv1E9L+vNZcuLooHCrrmO9c0oUkosnTRaV9zB +cUdYQofyNRlt8+eY/BsTan6cte4DzKVp6EtLO5R+WLgQyOgiCMMUxVWHhHG/oW50 +TX0i1ZWXWg+NGWmkGvYfXFUAv0oU5zIGC0T8fn1toQKBgQDaq8+9QTMIUKL12UTX +ShbdieJEdzdzIqCURWPrZ/AeZHmVpshTYhKcAt6kSRMSuxkuN0S4USWRwXjpexEV +nqll10J3lkJb6Z6hRmZyTIPDPX5DLe1/lgKWZHYXu9ZXl9+0ZS9DwuMLJUpoZ+UQ +z3gd/LEidaSwUwFRjyCIPwZEWQKBgQDU+MA2oKtchZ8H/OXLSlNOuAwdwZwD5Xn9 +EWmrt2xGROZFC04/2WY/bYC7iLxRZxid8GANQ4Ya1yaGFt2jTjSjB26piLHzRSQ6 +lcUURfKOqyTxaaV1DGpZWNVlxOER7Mlb6U4tvKUjAwwTA4FVzZgQQJwErXAQ34D8 +1HTEsAN00wKBgQCaWRnaOVI/NUPBiunHmNlI6JGYyBmQoEl+PviHaicYHM2hb0cJ +bDk8e94RUi8vUnc0ovhTrZt6JXkmPKLTgtmJNAcLiDkwzVcV+S5I0W9T+WzNGHcC +Tq1m4GRm3kQuMdpKZ/2Ts9U0wc6ioWsTkY30hK+3ZhioCP7uRbutz+apiQKBgF4O +vlViAEyMdwAAITz3RnOttSwvJchSwN2ToyfDin4+T7SOmbB5Qz8gDYrFiOYqsiSO +1N0GxWN1Qf5Weux0zapyzdzyEiVuk+GL485gVg/MZjR4hCp9oTp0kUqw+PYBrax6 +DZ0Fg6lC30JGegh7FH2ZC07FiojpLP58llWHpv8hAoGALLmiK8EpcluJhKRt+P3q +gPKxDMHohCfDuaSRTVANbaOL4s84j+COYlbG3IFtsCyCZTd1vTQKcB7qZOtTmzoe +CaCi05WOYzU+tEoyjIQnkmvuYKfY6RPLWaOrEl+yFe7cpCV2f/vqJNvTX1C+aH8L +j5iYg0ok2Rc6EhMzXc5kxUY= +-----END PRIVATE KEY-----`)) + require.NotNil(t, pssBlock) + pssParsedKey, err := x509.ParsePKCS8PrivateKey(pssBlock.Bytes) + require.NoError(t, err) + pssKey := pssParsedKey.(*rsa.PrivateKey) + pssValidator := jwt.NewRSASHASignatureValidator(&pssKey.PublicKey) + + // RSA-PSS with SHA-256, both with a valid and invalid signature. + ps256Digest := sha256.Sum256([]byte("eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0")) + ps256ValidSig, err := rsa.SignPSS(rand.Reader, pssKey, crypto.SHA256, ps256Digest[:], + &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA256}) + require.NoError(t, err) + ps256InvalidSig := bytes.Clone(ps256ValidSig) + ps256InvalidSig[0] ^= 0xff + require.True(t, pssValidator.ValidateSignature( + "PS256", + /* keyID = */ nil, + "eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0", + ps256ValidSig)) + require.False(t, pssValidator.ValidateSignature( + "PS256", + /* keyID = */ nil, + "eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0", + ps256InvalidSig)) + + // RSA-PSS with SHA-384, both with a valid and invalid signature. + ps384Digest := sha512.Sum384([]byte("eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0")) + ps384ValidSig, err := rsa.SignPSS(rand.Reader, pssKey, crypto.SHA384, ps384Digest[:], + &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA384}) + require.NoError(t, err) + ps384InvalidSig := bytes.Clone(ps384ValidSig) + ps384InvalidSig[0] ^= 0xff + require.True(t, pssValidator.ValidateSignature( + "PS384", + /* keyID = */ nil, + "eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0", + ps384ValidSig)) + require.False(t, pssValidator.ValidateSignature( + "PS384", + /* keyID = */ nil, + "eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0", + ps384InvalidSig)) + + // RSA-PSS with SHA-512, both with a valid and invalid signature. + ps512Digest := sha512.Sum512([]byte("eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0")) + ps512ValidSig, err := rsa.SignPSS(rand.Reader, pssKey, crypto.SHA512, ps512Digest[:], + &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA512}) + require.NoError(t, err) + ps512InvalidSig := bytes.Clone(ps512ValidSig) + ps512InvalidSig[0] ^= 0xff + require.True(t, pssValidator.ValidateSignature( + "PS512", + /* keyID = */ nil, + "eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0", + ps512ValidSig)) + require.False(t, pssValidator.ValidateSignature( + "PS512", + /* keyID = */ nil, + "eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0", + ps512InvalidSig)) }