浏览代码

[raz] Improve RAZ HA unit tests (#3651)

Harsh Gupta 1 年之前
父节点
当前提交
4bd2f36b37
共有 2 个文件被更改,包括 30 次插入19 次删除
  1. 1 1
      desktop/core/src/desktop/lib/raz/raz_client.py
  2. 29 18
      desktop/core/src/desktop/lib/raz/raz_client_test.py

+ 1 - 1
desktop/core/src/desktop/lib/raz/raz_client.py

@@ -171,7 +171,7 @@ class RazClient(object):
 
     raz_response = None
     for r_url in raz_urls_list:
-      r_url = "%s/api/authz/%s/access?doAs=%s" % (r_url.rstrip('/'), self.service, self.username)
+      r_url = "%s/api/authz/%s/access?doAs=%s" % (r_url.strip(' ').rstrip('/'), self.service, self.username)
       LOG.info('Attempting to connect to RAZ URL: %s' % r_url)
 
       try:

+ 29 - 18
desktop/core/src/desktop/lib/raz/raz_client_test.py

@@ -14,8 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import base64
-import sys
 import unittest
 
 from nose.tools import assert_equal, assert_true, assert_raises
@@ -23,10 +21,7 @@ from nose.tools import assert_equal, assert_true, assert_raises
 from desktop.lib.raz.raz_client import RazClient, get_raz_client
 from desktop.lib.exceptions_renderable import PopupException
 
-if sys.version_info[0] > 2:
-  from unittest.mock import patch, Mock
-else:
-  from mock import patch, Mock
+from unittest.mock import patch, Mock
 
 
 class RazClientTest(unittest.TestCase):
@@ -34,7 +29,7 @@ class RazClientTest(unittest.TestCase):
   def setUp(self):
     self.username = 'gethue'
     self.raz_url = 'https://raz.gethue.com:8080'
-    self.raz_urls_ha = 'https://raz_host_1.gethue.com:8080/,https://raz_host_2.gethue.com:8080/'
+    self.raz_urls_ha = 'https://raz_host_1.gethue.com:8080/, https://raz_host_2.gethue.com:8080/'
 
     self.s3_path = 'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv'
     self.adls_path = 'https://gethuestorage.dfs.core.windows.net/gethue-container/user/csso_hueuser/customer.csv'
@@ -346,12 +341,8 @@ class RazClientTest(unittest.TestCase):
 
               resp = client.check_access(method='GET', url=self.s3_path)
 
-              if sys.version_info[0] > 2:
-                signed_request = 'CiRodHRwczovL2dldGh1ZS10ZXN0LnMzLmFtYXpvbmF3cy5jb20Q' \
-                  'ATIYZ2V0aHVlL2RhdGEvY3VzdG9tZXIuY3N2OABCAnMzSgJzMw=='
-              else:
-                signed_request = b'CiRodHRwczovL2dldGh1ZS10ZXN0LnMzLmFtYXpvbmF3cy5jb20Q' \
-                  b'ATIYZ2V0aHVlL2RhdGEvY3VzdG9tZXIuY3N2OABCAnMzSgJzMw=='
+              signed_request = 'CiRodHRwczovL2dldGh1ZS10ZXN0LnMzLmFtYXpvbmF3cy5jb20Q' \
+                'ATIYZ2V0aHVlL2RhdGEvY3VzdG9tZXIuY3N2OABCAnMzSgJzMw=='
 
               requests_post.assert_called_with(
                 'https://raz.gethue.com:8080/api/authz/s3/access?doAs=gethue',
@@ -382,10 +373,10 @@ class RazClientTest(unittest.TestCase):
   def test_handle_raz_ha(self):
     with patch('desktop.lib.sdxaas.knox_jwt.requests_kerberos.HTTPKerberosAuth') as HTTPKerberosAuth:
       with patch('desktop.lib.raz.raz_client.requests.post') as requests_post:
-        requests_post.return_value = Mock(status_code=200)
         request_data = Mock()
 
         # Non-HA mode
+        requests_post.return_value = Mock(status_code=200)
         client = RazClient(self.raz_url, 'kerberos', username=self.username, service="s3", service_name="cm_s3", cluster_name="cl1")
         raz_response = client._handle_raz_ha(self.raz_url, auth_handler=HTTPKerberosAuth(), data=request_data, headers={})
 
@@ -397,9 +388,13 @@ class RazClientTest(unittest.TestCase):
           verify=False
         )
         assert_equal(raz_response.status_code, 200)
+        assert_equal(requests_post.call_count, 1)
+        requests_post.reset_mock()
 
-        # HA mode - where first URL sends 200 status code
+        # HA mode - When RAZ instance1 is healthy and RAZ instance2 is unhealthy
         client = RazClient(self.raz_urls_ha, 'kerberos', username=self.username, service="s3", service_name="cm_s3", cluster_name="cl1")
+
+        requests_post.side_effect = [Mock(status_code=200), Mock(status_code=404)]
         raz_response = client._handle_raz_ha(self.raz_urls_ha, auth_handler=HTTPKerberosAuth(), data=request_data, headers={})
 
         requests_post.assert_called_with(
@@ -410,12 +405,28 @@ class RazClientTest(unittest.TestCase):
           verify=False
         )
         assert_equal(raz_response.status_code, 200)
+        assert_equal(requests_post.call_count, 1)
+        requests_post.reset_mock()
 
-        # When no RAZ URL is healthy
-        requests_post.return_value = Mock(status_code=404)
+        # HA mode - When RAZ instance1 is unhealthy and RAZ instance2 is healthy
+        requests_post.side_effect = [Mock(status_code=404), Mock(status_code=200)]
+        raz_response = client._handle_raz_ha(self.raz_urls_ha, auth_handler=HTTPKerberosAuth(), data=request_data, headers={})
 
-        client = RazClient(self.raz_urls_ha, 'kerberos', username=self.username, service="s3", service_name="cm_s3", cluster_name="cl1")
+        requests_post.assert_called_with(
+          'https://raz_host_2.gethue.com:8080/api/authz/s3/access?doAs=gethue',
+          auth=HTTPKerberosAuth(),
+          headers={}, 
+          json=request_data, 
+          verify=False
+        )
+        assert_equal(raz_response.status_code, 200)
+        assert_equal(requests_post.call_count, 2)
+        requests_post.reset_mock()
+
+        # When no RAZ instance is healthy
+        requests_post.side_effect = [Mock(status_code=404), Mock(status_code=404)]
         raz_response = client._handle_raz_ha(self.raz_urls_ha, auth_handler=HTTPKerberosAuth(), data=request_data, headers={})
 
         assert_equal(raz_response, None)
+        assert_equal(requests_post.call_count, 2)