소스 검색

[raz_adls] Add RAZ req mapping and update tests (#2499)

- Much better 'container' and 'relative_path' value extraction.
- Added mapping v1 for stats(), listdir() and read() calls.
- Update unit tests for all calls above.
Harsh Gupta 4 년 전
부모
커밋
3f0665074c

+ 48 - 11
desktop/core/src/desktop/lib/raz/raz_client.py

@@ -33,9 +33,10 @@ from desktop.lib.exceptions_renderable import PopupException
 import desktop.lib.raz.signer_protos_pb2 as raz_signer
 
 if sys.version_info[0] > 2:
-  from urllib.parse import urlparse as lib_urlparse
+  from urllib.parse import urlparse as lib_urlparse, unquote as lib_urlunquote
 else:
   from urlparse import urlparse as lib_urlparse
+  from urllib import unquote as lib_urlunquote
 
 
 LOG = logging.getLogger(__name__)
@@ -56,16 +57,21 @@ class RazToken:
 
   def get_delegation_token(self, user):
     ip_address = socket.gethostbyname(self.raz_hostname)
-    GET_PARAMS = {"op": "GETDELEGATIONTOKEN", "service": "%s:%s" % (ip_address, self.raz_port), "renewer": AUTH_USERNAME.get(), "doAs": user}
+    GET_PARAMS = {
+      "op": "GETDELEGATIONTOKEN",
+      "service": "%s:%s" % (ip_address, self.raz_port),
+      "renewer": AUTH_USERNAME.get(),
+      "doAs": user
+    }
     r = requests.get(self.raz_url, GET_PARAMS, auth=self.auth_handler, verify=False)
     self.raz_token = json.loads(r.text)['Token']['urlString']
     return self.raz_token
 
   def renew_delegation_token(self, user):
     if self.raz_token is None:
-        self.raz_token = self.get_delegation_token(user=user)
+      self.raz_token = self.get_delegation_token(user=user)
     if (self.init_time - timedelta(hours=8)) > datetime.now():
-        r = requests.put("%s?op=RENEWDELEGATIONTOKEN&token=%s"%(self.raz_url, self.raz_token), auth=self.auth_handler, verify=False)
+      r = requests.put("%s?op=RENEWDELEGATIONTOKEN&token=%s"%(self.raz_url, self.raz_token), auth=self.auth_handler, verify=False)
     return self.raz_token
 
 
@@ -94,6 +100,7 @@ class RazClient(object):
     self.cluster_name = cluster_name
     self.requestid = str(uuid.uuid4())
 
+
   def check_access(self, method, url, params=None, headers=None):
     LOG.debug("Check access: method {%s}, url {%s}, params {%s}, headers {%s}" % (method, url, params, headers))
 
@@ -123,7 +130,7 @@ class RazClient(object):
     raz_url = "%s/api/authz/%s/access?delegation=%s" % (self.raz_url, self.service, self.raz_token)
 
     if self.service == 'adls':
-      self._make_adls_request(request_data, path, resource_path)
+      self._make_adls_request(request_data, method, path, url_params, resource_path)
     elif self.service == 's3':
       self._make_s3_request(request_data, request_headers, method, params, headers, url_params, endpoint, resource_path)
 
@@ -165,9 +172,18 @@ class RazClient(object):
           if signed_response is not None:
             return dict([(i.key, i.value) for i in signed_response.signer_generated_headers])
 
-  def _make_adls_request(self, request_data, path, resource_path):
+
+  def _make_adls_request(self, request_data, method, path, url_params, resource_path):
     storage_account = path.netloc.split('.')[0]
-    container, relative_path = resource_path.split('/', 1)
+    resource_path = resource_path.split('/', 1)
+
+    container = resource_path[0]
+    relative_path = "/"
+
+    if len(resource_path) == 2:
+      relative_path += resource_path[1]
+
+    req_params = self.handle_adls_req_mapping(method, url_params, relative_path)
 
     request_data.update({
       "clientType": "adls",
@@ -175,14 +191,34 @@ class RazClient(object):
         "resource": {
           "storageaccount": storage_account,
           "container": container,
-          "relativepath": relative_path,
+          "relativepath": req_params.get('relative_path'),
         },
-        "resourceOwner": storage_account,
-        "action": "read",
-        "accessTypes":["read"]
+        "action": req_params.get('access_type'),
+        "accessTypes": [req_params.get('access_type')]
       }
     })
 
