diff --git a/.gitignore b/.gitignore index 6ae333e5..c28366ec 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ coverage.xml dist/ docs/_build/ htmlcov/ + +.idea/ diff --git a/graphql_jwt/settings.py b/graphql_jwt/settings.py index d7257b22..6af4ceee 100644 --- a/graphql_jwt/settings.py +++ b/graphql_jwt/settings.py @@ -59,6 +59,7 @@ "JWT_GET_REFRESH_TOKEN_HANDLER", "JWT_ALLOW_ANY_HANDLER", "JWT_ALLOW_ANY_CLASSES", + "JWT_PUBLIC_KEY", ) @@ -103,6 +104,9 @@ def __getattr__(self, attr): if attr in self.import_strings: value = perform_import(value, attr) + if attr == "JWT_PUBLIC_KEY" and value is not None and callable(value): + value = value() + self._cached_attrs.add(attr) setattr(self, attr, value) return value diff --git a/tests/test_utils.py b/tests/test_utils.py index 21177c6a..8fc66bfe 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,6 +10,12 @@ from .decorators import override_jwt_settings from .testcases import TestCase +PRIVATE_KEY = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend(), +) + class JWTPayloadTests(TestCase): @mock.patch( @@ -35,20 +41,34 @@ def test_issuer(self): class AsymmetricAlgorithmsTests(TestCase): def test_rsa_jwt(self): - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend(), - ) - public_key = private_key.public_key() + public_key = PRIVATE_KEY.public_key() payload = utils.jwt_payload(self.user) with override_jwt_settings( JWT_PUBLIC_KEY=public_key, - JWT_PRIVATE_KEY=private_key, + JWT_PRIVATE_KEY=PRIVATE_KEY, JWT_ALGORITHM="RS256", ): + token = utils.jwt_encode(payload) + decoded = utils.jwt_decode(token) + + self.assertEqual(payload, decoded) + + +def get_rsa_jwt(): + public_key = PRIVATE_KEY.public_key() + return public_key + +class PublicKeyImportStringTest(TestCase): + def test_import_string_public_key(self): + payload = utils.jwt_payload(self.user) + + with override_jwt_settings( + JWT_PRIVATE_KEY=PRIVATE_KEY, + JWT_PUBLIC_KEY="tests.test_utils.get_rsa_jwt", + JWT_ALGORITHM="RS256", + ): token = utils.jwt_encode(payload) decoded = utils.jwt_decode(token)