Эх сурвалжийг харах

[jwt] Improve external JWT authentication for public APIs (#3337)

- Don't dynamically set value to True for thrift JWT for now.
- Getting in Hive:
Error 500 java.lang.IllegalArgumentException: Illegal base64 character 2e
Harsh Gupta 2 жил өмнө
parent
commit
f18275f189

+ 5 - 2
desktop/conf.dist/hue.ini

@@ -450,10 +450,13 @@ idle_session_timeout=-1
 # Endpoint to fetch the public key from verification server.
 ## key_server_url=https://ext_authz:8000
 
-# The identifier of the service issued the JWT
+# The JWT payload header containing the username.
+## username_header=sub
+
+# The identifier of the service issued the JWT.
 ## issuer=None
 
-# The identifier of the resource intend to access
+# The identifier of the resource intend to access.
 ## audience=None
 
 # Verify custom JWT signature.

+ 5 - 2
desktop/conf/pseudo-distributed.ini.tmpl

@@ -455,10 +455,13 @@
       # Endpoint to fetch the public key from verification server.
       ## key_server_url=https://ext_authz:8000
 
-      # The identifier of the service issued the JWT
+      # The JWT payload header containing the username.
+      ## username_header=sub
+
+      # The identifier of the service issued the JWT.
       ## issuer=None
 
-      # The identifier of the resource intend to access
+      # The identifier of the resource intend to access.
       ## audience=None
 
       # Verify custom JWT signature.

+ 22 - 22
desktop/core/src/desktop/auth/api_authentications.py

@@ -53,28 +53,25 @@ class JwtAuthentication(authentication.BaseAuthentication):
 
     LOG.debug('JwtAuthentication: got access token from %s: %s' % (request.path, access_token))
 
-    public_key_pem = ''
-    if AUTH.JWT.VERIFY.get():
-      public_key_pem = self._handle_public_key(access_token)
+    try:
+      public_key_pem = self._handle_public_key(access_token) if AUTH.JWT.VERIFY.get() else ''
+    except Exception as e:
+      LOG.error('JwtAuthentication: Error fetching public key %s' % str(e))
+      raise exceptions.AuthenticationFailed(e)
 
     params = {
       'jwt': access_token,
       'key': public_key_pem,
       'issuer': AUTH.JWT.ISSUER.get(),
       'audience': AUTH.JWT.AUDIENCE.get(),
-      'algorithms': ["RS256"]
-    }
-
-    if sys.version_info[0] > 2:
-      params['options'] = {
+      'algorithms': ["RS256"],
+      'options': {
         'verify_signature': AUTH.JWT.VERIFY.get()
       }
-    else:
-      params['verify'] = AUTH.JWT.VERIFY.get()
+    }
 
     try:
       payload = jwt.decode(**params)
-
     except jwt.DecodeError:
       LOG.error('JwtAuthentication: Invalid token')
       raise exceptions.AuthenticationFailed('JwtAuthentication: Invalid token')
@@ -90,14 +87,15 @@ class JwtAuthentication(authentication.BaseAuthentication):
     except Exception as e:
       LOG.error('JwtAuthentication: %s' % str(e))
       raise exceptions.AuthenticationFailed(e)
-    
-    if payload.get('user') is None:
-      LOG.debug('JwtAuthentication: no user ID in token')
+
+
+    if payload.get(AUTH.JWT.USERNAME_HEADER.get()) is None: 
+      LOG.debug('JwtAuthentication: no username in token')
       return None
 
-    LOG.debug('JwtAuthentication: got user ID %s and tenant ID %s' % (payload.get('user'), payload.get('tenantId')))
+    LOG.debug('JwtAuthentication: got username %s' % (payload.get(AUTH.JWT.USERNAME_HEADER.get())))
 
-    user = find_or_create_user(payload.get('user'), is_superuser=False)
+    user = find_or_create_user(payload.get(AUTH.JWT.USERNAME_HEADER.get()), is_superuser=False)
     ensure_has_a_group(user)
     user = rewrite_user(user)
 
@@ -112,18 +110,20 @@ class JwtAuthentication(authentication.BaseAuthentication):
 
   def _handle_public_key(self, access_token):
     token_metadata = jwt.get_unverified_header(access_token)
-    headers = {'kid': token_metadata.get('kid', {})} 
+    kid = token_metadata.get('kid')
 
     if AUTH.JWT.KEY_SERVER_URL.get():
-      response = requests.get(AUTH.JWT.KEY_SERVER_URL.get(), headers=headers)
+      response = requests.get(AUTH.JWT.KEY_SERVER_URL.get(), verify=False)
       jwk = json.loads(response.content)
 
       if jwk.get('keys'):
-        public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk["keys"][0])).public_key()
-        public_key_pem = public_key.public_bytes(encoding=serialization.Encoding.PEM,
-                                                 format=serialization.PublicFormat.SubjectPublicKeyInfo)
+        for key in jwk.get('keys'):
+          if key.get('kid') and key.get('kid') == kid:
+            public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key))
+            public_key_pem = public_key.public_bytes(encoding=serialization.Encoding.PEM,
+                                                     format=serialization.PublicFormat.SubjectPublicKeyInfo)
 
