소스 검색

[jwt] Fetch RSA public key from key_server (#2445)

* [jwt] Convert JWK to PEM formatted public key
Harsh Gupta 4 년 전
부모
커밋
c5478f91e1
2개의 변경된 파일120개의 추가작업 그리고 45개의 파일을 삭제
  1. 21 1
      desktop/core/src/desktop/auth/api_authentications.py
  2. 99 44
      desktop/core/src/desktop/auth/api_authentications_tests.py

+ 21 - 1
desktop/core/src/desktop/auth/api_authentications.py

@@ -16,8 +16,11 @@
 # limitations under the License.
 
 import logging
+import requests
 import jwt
+import json
 
+from cryptography.hazmat.primitives import serialization
 from rest_framework import authentication, exceptions
 
 from desktop.auth.backend import find_or_create_user, ensure_has_a_group, rewrite_user
@@ -49,10 +52,14 @@ class JwtAuthentication(authentication.BaseAuthentication):
 
     LOG.debug('JwtAuthentication: got access token from %s: %s' % (request.path, access_token))
 
+    public_key_pem = ''
+    if AUTH.JWT.VERIFY.get():
+      public_key_pem = self._handle_public_key(access_token)
+
     try:
       payload = jwt.decode(
         access_token,
-        'secret',
+        public_key_pem,
         algorithms=["RS256"],
         verify=AUTH.JWT.VERIFY.get()
       )
@@ -82,6 +89,19 @@ class JwtAuthentication(authentication.BaseAuthentication):
 
     return (user, None)
 
+  def _handle_public_key(self, access_token):
+    token_metadata = jwt.get_unverified_header(access_token)
+    headers = {'kid': token_metadata.get('kid', {})} 
+
+    if AUTH.JWT.KEY_SERVER_URL.get():
+      jwk = requests.get(AUTH.JWT.KEY_SERVER_URL.get(), headers=headers)
+
+      if jwk.get('keys'):
+        public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk["keys"][0])).public_key()
+        public_key_pem = public_key.public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo)
+
+        return public_key_pem
+
 
 class DummyCustomAuthentication(authentication.BaseAuthentication):
   """

+ 99 - 44
desktop/core/src/desktop/auth/api_authentications_tests.py

@@ -48,73 +48,128 @@ class TestJwtAuthentication():
       }
     )
 
+
   def test_authenticate_existing_user(self):
     with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
+      with patch('desktop.auth.api_authentications.requests.get'):
 
-      jwt_decode.return_value = {
-        "user": "test_user"
-      }
+        jwt_decode.return_value = {
+          "user": "test_user"
+        }
 
-      user, token = JwtAuthentication().authenticate(request=self.request)
+        user, token = JwtAuthentication().authenticate(request=self.request)
 
-      assert_equal(user, self.user)
-      assert_true(user.is_authenticated)
-      assert_false(user.is_superuser)
+        assert_equal(user, self.user)
+        assert_true(user.is_authenticated)
+        assert_false(user.is_superuser)
 
 
   def test_authenticate_new_user(self):
     with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
+      with patch('desktop.auth.api_authentications.requests.get'):
 
-      jwt_decode.return_value = {
-        "user": "test_new_user"
-      }
+        jwt_decode.return_value = {
+          "user": "test_new_user"
+        }
 
-      assert_false(User.objects.filter(username="test_new_user").exists())
+        assert_false(User.objects.filter(username="test_new_user").exists())
 
-      user, token = JwtAuthentication().authenticate(request=self.request)
+        user, token = JwtAuthentication().authenticate(request=self.request)
 
-      assert_true(User.objects.filter(username="test_new_user").exists())
-      assert_equal(User.objects.get(username="test_new_user"), user)
-      assert_true(user.is_authenticated)
-      assert_false(user.is_superuser)
+        assert_true(User.objects.filter(username="test_new_user").exists())
+        assert_equal(User.objects.get(username="test_new_user"), user)
+        assert_true(user.is_authenticated)
+        assert_false(user.is_superuser)
 
 
   def test_failed_authentication(self):
     with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
+      with patch('desktop.auth.api_authentications.requests.get'):
+        with patch('desktop.auth.api_authentications.JwtAuthentication._handle_public_key'):
 
-      # Invalid token
-      jwt_decode.side_effect = exceptions.AuthenticationFailed('JwtAuthentication: Invalid token')
-      assert_raises(exceptions.AuthenticationFailed, JwtAuthentication().authenticate, self.request)
+          # Invalid token
+          jwt_decode.side_effect = exceptions.AuthenticationFailed('JwtAuthentication: Invalid token')
+          assert_raises(exceptions.AuthenticationFailed, JwtAuthentication().authenticate, self.request)
 
-      # Expired token
-      jwt_decode.side_effect = exceptions.AuthenticationFailed('JwtAuthentication: Token expired')
-      assert_raises(exceptions.AuthenticationFailed, JwtAuthentication().authenticate, self.request)
+          # Expired token
+          jwt_decode.side_effect = exceptions.AuthenticationFailed('JwtAuthentication: Token expired')
+          assert_raises(exceptions.AuthenticationFailed, JwtAuthentication().authenticate, self.request)
 
 
   def test_check_user_token_storage(self):
     with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
-      jwt_decode.return_value = {
-        "user": "test_user"
-      }
-      user, token = JwtAuthentication().authenticate(request=self.request)
+      with patch('desktop.auth.api_authentications.requests.get'):
+        jwt_decode.return_value = {
+          "user": "test_user"
+        }
+        user, token = JwtAuthentication().authenticate(request=self.request)
 
-      assert_true('jwt_access_token' in user.profile.data)
-      assert_equal(user.profile.data['jwt_access_token'], self.sample_token)
+        assert_true('jwt_access_token' in user.profile.data)
+        assert_equal(user.profile.data['jwt_access_token'], self.sample_token)
 
-  def test_check_token_verification_flag(self):
 
-    # When verification flag is True for old sample token
-    reset = AUTH.JWT.VERIFY.set_for_testing(True)
-    try:
-      assert_raises(exceptions.AuthenticationFailed, JwtAuthentication().authenticate, self.request)
-    finally:
-      reset()
-
-    # When verification flag is False
-    reset = AUTH.JWT.VERIFY.set_for_testing(False)
-    try:
-      user, token = JwtAuthentication().authenticate(request=self.request)
-
-      assert_equal(user, self.user)
-    finally:
-      reset()
+  def test_check_token_verification_flag(self):
+    with patch('desktop.auth.api_authentications.requests.get'):
+      with patch('desktop.auth.api_authentications.jwt.algorithms.RSAAlgorithm.from_jwk'):
+        with patch('desktop.auth.api_authentications.JwtAuthentication._handle_public_key'):
+
+          # When verification flag is True for old sample token
+          reset = AUTH.JWT.VERIFY.set_for_testing(True)
+          try:
+            assert_raises(exceptions.AuthenticationFailed, JwtAuthentication().authenticate, self.request)
+          finally:
+            reset()
+
+          # When verification flag is False
+          reset = AUTH.JWT.VERIFY.set_for_testing(False)
+          try:
+            user, token = JwtAuthentication().authenticate(request=self.request)
+
+            assert_equal(user, self.user)
+          finally:
+            reset()
+
+
+  def test_handle_public_key(self):
+    with patch('desktop.auth.api_authentications.requests.get') as key_server_request:
+      with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
+
+        jwt_decode.return_value = {
+          "user": "test_user"
+        }
+        key_server_request.return_value = {
+          "keys": [
+            {
+              "kty": "RSA",
+              "kid": "1",
+              "alg": "RSA256",
+              "n": "rtT3gR0NDIx6gv8xYLiPue_ItaIbognCGGgQbipp3IOuobu2RnJjedsIRBTEOdkVx-xjV6m92VYtrpW6gM9vldwTfI0UmoSLGKT5uYd0JGHvYWoN9inCZYZcnala58T8HDgLiXa9KlEuQxGGQDemB3yf5rgS1OhLBKVsI8bMVgah7xNIiBOWsVeWIEr13Nem8HUuDqgIpL_8TgjxFOqFcdqPCfoIZ89JKEiKbsGbU-lqs1xYChFscI_w7Jc7l6rvf2nsLGMFs3U4ZJvS4AUpVno2e527clXzQisfJKwb4hjfKRMhHfnYfyJxaoHqWfx8DjXmH3CMqlWr_-hL3y1-4Q",
+              "e": "AQAB",
+              "d": "XVj4jcelH_4hq6_1_V6N3wlYcSKM_oeXStDFdQzQWR02MMS5HgQVeQqp7y_nVbvDFWvx3uySoWiSG5V2bzBStAE9plLtnVMHsbDkZVsdeA-ScMDfk3_Ye7yx1ryF_RoAQlDqWAs-FUojGUxSEhekXnr8JYRDCcq9w01P4ApVL9iX9Togk8MFO68vKRykeFC21TGE87-2_ieIMksDf25r-uhYzdN1FCJuzHRaYBUBgBRq82rgno1f1Y9_j8TN30NQtOLr5UtYkH-iKb_wqgocFG9GamEbBzzZW2_BwRhywHm1ciJyiQ_Woikx798HoXlHOEHi8q4G-ay2JUFcbTyAAQ",
+              "p": "5umhRLdRjv30UO53l9gmVs2nUJPD-Uv_vDzx27aemTqaBxjTj_rVo3_KUwunQ4Y9aaaQo9BvlxG-tlmtYuDHYKavxqFQ6Q6jci3OWv2my9515akl5nUWj4SQD9xvve3b7x-nVGRefYmGvscXZU_Ryg1CZ_4FPsfljWwBTo7ggaE",
+              "q": "wdOQhh0NOxj1oI3cod_IQxl-5UjBzRvkm6Yx9r2QyOn2wk60b_ExWA8CrEr-eOSSSc0TMf2Y8vbCjzXSkd2-Gbsz4OOC-AkxY5W4FonLxF8AQabAXeIIfH7qF7Q0ByaZBFFaNQ3ejBunBa5ph0KUrxDrzVf1tcX3b8y8fHIudUE",
+              "dp": "ctEaojtw72PxNsjMaJFOxvytRFClMnGKsMOxEynkBJbx_bNnhwEXd5vUM6Tov5ehM8Zhx0KeKgTlynAe2bqhCLr5Tg_qVmgz91M1d2MGq_pqrw6DTOtNk4E7zNc0LMF4CZe4sSrTHSLkADqotHSTAR_EtEbHvubQiph4seIzWeE",
+              "dq": "q_htG0D9czjC_i-_2PO3OCmP2BkEsloULDF51ST-J_TF1kKEf2mtUScIRRvIyjRqwwYsCMerg66CkxO6_2aRez0IW3kgw7dMVcIJ8h1SaKmtjZJIzUN2Khdk1aEyJEIPs7AGbFog4YjLWRQVV0gwqV9HCAsJ27yIvG4XsgaQx8E",
+              "qi": "lNOWMacUcZtytxeTfeR6OWbqufAp56cICNTZX82JDnoi2KCmyeUERl1tLdYC1giK2lNw5j57ojTigPpyhBdeZ-3NqlJEH8pq6gJXNSpBOWTGzOT_EcW2jaCP4cT8q1Js3pFUynYPdXRU9FG0kdQgNIrDztNZJlPtdFxAVgCM4PY",
+            }
+          ]
+        }
+
+        resets = [
+          AUTH.JWT.VERIFY.set_for_testing(True),
+          AUTH.JWT.KEY_SERVER_URL.set_for_testing('https://ext-authz:8000')
+        ]
+
+        try:
+          user, token = JwtAuthentication().authenticate(request=self.request)
+
+          jwt_decode.assert_called_with(
+            self.sample_token,
+            b'-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArtT3gR0NDIx6gv8xYLiP\nue/ItaIbognCGGgQbipp3IOuobu2RnJjedsIRBTEOdkVx+xjV6m92VYtrpW6gM9v\nldwTfI0UmoSLGKT5uYd0JGHvYWoN9inCZYZcnala58T8HDgLiXa9KlEuQxGGQDem\nB3yf5rgS1OhLBKVsI8bMVgah7xNIiBOWsVeWIEr13Nem8HUuDqgIpL/8TgjxFOqF\ncdqPCfoIZ89JKEiKbsGbU+lqs1xYChFscI/w7Jc7l6rvf2nsLGMFs3U4ZJvS4AUp\nVno2e527clXzQisfJKwb4hjfKRMhHfnYfyJxaoHqWfx8DjXmH3CMqlWr/+hL3y1+\n4QIDAQAB\n-----END PUBLIC KEY-----\n',
+            algorithms=['RS256'],
+            verify=True
+          )
+          assert_equal(user, self.user)
+        finally:
+          for reset in resets:
+            reset()