Procházet zdrojové kódy

[abfs] Hook-in Raz client into HttpClient

Romain Rigaux před 4 roky
rodič
revize
dccc557608

+ 11 - 14
desktop/core/src/desktop/lib/raz/clients.py

@@ -16,12 +16,8 @@
 
 import logging
 
-from requests_kerberos import HTTPKerberosAuth
-
 from desktop.conf import RAZ
 from desktop.lib.raz.raz_client import get_raz_client
-from desktop.lib.raz.ranger.clients.ranger_raz_adls import RangerRazAdls
-from desktop.lib.raz.ranger.clients.ranger_raz_s3 import RangerRazS3
 
 
 LOG = logging.getLogger(__name__)
@@ -32,7 +28,7 @@ class S3RazClient():
   def __init__(self, username):
     self.username = username
 
-  def get_url(self, action='GET', path=None, headers=None, perm='read'):
+  def get_url(self, action='GET', path=None, headers=None):
     '''
     Example of headers:
     {
@@ -56,14 +52,15 @@ class S3RazClient():
 
 class AdlsRazClient():
 
-  def __init__(self):
-    if RAZ.API_AUTHENTICATION.get() == 'kerberos':
-      auth = HTTPKerberosAuth()
-    else:
-      auth = None
+  def __init__(self, username):
+    self.username = username
 
-    self.ranger = RangerRazAdls(RAZ.API_URL.get(), auth)
+  def get_url(self, action='GET', path=None, headers=None):
+    c = get_raz_client(
+      raz_url=RAZ.API_URL.get(),
+      username=self.username,
+      auth=RAZ.API_AUTHENTICATION.get(),
+      service='adls',
+    )
 
-  def get_url(self, storage_account, container, relative_path, perm='read'):
-    # e.g. get_url('<storage_account>', '<container>', '<relative_path>', 'read')
-    return self.ranger.get_dsas_token(storage_account, container, relative_path, perm)
+    return c.check_access(method=action, url=path, headers=headers)

+ 23 - 11
desktop/core/src/desktop/lib/raz/raz_client.py

@@ -75,7 +75,14 @@ class RazClient(object):
     self.raz_url = raz_url.strip('/')
     self.raz_token = raz_token
     self.username = username
-    if service == 's3' or True:  # True until ABFS option
+    self.service = service
+    if self.service == 'adls':
+      self.service_params = {
+        'endpoint_prefix': 'adls',
+        'service_name': 'adls',
+        'serviceType': 'adls'
+      }
+    else:
       self.service_params = {
         'endpoint_prefix': 's3',
         'service_name': 's3',
@@ -139,7 +146,7 @@ class RazClient(object):
     LOG.debug("Sending access check headers: {%s} request_data: {%s}" % (headers, request_data))
     raz_req = requests.post(raz_url, headers=headers, json=request_data, verify=False)
 
-    s3_sign_response = None
+    signed_response_result = None
     signed_response = None
 
     if raz_req.ok:
@@ -156,20 +163,25 @@ class RazClient(object):
 
       if result == "ALLOWED":
         LOG.debug('Received allowed response %s' % raz_req.json())
-        s3_sign_response = raz_req.json()["operResult"]["additionalInfo"]["S3_SIGN_RESPONSE"]
+        signed_response_data = raz_req.json()["operResult"]["additionalInfo"]
+        if self.service == 'adls':
+          LOG.debug("Received SAS %s" % signed_response_data["ADLS_DSAS"])
+          return {'token': signed_response_data["ADLS_DSAS"]}
+        else:
+          signed_response_result = signed_response_data["S3_SIGN_RESPONSE"]
 
-      if s3_sign_response:
-        raz_response_proto = raz_signer.SignResponseProto()
-        signed_response = raz_response_proto.FromString(base64.b64decode(s3_sign_response))
-        LOG.debug("Received signed Response %s" % signed_response)
+          if signed_response_result:
+            raz_response_proto = raz_signer.SignResponseProto()
+            signed_response = raz_response_proto.FromString(base64.b64decode(signed_response_result))
+            LOG.debug("Received signed Response %s" % signed_response)
 
-      # Currently returning signed headers "only"
-      if signed_response:
-        return dict([(i.key, i.value) for i in signed_response.signer_generated_headers])
+          # Signed headers "only"
+          if signed_response:
+            return dict([(i.key, i.value) for i in signed_response.signer_generated_headers])
 
 
 def get_raz_client(raz_url, username, auth='kerberos', service='s3', service_name='cm_s3', cluster_name='myCluster'):
-  if auth == 'kerberos' or True:  # True until ABFS option
+  if auth == 'kerberos' or True:  # True until JWT option
     auth_handler = requests_kerberos.HTTPKerberosAuth(mutual_authentication=requests_kerberos.OPTIONAL)
 
   raz = RazToken(raz_url, auth_handler)

+ 29 - 7
desktop/core/src/desktop/lib/rest/raz_http_client.py

@@ -33,18 +33,40 @@ LOG = logging.getLogger(__name__)
 
 class RazHttpClient(HttpClient):
 
+  def __init__(self, username, base_url, exc_class=None, logger=None):
+    super(RazHttpClient, self).__init__(base_url, exc_class, logger)
+    self.username = username
+
   def execute(self, http_method, path, params=None, data=None, headers=None, allow_redirects=False, urlencode=True,
               files=None, stream=False, clear_cookies=False, timeout=conf.REST_CONN_TIMEOUT.get()):
-
-    raz_client = AdlsRazClient()
+    """
+    From an object URL we get back the SAS token as a GET param string, e.g.:
+    https://[storageaccountname].blob.core.windows.net/[containername]/[blobname]
+    -->
+    https://[storageaccountname].blob.core.windows.net/[containername]/[blobname]?sv=2014-02-14&sr=b&
+    sig=pJL%2FWyed41tptiwBM5ymYre4qF8wzrO05tS5MCjkutc%3D&st=2015-01-02T01%3A40%3A51Z&se=2015-01-02T02%3A00%3A51Z&sp=r
+    """
+    raz_client = AdlsRazClient(username=self.username)
 
     container = 'hue'
     storage_account = 'gethue.dfs.core.windows.net'
 
-    # https://[storageaccountname].blob.core.windows.net/[containername]/[blobname]?sv=2014-02-14&sr=b&
-    #   sig=pJL%2FWyed41tptiwBM5ymYre4qF8wzrO05tS5MCjkutc%3D&st=2015-01-02T01%3A40%3A51Z&se=2015-01-02T02%3A00%3A51Z&sp=r
-    tmp_url = raz_client.get_url(storage_account, container, relative_path=path, perm='read')
+    url = self._make_url(path, params)
+
+    token = raz_client.get_url(action=http_method, path=url, headers=headers)
 
-    # TODO: get clean `path` etc
+    signed_path = path + ('?' if '?' in url else '&') + token
 
-    return super(RazHttpClient, self).execute(http_method=http_method, path=tmp_url)
+    return super(RazHttpClient, self).execute(
+        http_method=http_method,
+        path=signed_path,
+        params=params,
+        data=data,
+        headers=headers,
+        allow_redirects=allow_redirects,
+        urlencode=urlencode,
+        files=files,
+        stream=stream,
+        clear_cookies=clear_cookies,
+        timeout=timeout
+    )

+ 10 - 8
desktop/libs/azure/src/azure/abfs/abfs.py

@@ -76,7 +76,8 @@ class ABFS(object):
       hdfs_supergroup=None,
       access_token=None,
       token_type=None,
-      expiration=None
+      expiration=None,
+      username=None
     ):
     self._url = url
     self._superuser = hdfs_superuser
@@ -96,12 +97,12 @@ class ABFS(object):
     self._is_remote = True
     self._has_trash_support = False
     self._filebrowser_action = PERMISSION_ACTION_ABFS
-
     self.expiration = expiration
-    self._root = self.get_client(url)
+    self._user = username
 
     # To store user info
-    self._thread_local = threading.local()
+    self._thread_local = threading.local()  # Unused
+    self._root = self.get_client(url)
 
     LOG.debug("Initializing ABFS : %s (security: %s, superuser: %s)" % (self._url, self._security_enabled, self._superuser))
 
@@ -119,12 +120,13 @@ class ABFS(object):
         hdfs_supergroup=None,
         access_token=credentials.get('access_token'),
         token_type=credentials.get('token_type'),
-        expiration=int(credentials.get('expires_on')) * 1000 if credentials.get('expires_on') is not None else None
+        expiration=int(credentials.get('expires_on')) * 1000 if credentials.get('expires_on') is not None else None,
+        username=credentials.get('username')
     )
 
   def get_client(self, url):
     if RAZ.IS_ENABLED.get():
-      client = RazHttpClient(url, exc_class=WebHdfsException, logger=LOG)
+      client = RazHttpClient(self._user, url, exc_class=WebHdfsException, logger=LOG)
     else:
       client = http_client.HttpClient(url, exc_class=WebHdfsException, logger=LOG)
 
@@ -192,7 +194,7 @@ class ABFS(object):
       raise IOError
     if dir_name == '':
       return ABFSStat.for_filesystem(self._statsf(file_system, params, **kwargs), path)
-    return ABFSStat.for_single(self._stats(file_system + '/' +dir_name, params, **kwargs), path)
+    return ABFSStat.for_single(self._stats(file_system + '/' + dir_name, params, **kwargs), path)
 
   def listdir_stats(self,path, params=None, **kwargs):
     """
@@ -221,7 +223,7 @@ class ABFS(object):
       dir_stats.append(ABFSStat.for_directory(res.headers, x, root + file_system + "/" + x['name']))
     return dir_stats
 
-  def listfilesystems_stats(self, root = Init_ABFS.ABFS_ROOT, params=None, **kwargs):
+  def listfilesystems_stats(self, root=Init_ABFS.ABFS_ROOT, params=None, **kwargs):
     """
     Lists the stats inside the File Systems, No functionality for params
     """

+ 5 - 2
desktop/libs/azure/src/azure/client.py

@@ -44,7 +44,7 @@ def _make_abfs_client(identifier, user):
 def get_credential_provider(identifier, user, version=None):
   from desktop.conf import RAZ
   if RAZ.IS_ENABLED.get():
-    return RazCredentialProvider()
+    return RazCredentialProvider(username=user)
   else:
     client_conf = conf.AZURE_ACCOUNTS[identifier] if identifier in conf.AZURE_ACCOUNTS else None
     return CredentialProviderIDBroker(IDBroker.from_core_site('azure', user)) if conf_idbroker.is_idbroker_enabled('azure') \
@@ -66,6 +66,9 @@ class CredentialProviderIDBroker(object):
     return self.idbroker.get_cab()
 
 class RazCredentialProvider(object):
+  def __init__(self, username):
+    self.username = username
+
   def get_credentials(self):
     # No credentials are required
-    return {}
+    return {'username': self.username}