diff --git a/jose.py b/jose.py index fa2859c..fe5393c 100644 --- a/jose.py +++ b/jose.py @@ -158,14 +158,17 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', 'Unsupported compression algorithm: {}'.format(compression)) plaintext = compress(plaintext) + adata = _jwe_adata_str(adata, b64encode_url(json_encode(header))) + # body encryption/hash ((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc] iv = rng(AES.block_size) encryption_key = rng(hash_mod.digest_size) + enc_key = encryption_key[-hash_mod.digest_size/2:] # first half + mac_key = encryption_key[:-hash_mod.digest_size/2] # second half - ciphertext = cipher(plaintext, encryption_key[-hash_mod.digest_size/2:], iv) - hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata), - encryption_key[:-hash_mod.digest_size/2], hash_mod) + ciphertext = cipher(plaintext, enc_key, iv) + hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata), mac_key, hash_mod) # cek encryption (cipher, _), _ = JWA[alg] @@ -209,17 +212,26 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): # decrypt body ((_, decipher), _), ((hash_fn, _), mod) = JWA[header['enc']] - if header.get(_TEMP_VER_KEY) == _TEMP_VER: - plaintext = decipher(ciphertext, encryption_key[-mod.digest_size/2:], iv) - hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata), - encryption_key[:-mod.digest_size/2], mod=mod) + enc_key = encryption_key[-mod.digest_size/2:] # first half + mac_key = encryption_key[:-mod.digest_size/2] # second half + + if header.get(_TEMP_VER_KEY) == _TEMP_VER or \ + len(encryption_key) == mod.digest_size: + adata = _jwe_adata_str(adata, jwe[0]) + + hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata), mac_key, mod=mod) + if not const_compare(auth_tag(hash), tag): + raise Error('Mismatched authentication tags') + + plaintext = decipher(ciphertext, enc_key, iv) else: - plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv) - hash = hash_fn(_jwe_hash_str(plaintext, iv, adata, True), - encryption_key[-mod.digest_size:], mod=mod) + enc_key = encryption_key[:-mod.digest_size] # first half + mac_key = encryption_key[-mod.digest_size:] # second half + plaintext = decipher(ciphertext, enc_key, iv) + hash = hash_fn(_jwe_hash_str(plaintext, iv, adata, True), mac_key, mod=mod) - if not const_compare(auth_tag(hash), tag): - raise Error('Mismatched authentication tags') + if not const_compare(auth_tag(hash), tag): + raise Error('Mismatched authentication tags') if 'zip' in header: try: @@ -524,12 +536,28 @@ def _validate(claims, validate_claims, expiry_seconds): _check_not_before(now, not_before) +def _jwe_adata_str(adata, header): + if adata == '': + # draft-ietf-jose-json-web-encryption-40 5.1. Message Encryption step 14 + # 'Let the Additional Authenticated Data encryption parameter be + # ASCII(Encoded Protected Header).' + adata = str(header) + else: + # However if a JWE AAD value is present + # (which can only be the case when using the JWE JSON + # Serialization), instead let the Additional Authenticated Data + # encryption parameter be ASCII(Encoded Protected Header || '.' || + # BASE64URL(JWE AAD)). + adata = '.'.join([str(header), adata]) + return adata + + def _jwe_hash_str(ciphertext, iv, adata='', legacy=False): - # http://tools.ietf.org/html/ - # draft-ietf-jose-json-web-algorithms-24#section-5.2.2.1 if legacy: return '.'.join((adata, iv, ciphertext, str(len(adata)))) - return '.'.join((adata, iv, ciphertext, pack("!Q", len(adata) * 8))) + # http://tools.ietf.org/html/ + # draft-ietf-jose-json-web-algorithms-40#section-5.2.2.1 + return ''.join((adata, iv, ciphertext, pack("!Q", len(adata) * 8))) def _jws_hash_str(header, claims): diff --git a/tests.py b/tests.py index d502698..b49768b 100644 --- a/tests.py +++ b/tests.py @@ -37,9 +37,13 @@ def legacy_encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', if compression is not None: header['zip'] = compression try: - (compress, _) = jose.COMPRESSION[compression] + if compression == 'BAD': + # facilitate test for bad compression algorithm + (compress, _) = jose.COMPRESSION['DEF'] + else: + (compress, _) = jose.COMPRESSION[compression] except KeyError: - raise Error( + raise jose.Error( 'Unsupported compression algorithm: {}'.format(compression)) plaintext = compress(plaintext) @@ -274,13 +278,9 @@ def test_encrypt_invalid_compression_error(self): pass def test_decrypt_invalid_compression_error(self): - jwe = jose.encrypt(claims, rsa_pub_key, compression='DEF') - header = jose.b64encode_url( - '{"alg": "RSA-OAEP", ''"enc": "A128CBC-HS256", "__v": 1,' - '"zip": "BAD"}') - + jwe = legacy_encrypt(claims, rsa_pub_key, compression='BAD') try: - jose.decrypt(jose.JWE(*((header,) + (jwe[1:]))), rsa_priv_key) + jose.decrypt(jwe, rsa_priv_key) self.fail() except jose.Error as e: self.assertEqual(