Browse Source

[raz] Raz Client for ADLS to submit proper requests to getback SAS token (#2362)

- Update existing S3 client to support ADLS
- Updated UTs

- Currently, it sends the read request to get ADLS SAS Token, we need more info for mapping with request methods to read/write/delete ops for the token
Harsh Gupta 4 years ago
parent
commit
1da17c44ed

+ 62 - 31
desktop/core/src/desktop/lib/raz/raz_client.py

@@ -76,6 +76,7 @@ class RazClient(object):
     self.raz_token = raz_token
     self.username = username
     self.service = service
+
     if self.service == 'adls':
       self.service_params = {
         'endpoint_prefix': 'adls',
@@ -88,6 +89,7 @@ class RazClient(object):
         'service_name': 's3',
         'serviceType': 's3'
       }
+
     self.service_name = service_name
     self.cluster_name = cluster_name
     self.requestid = str(uuid.uuid4())
@@ -100,51 +102,34 @@ class RazClient(object):
     params = params if params is not None else {}
     headers = headers if headers is not None else {}
 
-    allparams = [raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in url_params.items()]
-    allparams.extend([raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in params.items()])
-    headers = [raz_signer.StringStringMapProto(key=key, value=val) for key, val in headers.items()]
     endpoint = "%s://%s" % (path.scheme, path.netloc)
     resource_path = path.path.lstrip("/")
 
-    LOG.debug(
-      "Preparing sign request with http_method: {%s}, headers: {%s}, parameters: {%s}, endpoint: {%s}, resource_path: {%s}" %
-      (method, headers, allparams, endpoint, resource_path)
-    )
-    raz_req = raz_signer.SignRequestProto(
-        endpoint_prefix=self.service_params['endpoint_prefix'],
-        service_name=self.service_params['service_name'],
-        endpoint=endpoint,
-        http_method=method,
-        headers=headers,
-        parameters=allparams,
-        resource_path=resource_path,
-        time_offset=0
-    )
-    raz_req_serialized = raz_req.SerializeToString()
-    signed_request = base64.b64encode(raz_req_serialized)
-
     request_data = {
       "requestId": self.requestid,
       "serviceType": self.service_params['serviceType'],
       "serviceName": self.service_name,
       "user": self.username,
       "userGroups": [],
-      "accessTime": "",
       "clientIpAddress": "",
       "clientType": "",
       "clusterName": self.cluster_name,
       "clusterType": "",
       "sessionId": "",
-      "context": {
-        "S3_SIGN_REQUEST": signed_request
-      }
+      "accessTime": "",
+      "context": {}
     }
-    headers = {"Content-Type":"application/json", "Accept-Encoding":"gzip,deflate"}
-    raz_url = "%s/api/authz/s3/access?delegation=%s" % (self.raz_url, self.raz_token)
-    LOG.debug('Raz url: %s' % raz_url)
+    request_headers = {"Content-Type": "application/json"}
+    raz_url = "%s/api/authz/%s/access?delegation=%s" % (self.raz_url, self.service, self.raz_token)
 
-    LOG.debug("Sending access check headers: {%s} request_data: {%s}" % (headers, request_data))
-    raz_req = requests.post(raz_url, headers=headers, json=request_data, verify=False)
+    if self.service == 'adls':
+      self._make_adls_request(request_data, path, resource_path)
+    elif self.service == 's3':
+      self._make_s3_request(request_data, request_headers, method, params, headers, url_params, endpoint, resource_path)
+
+    LOG.debug('Raz url: %s' % raz_url)
+    LOG.debug("Sending access check headers: {%s} request_data: {%s}" % (request_headers, request_data))
+    raz_req = requests.post(raz_url, headers=request_headers, json=request_data, verify=False)
 
     signed_response_result = None
     signed_response = None
@@ -164,21 +149,67 @@ class RazClient(object):
       if result == "ALLOWED":
         LOG.debug('Received allowed response %s' % raz_req.json())
         signed_response_data = raz_req.json()["operResult"]["additionalInfo"]
+
         if self.service == 'adls':
           LOG.debug("Received SAS %s" % signed_response_data["ADLS_DSAS"])
           return {'token': signed_response_data["ADLS_DSAS"]}
         else:
           signed_response_result = signed_response_data["S3_SIGN_RESPONSE"]
 
-          if signed_response_result:
+          if signed_response_result is not None:
             raz_response_proto = raz_signer.SignResponseProto()
             signed_response = raz_response_proto.FromString(base64.b64decode(signed_response_result))
             LOG.debug("Received signed Response %s" % signed_response)
 
           # Signed headers "only"
-          if signed_response:
+          if signed_response is not None:
             return dict([(i.key, i.value) for i in signed_response.signer_generated_headers])
 
+  def _make_adls_request(self, request_data, path, resource_path):
+    storage_account = path.netloc.split('.')[0]
+    container, relative_path = resource_path.split('/', 1)
+
+    request_data.update({
+      "clientType": "adls",
+      "operation": {
+        "resource": {
+          "storageaccount": storage_account,
+          "container": container,
+          "relativepath": relative_path,
+        },
+        "resourceOwner": "",
+        "action": "read",
+        "accessTypes":["read"]
+      }
+    })
+
+  def _make_s3_request(self, request_data, request_headers, method, params, headers, url_params, endpoint, resource_path):
+
+    allparams = [raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in url_params.items()]
+    allparams.extend([raz_signer.StringListStringMapProto(key=key, value=[val]) for key, val in params.items()])
+    headers = [raz_signer.StringStringMapProto(key=key, value=val) for key, val in headers.items()]
+
+    LOG.debug(
+      "Preparing sign request with http_method: {%s}, headers: {%s}, parameters: {%s}, endpoint: {%s}, resource_path: {%s}" %
+      (method, headers, allparams, endpoint, resource_path)
+    )
+    raz_req = raz_signer.SignRequestProto(
+        endpoint_prefix=self.service_params['endpoint_prefix'],
+        service_name=self.service_params['service_name'],
+        endpoint=endpoint,
+        http_method=method,
+        headers=headers,
+        parameters=allparams,
+        resource_path=resource_path,
+        time_offset=0
+    )
+    raz_req_serialized = raz_req.SerializeToString()
+    signed_request = base64.b64encode(raz_req_serialized)
+
+    request_headers["Accept-Encoding"] = {"gzip,deflate"}
+    request_data["context"] = {
+      "S3_SIGN_REQUEST": signed_request
+    }
 
 def get_raz_client(raz_url, username, auth='kerberos', service='s3', service_name='cm_s3', cluster_name='myCluster'):
   if not username:

+ 83 - 21
desktop/core/src/desktop/lib/raz/raz_client_test.py

@@ -81,10 +81,9 @@ class RazTokenTest(unittest.TestCase):
 
           t = token.renew_delegation_token(user=self.username)
 
-          assert_equal(
+          assert_equal(t,
             'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv?AWSAccessKeyId=AKIA23E77ZX2HVY76YGL&'
-            'Signature=3lhK%2BwtQ9Q2u5VDIqb4MEpoY3X4%3D&Expires=1617207304',
-            t
+            'Signature=3lhK%2BwtQ9Q2u5VDIqb4MEpoY3X4%3D&Expires=1617207304'
           )
 
           with patch('desktop.lib.raz.raz_client.requests.put') as requests_put:
@@ -100,36 +99,103 @@ class RazClientTest(unittest.TestCase):
   def setUp(self):
     self.username = 'gethue'
     self.raz_url = 'https://raz.gethue.com:8080'
-    self.resource_url = 'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv'
 
-  def test_get_raz_client(self):
+    self.s3_path = 'https://gethue-test.s3.amazonaws.com/gethue/data/customer.csv'
+    self.adls_path = 'https://gethuestorageaccount.blob.core.windows.net/demo-gethue-container/demo-dir1/customer.csv'
 
+  def test_get_raz_client_adls(self):
     with patch('desktop.lib.raz.raz_client.RazToken') as RazToken:
       with patch('desktop.lib.raz.raz_client.requests_kerberos.HTTPKerberosAuth') as HTTPKerberosAuth:
         client = get_raz_client(
           raz_url=self.raz_url,
           username=self.username,
           auth='kerberos',
-          service='s3',
-          service_name='gethue_s3',
+          service='adls',
+          service_name='gethue_adls',
           cluster_name='gethueCluster'
         )
 
         assert_true(isinstance(client, RazClient))
 
         HTTPKerberosAuth.assert_called()
-        assert_equal(
-          client.raz_url, self.raz_url
+        assert_equal(client.raz_url, self.raz_url)
+        assert_equal(client.service_name, 'gethue_adls')
+        assert_equal(client.cluster_name, 'gethueCluster')
+
+  def test_check_access_adls(self):
+    with patch('desktop.lib.raz.raz_client.requests.post') as requests_post:
+      with patch('desktop.lib.raz.raz_client.uuid.uuid4') as uuid:
+        raz_token = "mock_RAZ_token"
+
+        requests_post.return_value = Mock(
+          json=Mock(return_value=
+          {
+            'operResult': {
+              'result': 'ALLOWED',
+              'additionalInfo': {
+                "ADLS_DSAS": "nulltenantIdnullnullbnullALLOWEDnullnull1.05nSlN7t/QiPJ1OFlCruTEPLibFbAhEYYj5wbJuaeQqs="
+                }
+              }
+            }
+          )
         )
-        assert_equal(
-          client.service_name, 'gethue_s3'
+        uuid.return_value = 'mock_request_id'
+
+        client = RazClient(self.raz_url, raz_token, username=self.username, service="adls", service_name="adls", cluster_name="cl1")
+
+        resp = client.check_access(method='GET', url=self.adls_path)
+
+        requests_post.assert_called_with(
+          "https://raz.gethue.com:8080/api/authz/adls/access?delegation=" + raz_token,
+          headers={"Content-Type": "application/json"},
+          json={
+            'requestId': 'mock_request_id', 
+            'serviceType': 'adls', 
+            'serviceName': 'adls', 
+            'user': 'gethue', 
+            'userGroups': [], 
+            'clientIpAddress': '', 
+            'clientType': 'adls', 
+            'clusterName': 'cl1', 
+            'clusterType': '', 
+            'sessionId': '', 
+            'accessTime': '', 
+            'context': {}, 
+            'operation': {
+              'resource': {
+                'storageaccount': 'gethuestorageaccount', 
+                'container': 'demo-gethue-container', 
+                'relativepath': 'demo-dir1/customer.csv'
+              }, 
+              'resourceOwner': '', 
+              'action': 'read', 
+              'accessTypes': ['read']
+            }
+          },
+          verify=False
         )
-        assert_equal(
-          client.cluster_name, 'gethueCluster'
+        assert_equal(resp['token'], "nulltenantIdnullnullbnullALLOWEDnullnull1.05nSlN7t/QiPJ1OFlCruTEPLibFbAhEYYj5wbJuaeQqs=")
+
+  def test_get_raz_client_s3(self):
+    with patch('desktop.lib.raz.raz_client.RazToken') as RazToken:
+      with patch('desktop.lib.raz.raz_client.requests_kerberos.HTTPKerberosAuth') as HTTPKerberosAuth:
+        client = get_raz_client(
+          raz_url=self.raz_url,
+          username=self.username,
+          auth='kerberos',
+          service='s3',
+          service_name='gethue_s3',
+          cluster_name='gethueCluster'
         )
 
+        assert_true(isinstance(client, RazClient))
+
+        HTTPKerberosAuth.assert_called()
+        assert_equal(client.raz_url, self.raz_url)
+        assert_equal(client.service_name, 'gethue_s3')
+        assert_equal(client.cluster_name, 'gethueCluster')
 
-  def test_check_access(self):
+  def test_check_access_s3(self):
     raz_token = Mock()
 
     client = RazClient(self.raz_url, raz_token, username=self.username)
@@ -162,11 +228,7 @@ class RazClientTest(unittest.TestCase):
             )
           )
 
-          resp = client.check_access(method='GET', url=self.resource_url)
+          resp = client.check_access(method='GET', url=self.s3_path)
 
-          assert_true(
-            resp
-          )
-          assert_equal(
-            resp['AWSAccessKeyId'], 'AKIA23E77ZX2HVY76YGL'
-          )
+          assert_true(resp)
+          assert_equal(resp['AWSAccessKeyId'], 'AKIA23E77ZX2HVY76YGL')