+
+  def handle_adls_req_mapping(self, method, params, relative_path):
+    access_type = ''
+
+    if method == 'HEAD':
+      # Stats
+      if params.get('action') == 'getStatus':
+        access_type = 'get-status'
+    
+    if method == 'GET':
+      access_type = 'read'
+
+      # List
+      if params.get('resource') == 'filesystem':
+        if params.get('directory'):
+          relative_path += lib_urlunquote(params['directory'])
+          access_type = 'list'
+
+    return {'access_type': access_type, 'relative_path': relative_path}
+
+
   def _make_s3_request(self, request_data, request_headers, method, params, headers, url_params, endpoint, resource_path):
 
     allparams = [raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in url_params.items()]
@@ -211,6 +247,7 @@ class RazClient(object):
       "S3_SIGN_REQUEST": signed_request
     }
 
+
 def get_raz_client(raz_url, username, auth='kerberos', service='s3', service_name='cm_s3', cluster_name='myCluster'):
   if not username:
     from crequest.middleware import CrequestMiddleware

+ 54 - 26
desktop/core/src/desktop/lib/raz/raz_client_test.py

@@ -71,27 +71,27 @@ class RazTokenTest(unittest.TestCase):
 
     with patch('desktop.lib.raz.raz_client.requests.get') as requests_get:
       with patch('desktop.lib.raz.raz_client.socket.gethostbyname') as gethostbyname:
-          requests_get.return_value = Mock(
-            text='{"Token":{"urlString":"https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv?' + \
-                  'AWSAccessKeyId=AKIA23E77ZX2HVY76YGL' + \
-                  '&Signature=3lhK%2BwtQ9Q2u5VDIqb4MEpoY3X4%3D&Expires=1617207304"}}'
-          )
-          gethostbyname.return_value = '128.0.0.1'
-          token = RazToken(raz_url='https://raz.gethue.com:8080', auth_handler=kerb_auth)
+        requests_get.return_value = Mock(
+          text='{"Token":{"urlString":"https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv?' + \
+                'AWSAccessKeyId=AKIA23E77ZX2HVY76YGL' + \
+                '&Signature=3lhK%2BwtQ9Q2u5VDIqb4MEpoY3X4%3D&Expires=1617207304"}}'
+        )
+        gethostbyname.return_value = '128.0.0.1'
+        token = RazToken(raz_url='https://raz.gethue.com:8080', auth_handler=kerb_auth)
 
-          t = token.renew_delegation_token(user=self.username)
+        t = token.renew_delegation_token(user=self.username)
 
-          assert_equal(t,
-            'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv?AWSAccessKeyId=AKIA23E77ZX2HVY76YGL&'
-            'Signature=3lhK%2BwtQ9Q2u5VDIqb4MEpoY3X4%3D&Expires=1617207304'
-          )
+        assert_equal(t,
+          'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv?AWSAccessKeyId=AKIA23E77ZX2HVY76YGL&'
+          'Signature=3lhK%2BwtQ9Q2u5VDIqb4MEpoY3X4%3D&Expires=1617207304'
+        )
 
-          with patch('desktop.lib.raz.raz_client.requests.put') as requests_put:
-            token.init_time += timedelta(hours=9)
+        with patch('desktop.lib.raz.raz_client.requests.put') as requests_put:
+          token.init_time += timedelta(hours=9)
 
-            t = token.renew_delegation_token(user=self.username)
+          t = token.renew_delegation_token(user=self.username)
 
-            requests_put.assert_called()
+          requests_put.assert_called()
 
 
 class RazClientTest(unittest.TestCase):
@@ -102,7 +102,8 @@ class RazClientTest(unittest.TestCase):
     self.raz_token = "mock_RAZ_token"
 
     self.s3_path = 'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv'
-    self.adls_path = 'https://gethuestorageaccount.dfs.core.windows.net/demo-gethue-container/demo-dir1/customer.csv'
+    self.adls_path = 'https://gethuestorage.dfs.core.windows.net/gethue-container/user/csso_hueuser/customer.csv'
+
 
   def test_get_raz_client_adls(self):
     with patch('desktop.lib.raz.raz_client.RazToken') as RazToken:
@@ -123,6 +124,7 @@ class RazClientTest(unittest.TestCase):
         assert_equal(client.service_name, 'gethue_adls')
         assert_equal(client.cluster_name, 'gethueCluster')
 
+
   def test_check_access_adls(self):
     with patch('desktop.lib.raz.raz_client.requests.post') as requests_post:
       with patch('desktop.lib.raz.raz_client.uuid.uuid4') as uuid:
@@ -143,6 +145,7 @@ class RazClientTest(unittest.TestCase):
 
         client = RazClient(self.raz_url, self.raz_token, username=self.username, service="adls", service_name="cm_adls", cluster_name="cl1")
 