-        return public_key_pem
+            return public_key_pem
 
 
 class DummyCustomAuthentication(authentication.BaseAuthentication):

+ 78 - 65
desktop/core/src/desktop/auth/api_authentications_tests.py

@@ -39,22 +39,18 @@ else:
 
 class TestJwtAuthentication():
 
-  @classmethod
-  def setUpClass(cls):
-    if sys.version_info[0] < 3:
-      raise SkipTest
-
-
   def setUp(self):
     self.client = make_logged_in_client(username="test_user", groupname="default", recreate=True, is_superuser=False)
     self.user = rewrite_user(User.objects.get(username="test_user"))
 
-    self.sample_token = "eyJhbGciOiJSUzI1NiJ9.eyJhdWQiOlsid29ya2xvYWQtYXBwIiwicmFuZ2VyIl0sImV4cCI6MTYyNjI1Njg5MywiaWF" \
-                        "0IjoxNjI2MjU2NTkzLCJpc3MiOiJDbG91ZGVyYTEiLCJqdGkiOiJpZDEiLCJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJ1c2V" \
-                        "yIjoidGVzdF91c2VyIn0.jvyVDxbWTAik0jbdUcIc9ZANNrJZUCWH-Pg7FloRhg0ZYAETd_AO3p5v_ppoMmVcPD2xBSr" \
-                        "ngA5J3_A_zPBvQ_hdDlpb0_-mCCJfGhC5tju4bI9EE9Akdn2FrrsqrvQQ8cPyGsIlvoIxrK1De4f74MmUaxfN7Hrrcue" \
-                        "1PTY4u4IB9cWQqV9vIcX99Od5PUaNekLIee-I8gweqvfGEEsW7qWUM63nh59_TOB3LLq-YcEuaX1h_oiTATeCssjk_ee" \
-                        "9RrJGLNyKmC0WJ4UrEWn8a_T3bwCy8CMe0zV5PSuuvPHy0FvnTo2il5SDjGimxKcbpgNiJdfblslu6i35DlfiWg"
+    self.sample_token = "eyJhbGciOiJSUzI1NiIsImtpZCI6InJ0R1VlRjlickdVa3pTZUk1YUg2REpaa1hFQXJ3cHFNcWJTTlo0aTd3SGMifQ.e" \
+                        "yJhdWQiOlsid29ya2xvYWQtYXBwIiwicmFuZ2VyIl0sImV4cCI6MTYyNjI1Njg5MywiaWF0IjoxNjI2MjU2NTkzLCJpc" \
+                        "3MiOiJDbG91ZGVyYTEiLCJqdGkiOiJpZDEiLCJzdWIiOiJ0ZXN0X3VzZXIifQ.jvyVDxbWTAik0jbdUcIc9ZANNrJZUC" \
+                        "WH-Pg7FloRhg0ZYAETd_AO3p5v_ppoMmVcPD2xBSrngA5J3_A_zPBvQ_hdDlpb0_-mCCJfGhC5tju4bI9EE9Akdn2Frr" \
+                        "sqrvQQ8cPyGsIlvoIxrK1De4f74MmUaxfN7Hrrcue1PTY4u4IB9cWQqV9vIcX99Od5PUaNekLIee-I8gweqvfGEEsW7q" \
+                        "WUM63nh59_TOB3LLq-YcEuaX1h_oiTATeCssjk_ee9RrJGLNyKmC0WJ4UrEWn8a_T3bwCy8CMe0zV5PSuuvPHy0FvnTo" \
+                        "2il5SDjGimxKcbpgNiJdfblslu6i35DlfiWg"
+
     self.request = MagicMock(
       META={
         "HTTP_AUTHORIZATION": "Bearer " + self.sample_token
@@ -65,34 +61,48 @@ class TestJwtAuthentication():
   def test_authenticate_existing_user(self):
     with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
       with patch('desktop.auth.api_authentications.requests.get'):
-
         jwt_decode.return_value = {
-          "user": "test_user"
+          "sub": "test_user"
         }
+        resets = [
+          AUTH.JWT.VERIFY.set_for_testing(False),
+          AUTH.JWT.USERNAME_HEADER.set_for_testing('sub')
+        ]
 
-        user, token = JwtAuthentication().authenticate(request=self.request)
+        try:
+          user, token = JwtAuthentication().authenticate(request=self.request)
 
-        assert_equal(user, self.user)
-        assert_true(user.is_authenticated)
-        assert_false(user.is_superuser)
+          assert_equal(user, self.user)
+          assert_true(user.is_authenticated)
+          assert_false(user.is_superuser)
+        finally:
+          for reset in resets:
+            reset()
 
 
   def test_authenticate_new_user(self):
     with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
       with patch('desktop.auth.api_authentications.requests.get'):
-
         jwt_decode.return_value = {
-          "user": "test_new_user"
+          "sub": "test_new_user"
         }
 
         assert_false(User.objects.filter(username="test_new_user").exists())
 
-        user, token = JwtAuthentication().authenticate(request=self.request)
+        resets = [
+          AUTH.JWT.VERIFY.set_for_testing(False),
+          AUTH.JWT.USERNAME_HEADER.set_for_testing('sub')
+        ]
+        try:
+          user, token = JwtAuthentication().authenticate(request=self.request)
 
-        assert_true(User.objects.filter(username="test_new_user").exists())
-        assert_equal(User.objects.get(username="test_new_user"), user)
-        assert_true(user.is_authenticated)
-        assert_false(user.is_superuser)
+          assert_true(User.objects.filter(username="test_new_user").exists())
+          assert_equal(User.objects.get(username="test_new_user"), user)
+          assert_true(user.is_authenticated)
+          assert_false(user.is_superuser)
+        finally:
+          for reset in resets:
+            reset()
 
 
   def test_failed_authentication(self):
@@ -113,12 +123,20 @@ class TestJwtAuthentication():
     with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
       with patch('desktop.auth.api_authentications.requests.get'):
         jwt_decode.return_value = {
-          "user": "test_user"
+          "sub": "test_user"
         }
-        user, token = JwtAuthentication().authenticate(request=self.request)
+        resets = [
+          AUTH.JWT.VERIFY.set_for_testing(False),
+          AUTH.JWT.USERNAME_HEADER.set_for_testing('sub')
+        ]
+        try:
+          user, token = JwtAuthentication().authenticate(request=self.request)
 
-        assert_true('jwt_access_token' in user.profile.data)
-        assert_equal(user.profile.data['jwt_access_token'], self.sample_token)
+          assert_true('jwt_access_token' in user.profile.data)
+          assert_equal(user.profile.data['jwt_access_token'], self.sample_token)
+        finally:
+          for reset in resets:
+            reset()
 
 
   def test_check_token_verification_flag(self):
@@ -127,20 +145,28 @@ class TestJwtAuthentication():
         with patch('desktop.auth.api_authentications.JwtAuthentication._handle_public_key'):
 
           # When verification flag is True for old sample token
-          reset = AUTH.JWT.VERIFY.set_for_testing(True)
+          resets = [
+            AUTH.JWT.VERIFY.set_for_testing(True),
+            AUTH.JWT.USERNAME_HEADER.set_for_testing('sub')
+          ]
           try:
             assert_raises(exceptions.AuthenticationFailed, JwtAuthentication().authenticate, self.request)
           finally:
-            reset()
+            for reset in resets:
+              reset()
 
           # When verification flag is False
-          reset = AUTH.JWT.VERIFY.set_for_testing(False)
+          resets = [
+            AUTH.JWT.VERIFY.set_for_testing(False),
+            AUTH.JWT.USERNAME_HEADER.set_for_testing('sub')
+          ]
           try:
             user, token = JwtAuthentication().authenticate(request=self.request)
 
             assert_equal(user, self.user)
           finally:
-            reset()
+            for reset in resets:
+              reset()
 
 
   def test_handle_public_key(self):
@@ -148,33 +174,20 @@ class TestJwtAuthentication():
       with patch('desktop.auth.api_authentications.jwt.decode') as jwt_decode:
 
         jwt_decode.return_value = {
-          "user": "test_user"
+          "sub": "test_user"
         }
         jwk = {
           "keys": [
             {
               "kty": "RSA",
-              "kid": "1",
-              "alg": "RSA256",
-              "n": "rtT3gR0NDIx6gv8xYLiPue_ItaIbognCGGgQbipp3IOuobu2RnJjedsIRBTEOdkVx-xjV6m92VYtrpW6gM9vldwTfI0UmoSLGKT"
-                   "5uYd0JGHvYWoN9inCZYZcnala58T8HDgLiXa9KlEuQxGGQDemB3yf5rgS1OhLBKVsI8bMVgah7xNIiBOWsVeWIEr13Nem8HUuDq"
-                   "gIpL_8TgjxFOqFcdqPCfoIZ89JKEiKbsGbU-lqs1xYChFscI_w7Jc7l6rvf2nsLGMFs3U4ZJvS4AUpVno2e527clXzQisfJKwb4"
-                   "hjfKRMhHfnYfyJxaoHqWfx8DjXmH3CMqlWr_-hL3y1-4Q",
               "e": "AQAB",
-              "d": "XVj4jcelH_4hq6_1_V6N3wlYcSKM_oeXStDFdQzQWR02MMS5HgQVeQqp7y_nVbvDFWvx3uySoWiSG5V2bzBStAE9plLtnVMHsbD"
-                   "kZVsdeA-ScMDfk3_Ye7yx1ryF_RoAQlDqWAs-FUojGUxSEhekXnr8JYRDCcq9w01P4ApVL9iX9Togk8MFO68vKRykeFC21TGE87"
-                   "-2_ieIMksDf25r-uhYzdN1FCJuzHRaYBUBgBRq82rgno1f1Y9_j8TN30NQtOLr5UtYkH-iKb_wqgocFG9GamEbBzzZW2_BwRhyw"
-                   "Hm1ciJyiQ_Woikx798HoXlHOEHi8q4G-ay2JUFcbTyAAQ",
-              "p": "5umhRLdRjv30UO53l9gmVs2nUJPD-Uv_vDzx27aemTqaBxjTj_rVo3_KUwunQ4Y9aaaQo9BvlxG-tlmtYuDHYKavxqFQ6Q6jci3"
-                   "OWv2my9515akl5nUWj4SQD9xvve3b7x-nVGRefYmGvscXZU_Ryg1CZ_4FPsfljWwBTo7ggaE",
-              "q": "wdOQhh0NOxj1oI3cod_IQxl-5UjBzRvkm6Yx9r2QyOn2wk60b_ExWA8CrEr-eOSSSc0TMf2Y8vbCjzXSkd2-Gbsz4OOC-AkxY5W"
-                   "4FonLxF8AQabAXeIIfH7qF7Q0ByaZBFFaNQ3ejBunBa5ph0KUrxDrzVf1tcX3b8y8fHIudUE",
-              "dp": "ctEaojtw72PxNsjMaJFOxvytRFClMnGKsMOxEynkBJbx_bNnhwEXd5vUM6Tov5ehM8Zhx0KeKgTlynAe2bqhCLr5Tg_qVmgz91"
-                    "M1d2MGq_pqrw6DTOtNk4E7zNc0LMF4CZe4sSrTHSLkADqotHSTAR_EtEbHvubQiph4seIzWeE",
-              "dq": "q_htG0D9czjC_i-_2PO3OCmP2BkEsloULDF51ST-J_TF1kKEf2mtUScIRRvIyjRqwwYsCMerg66CkxO6_2aRez0IW3kgw7dMVc"
-                    "IJ8h1SaKmtjZJIzUN2Khdk1aEyJEIPs7AGbFog4YjLWRQVV0gwqV9HCAsJ27yIvG4XsgaQx8E",
-              "qi": "lNOWMacUcZtytxeTfeR6OWbqufAp56cICNTZX82JDnoi2KCmyeUERl1tLdYC1giK2lNw5j57ojTigPpyhBdeZ-3NqlJEH8pq6g"
-                    "JXNSpBOWTGzOT_EcW2jaCP4cT8q1Js3pFUynYPdXRU9FG0kdQgNIrDztNZJlPtdFxAVgCM4PY",
+              "use": "sig",
+              "kid": "rtGUeF9brGUkzSeI5aH6DJZkXEArwpqMqbSNZ4i7wHc",
+              "alg": "RS256",
+              "n": "we9gTbRxHl4Ye9mY9abYl_WHgx5QYZTwnHO5G5MX9gOiCbbxBqcOifVywX1_ienElksDIvjuQFL7zOSoXipuBUcfTwdtiOgBpNF"
+              "TvtMB4xjrYABg2nm47umJXNjN9KtMCC49sMp8bOvpgTvedghPhpGBDPoljYL_1VFAezjilCIaaa1NdXQDBSBdupQoxuVrkMiskmVt6lJ"
+              "MAiSPTteOtzXtm1WKvJftKZVk1bdrv-XqMQDxoiPirGZSwqkaKDmrdBinK0LbUPNt06BA7cXl04cgp2eu11tpY6cgnvWEfvK32S1IHci"
+              "XLipfwb1uHIdgX8i1pyiGj_JAQHodICzSww"
             }
           ]
         }
@@ -184,6 +197,7 @@ class TestJwtAuthentication():
 
         resets = [
           AUTH.JWT.VERIFY.set_for_testing(True),
+          AUTH.JWT.USERNAME_HEADER.set_for_testing('sub'),
           AUTH.JWT.KEY_SERVER_URL.set_for_testing('https://ext-authz:8000'),
           AUTH.JWT.ISSUER.set_for_testing('issuer'),
           AUTH.JWT.AUDIENCE.set_for_testing('audience')
@@ -194,18 +208,17 @@ class TestJwtAuthentication():
 
           jwt_decode.assert_called_with(
             algorithms=['RS256'],
-            audience=AUTH.JWT.AUDIENCE.get(),
-            issuer=AUTH.JWT.ISSUER.get(),
-            jwt=self.sample_token,
+            audience='audience',
+            issuer='issuer',
+            jwt=self.sample_token, 
             key=b'-----BEGIN PUBLIC KEY-----\n'
-            b'MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArtT3gR0NDIx6gv8xYLiP\n'
-            b'ue/ItaIbognCGGgQbipp3IOuobu2RnJjedsIRBTEOdkVx+xjV6m92VYtrpW6gM9v\n'
-            b'ldwTfI0UmoSLGKT5uYd0JGHvYWoN9inCZYZcnala58T8HDgLiXa9KlEuQxGGQDem\n'
-            b'B3yf5rgS1OhLBKVsI8bMVgah7xNIiBOWsVeWIEr13Nem8HUuDqgIpL/8TgjxFOqF\n'
-            b'cdqPCfoIZ89JKEiKbsGbU+lqs1xYChFscI/w7Jc7l6rvf2nsLGMFs3U4ZJvS4AUp\n'
-            b'Vno2e527clXzQisfJKwb4hjfKRMhHfnYfyJxaoHqWfx8DjXmH3CMqlWr/+hL3y1+\n'
-            b'4QIDAQAB\n'
-            b'-----END PUBLIC KEY-----\n',
+            b'MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwe9gTbRxHl4Ye9mY9abY\n'
+            b'l/WHgx5QYZTwnHO5G5MX9gOiCbbxBqcOifVywX1/ienElksDIvjuQFL7zOSoXipu\n'
+            b'BUcfTwdtiOgBpNFTvtMB4xjrYABg2nm47umJXNjN9KtMCC49sMp8bOvpgTvedghP\n'
+            b'hpGBDPoljYL/1VFAezjilCIaaa1NdXQDBSBdupQoxuVrkMiskmVt6lJMAiSPTteO\n'
+            b'tzXtm1WKvJftKZVk1bdrv+XqMQDxoiPirGZSwqkaKDmrdBinK0LbUPNt06BA7cXl\n'
+            b'04cgp2eu11tpY6cgnvWEfvK32S1IHciXLipfwb1uHIdgX8i1pyiGj/JAQHodICzS\n'
+            b'wwIDAQAB\n-----END PUBLIC KEY-----\n',
             options={'verify_signature': True}
           )
           assert_equal(user, self.user)

+ 7 - 1
desktop/core/src/desktop/conf.py

@@ -1246,6 +1246,12 @@ AUTH = ConfigSection(
           type=str,
           help=_("The identifier of the resource intend to access")
         ),
+        USERNAME_HEADER=Config(
+          key="username_header",
+          default="sub",
+          type=str,
+          help=_("The JWT payload header containing the username.")
+        ),
         VERIFY=Config(
             key="verify",
             default=True,
@@ -2097,7 +2103,7 @@ USE_THRIFT_HTTP_JWT = Config(
   key="use_thrift_http_jwt",
   help=_("Use JWT as Bearer header for authentication when using Thrift over HTTP transport."),
   type=coerce_bool,
-  dynamic_default=is_jwt_authentication_enabled
+  default=False
 )
 
 DISABLE_LOCAL_STORAGE = Config(