+        # Read file operation
         resp = client.check_access(method='GET', url=self.adls_path)
 
         requests_post.assert_called_with(
@@ -163,11 +166,10 @@ class RazClientTest(unittest.TestCase):
             'context': {}, 
             'operation': {
               'resource': {
-                'storageaccount': 'gethuestorageaccount', 
-                'container': 'demo-gethue-container', 
-                'relativepath': 'demo-dir1/customer.csv'
-              }, 
-              'resourceOwner': 'gethuestorageaccount', 
+                'storageaccount': 'gethuestorage', 
+                'container': 'gethue-container', 
+                'relativepath': '/user/csso_hueuser/customer.csv'
+              },
               'action': 'read', 
               'accessTypes': ['read']
             }
@@ -176,6 +178,32 @@ class RazClientTest(unittest.TestCase):
         )
         assert_equal(resp['token'], "nulltenantIdnullnullbnullALLOWEDnullnull1.05nSlN7t/QiPJ1OFlCruTEPLibFbAhEYYj5wbJuaeQqs=")
 
+
+  def test_handle_adls_action_types_mapping(self):
+
+    client = RazClient(self.raz_url, self.raz_token, username=self.username, service="adls", service_name="cm_adls", cluster_name="cl1")
+
+    # List directory
+    method = 'GET'
+    relative_path = '/'
+    url_params = {'directory': 'user%2Fcsso_hueuser', 'resource': 'filesystem', 'recursive': 'false'}
+
+    response = client.handle_adls_req_mapping(method, url_params, relative_path)
+
+    assert_equal(response['access_type'], 'list')
+    assert_equal(response['relative_path'], '/user/csso_hueuser')
+
+    # Stats
+    method = 'HEAD'
+    relative_path = '/user/csso_hueuser'
+    url_params = {'action': 'getStatus'}
+
+    response = client.handle_adls_req_mapping(method, url_params, relative_path)
+
+    assert_equal(response['access_type'], 'get-status')
+    assert_equal(response['relative_path'], '/user/csso_hueuser')
+
+
   def test_get_raz_client_s3(self):
     with patch('desktop.lib.raz.raz_client.RazToken') as RazToken:
       with patch('desktop.lib.raz.raz_client.requests_kerberos.HTTPKerberosAuth') as HTTPKerberosAuth:
@@ -195,6 +223,7 @@ class RazClientTest(unittest.TestCase):
         assert_equal(client.service_name, 'gethue_s3')
         assert_equal(client.cluster_name, 'gethueCluster')
 
+
   def test_check_access_s3(self):
     with patch('desktop.lib.raz.raz_client.requests.post') as requests_post:
       with patch('desktop.lib.raz.raz_client.raz_signer.SignResponseProto') as SignResponseProto:
@@ -242,17 +271,16 @@ class RazClientTest(unittest.TestCase):
                 'userGroups': [],
                 'clientIpAddress': '',
                 'clientType': '',
-                'clusterName':
-                'myCluster',
+                'clusterName': 'myCluster',
                 'clusterType': '',
                 'sessionId': '',
                 'accessTime': '',
                 'context': {
-                  'S3_SIGN_REQUEST': b'CiRodHRwczovL2dldGh1ZS10ZXN0LnMzLmFtYXpvbmF3cy5jb20QATIYZ2V0aHVlL2RhdGEvY3VzdG9tZXIuY3N2OABCAnMzSgJzMw=='
+                  'S3_SIGN_REQUEST': b'CiRodHRwczovL2dldGh1ZS10ZXN0LnMzLmFtYXpvbmF3cy5jb20Q' \
+                    b'ATIYZ2V0aHVlL2RhdGEvY3VzdG9tZXIuY3N2OABCAnMzSgJzMw=='
                 }
               },
               verify=False
             )
-
             assert_true(resp)
             assert_equal(resp['AWSAccessKeyId'], 'AKIA23E77ZX2HVY76YGL')

+ 1 - 1
desktop/core/src/desktop/lib/rest/raz_http_client.py

@@ -56,7 +56,7 @@ class RazHttpClient(HttpClient):
     if response and response.get('token'):
       signed_url += ('?' if '?' not in url else '&') + response.get('token')
     else:
-      raise PopupException(_('No SAS token in response'), error_code=503)
+      raise PopupException(_('No SAS token in RAZ response'), error_code=503)
 
     # Required because `self._make_url` is called in base class execute method also
     signed_path = path + signed_url.partition(path)[2]