소스 검색

HUE-9073 [fb] Add IDBroker support for azure

Jean-Francois Desjeans Gauthier 6 년 전
부모
커밋
b8119cd166

+ 4 - 8
apps/filebrowser/src/filebrowser/views_test.py

@@ -1247,6 +1247,8 @@ class TestS3AccessPermissions(object):
 class TestABFSAccessPermissions(object):
 
   def setUp(self):
+    if not is_abfs_enabled():
+      raise SkipTest
     self.client = make_logged_in_client(username="test", groupname="default", recreate=True, is_superuser=False)
     grant_access('test', 'test', 'filebrowser')
     add_to_group('test')
@@ -1254,8 +1256,6 @@ class TestABFSAccessPermissions(object):
     self.user = User.objects.get(username="test")
 
   def test_no_default_permissions(self):
-    if not is_abfs_enabled():
-      raise SkipTest
     response = self.client.get('/filebrowser/view=ABFS://')
     assert_equal(500, response.status_code)
 
@@ -1266,8 +1266,6 @@ class TestABFSAccessPermissions(object):
 #       assert_raises(S3FileSystemException, self.client.post, '/filebrowser/upload/file?dest=%s' % DEST_DIR, dict(dest=DEST_DIR, hdfs_file=file(LOCAL_FILE)))
 
   def test_has_default_permissions(self):
-    if not is_abfs_enabled():
-      raise SkipTest
     add_permission(self.user.username, 'has_abfs', permname='abfs_access', appname='filebrowser')
 
     try:
@@ -1279,6 +1277,8 @@ class TestABFSAccessPermissions(object):
 class TestADLSAccessPermissions(object):
 
   def setUp(self):
+    if not is_adls_enabled():
+      raise SkipTest
     self.client = make_logged_in_client(username="test", groupname="default", recreate=True, is_superuser=False)
     grant_access('test', 'test', 'filebrowser')
     add_to_group('test')
@@ -1286,8 +1286,6 @@ class TestADLSAccessPermissions(object):
     self.user = User.objects.get(username="test")
 
   def test_no_default_permissions(self):
-    if not is_adls_enabled():
-      raise SkipTest
     response = self.client.get('/filebrowser/view=ADL://')
     assert_equal(500, response.status_code)
 
@@ -1313,8 +1311,6 @@ class TestADLSAccessPermissions(object):
 #       assert_raises(S3FileSystemException, self.client.post, '/filebrowser/upload/file?dest=%s' % DEST_DIR, dict(dest=DEST_DIR, hdfs_file=file(LOCAL_FILE)))
 
   def test_has_default_permissions(self):
-    if not is_adls_enabled():
-      raise SkipTest
     add_permission(self.user.username, 'has_adls', permname='adls_access', appname='filebrowser')
 
     try:

+ 2 - 42
desktop/core/src/desktop/lib/fs/gc/client.py

@@ -16,7 +16,6 @@
 from __future__ import absolute_import
 
 import boto
-import datetime
 import logging
 import gcs_oauth2_boto_plugin
 import json
@@ -28,55 +27,16 @@ from boto.provider import Provider
 from boto.s3.connection import SubdomainCallingFormat
 
 from desktop import conf
-from desktop.conf import DEFAULT_USER
 from desktop.lib.idbroker import conf as conf_idbroker
 from desktop.lib.idbroker.client import IDBroker
 
 LOG = logging.getLogger(__name__)
 
-CLIENT_CACHE = None
-
-_DEFAULT_USER = DEFAULT_USER.get()
-
-# FIXME: Should we check hue principal for the default user?
-def _get_cache_key(identifier='default', user=_DEFAULT_USER): # FIXME: Caching via username has issues when users get deleted. Need to switch to userid, but bigger change
-  return identifier + ':' + user
-
-
-def clear_cache():
-  global CLIENT_CACHE
-  CLIENT_CACHE = None
-
-
-def current_ms_from_utc():
-  return (datetime.datetime.utcnow() - datetime.datetime.utcfromtimestamp(0)).total_seconds() * 1000
-
-
-def get_client(identifier='default', user=_DEFAULT_USER):
-  global CLIENT_CACHE
-  _init_clients()
-
-  cache_key = _get_cache_key(identifier, user) if conf_idbroker.is_idbroker_enabled('gs') else _get_cache_key(identifier) # We don't want to cache by username when IDBroker not enabled
-  client = CLIENT_CACHE.get(cache_key)
-
-  if client and (client.expiration is None or client.expiration > int(current_ms_from_utc())): # expiration from IDBroker returns java timestamp in MS
-    return client
-  else:
-    client = _make_client(identifier, user)
-    CLIENT_CACHE[cache_key] = client
-    return client
-
-def get_credential_provider(config=None, user=_DEFAULT_USER):
+def get_credential_provider(config, user):
   return CredentialProviderIDBroker(IDBroker.from_core_site('gs', user)) if conf_idbroker.is_idbroker_enabled('gs') else CredentialProviderConf(config)
 
 
-def _init_clients():
-  global CLIENT_CACHE
-  if CLIENT_CACHE is not None:
-    return
-  CLIENT_CACHE = {} # Can't convert this to django cache, because S3FileSystem is not pickable
-
-def _make_client(identifier, user=_DEFAULT_USER):
+def _make_client(identifier, user):
   config = conf.GC_ACCOUNTS[identifier] if identifier in list(conf.GC_ACCOUNTS.keys()) else None
   client = Client.from_config(config, get_credential_provider(config, user))
   return S3FileSystem(client.get_s3_connection(), client.expiration, headers={"x-goog-project-id": client.project}, filebrowser_action=conf.PERMISSION_ACTION_GS) # It would be nice if the connection is lazy loaded

+ 3 - 3
desktop/core/src/desktop/lib/fs/gc/tests.py

@@ -23,8 +23,8 @@ from nose.plugins.skip import SkipTest
 from nose.tools import assert_equal, assert_true, assert_not_equal
 
 from desktop.conf import is_gs_enabled
-from desktop.lib.fs.gc.client import get_client
 
+from desktop.lib.fsmanager import get_client
 
 LOG = logging.getLogger(__name__)
 
@@ -35,7 +35,7 @@ class TestGCS(unittest.TestCase):
       raise SkipTest('gs not enabled')
 
   def test_with_credentials(self):
-    # Simple test that makes sure no errors are thrown. 
-    client = get_client()
+    # Simple test that makes sure no errors are thrown.
+    client = get_client(fs='gs')
     buckets = client.listdir_stats('gs://')
     LOG.info(len(buckets))

+ 48 - 23
desktop/core/src/desktop/lib/fsmanager.py

@@ -17,24 +17,38 @@
 
 from __future__ import absolute_import
 
+from functools import partial
 import logging
 
 import aws.client
 import azure.client
+import desktop.lib.fs.gc.client
 
 from aws.conf import is_enabled as is_s3_enabled, has_s3_access
 from azure.conf import is_adls_enabled, is_abfs_enabled, has_adls_access, has_abfs_access
 
+
+from desktop.conf import is_gs_enabled, has_gs_access, DEFAULT_USER
+
 from desktop.lib.fs.proxyfs import ProxyFS
-from desktop.conf import is_gs_enabled, has_gs_access
-from desktop.lib.fs.gc.client import get_client as get_client_gs
+from desktop.lib.python_util import current_ms_from_utc
+from desktop.lib.idbroker import conf as conf_idbroker
 
-from hadoop.cluster import get_hdfs
+from hadoop.cluster import get_hdfs, _make_filesystem
 from hadoop.conf import has_hdfs_enabled
 
 
 SUPPORTED_FS = ['hdfs', 's3a', 'adl', 'abfs', 'gs']
+CLIENT_CACHE = None
+_DEFAULT_USER = DEFAULT_USER.get()
+
+# FIXME: Should we check hue principal for the default user?
+def _get_cache_key(fs, identifier, user=_DEFAULT_USER): # FIXME: Caching via username has issues when users get deleted. Need to switch to userid, but bigger change
+  return fs + ':' + identifier + ':' + user
 
+def clear_cache():
+  global CLIENT_CACHE
+  CLIENT_CACHE = None
 
 def has_access(fs=None, user=None):
   if fs == 'hdfs':
@@ -66,21 +80,44 @@ def is_enabled_and_has_access(fs=None, user=None):
   return is_enabled(fs) and has_access(fs, user)
 
 
-def _get_client(fs=None):
+def _make_client(fs, name, user):
   if fs == 'hdfs':
-    return get_hdfs
+    return _make_filesystem(name)
   elif fs == 's3a':
-    return aws.client.get_client
+    return aws.client._make_client(name, user)
   elif fs == 'adl':
-    return azure.client.get_client
+    return azure.client._make_adls_client(name, user)
   elif fs == 'abfs':
-    return azure.client.get_client_abfs
+    return azure.client._make_abfs_client(name, user)
   elif fs == 'gs':
-    return get_client_gs
+    return desktop.lib.fs.gc.client._make_client(name, user)
+  return None
+
+
+def _get_client(fs=None):
+  if fs == 'hdfs':
+    return get_hdfs
+  elif fs in ['s3a', 'adl', 'abfs', 'gs']:
+    return partial(_get_client_cached, fs)
   return None
 
 
-def get_client(name='default', fs=None, user=None):
+def _get_client_cached(fs, name, user):
+  global CLIENT_CACHE
+  if CLIENT_CACHE is None:
+    CLIENT_CACHE = {}
+  cache_key = _get_cache_key(fs, name, user) if conf_idbroker.is_idbroker_enabled(fs) else _get_cache_key(fs, name) # We don't want to cache by username when IDBroker not enabled
+  client = CLIENT_CACHE.get(cache_key)
+
+  if client and (client.expiration is None or client.expiration > int(current_ms_from_utc())): # expiration from IDBroker returns java timestamp in MS
+    return client
+  else:
+    client = _make_client(fs, name, user)
+    CLIENT_CACHE[cache_key] = client
+    return client
+
+
+def get_client(name='default', fs=None, user=_DEFAULT_USER):
   fs_getter = _get_client(fs)
   if fs_getter:
     return fs_getter(name, user)
@@ -109,16 +146,4 @@ def get_filesystem(name='default'):
 
 
 def get_filesystems(user):
-  return [fs for fs in SUPPORTED_FS if is_enabled(fs) and has_access(fs, user)]
-
-
-def _get_client(fs=None):
-  if fs == 'hdfs':
-    return get_hdfs
-  elif fs == 's3a':
-    return aws.client.get_client
-  elif fs == 'adl':
-    return azure.client.get_client
-  elif fs == 'abfs':
-    return azure.client.get_client_abfs
-  return None
+  return [fs for fs in SUPPORTED_FS if is_enabled(fs) and has_access(fs, user)]

+ 14 - 13
desktop/core/src/desktop/lib/idbroker/conf.py

@@ -28,33 +28,34 @@ _CNF_CAB_ADDRESS_DT_PATH='fs.%s.ext.cab.dt.path' # dt
 _CNF_CAB_ADDRESS_PATH='fs.%s.ext.cab.path' # aws-cab
 _CNF_CAB_USERNAME='fs.%s.ext.cab.username' # when not using kerberos
 _CNF_CAB_PASSWORD='fs.%s.ext.cab.password'
-SUPPORTED_FS = ['s3a', 'azure', 'gs']
+SUPPORTED_FS = {'s3a': 's3a', 'adl': 'azure', 'abfs': 'azure', 'azure': 'azure', 'gs': 'gs'}
 
 def validate_fs(fs=None):
   if fs in SUPPORTED_FS:
-    return True
+    return SUPPORTED_FS[fs]
   else:
-    raise ValueError('Selected FS %s is not supported by Hue IDBroker client' % fs)
+    LOG.warn('Selected FS %s is not supported by Hue IDBroker client' % fs)
+    return None
 
 def get_cab_address(fs=None):
-  validate_fs(fs)
-  return get_conf().get(_CNF_CAB_ADDRESS % fs)
+  fs = validate_fs(fs)
+  return get_conf().get(_CNF_CAB_ADDRESS % fs) if fs else None
 
 def get_cab_dt_path(fs=None):
-  validate_fs(fs)
-  return get_conf().get(_CNF_CAB_ADDRESS_DT_PATH % fs)
+  fs = validate_fs(fs)
+  return get_conf().get(_CNF_CAB_ADDRESS_DT_PATH % fs) if fs else None
 
 def get_cab_path(fs=None):
-  validate_fs(fs)
-  return get_conf().get(_CNF_CAB_ADDRESS_PATH % fs)
+  fs = validate_fs(fs)
+  return get_conf().get(_CNF_CAB_ADDRESS_PATH % fs) if fs else None
 
 def get_cab_username(fs=None):
-  validate_fs(fs)
-  return get_conf().get(_CNF_CAB_USERNAME % fs)
+  fs = validate_fs(fs)
+  return get_conf().get(_CNF_CAB_USERNAME % fs) if fs else None
 
 def get_cab_password(fs=None):
-  validate_fs(fs)
-  return get_conf().get(_CNF_CAB_PASSWORD % fs)
+  fs = validate_fs(fs)
+  return get_conf().get(_CNF_CAB_PASSWORD % fs) if fs else None
 
 def is_idbroker_enabled(fs=None):
   return get_cab_address(fs) is not None

+ 2 - 36
desktop/libs/aws/src/aws/client.py

@@ -17,7 +17,6 @@ from __future__ import absolute_import
 
 from builtins import str
 from builtins import object
-import datetime
 import logging
 import os
 
@@ -27,7 +26,6 @@ from aws import conf as aws_conf
 from aws.s3.s3fs import S3FileSystemException
 from aws.s3.s3fs import S3FileSystem
 
-from desktop.conf import DEFAULT_USER
 from desktop.lib.idbroker import conf as conf_idbroker
 from desktop.lib.idbroker.client import IDBroker
 
@@ -35,45 +33,13 @@ LOG = logging.getLogger(__name__)
 
 HTTP_SOCKET_TIMEOUT_S = 60
 
-CLIENT_CACHE = None
 
-_DEFAULT_USER = DEFAULT_USER.get()
-
-# FIXME: Should we check hue principal for the default user?
-def _get_cache_key(identifier='default', user=_DEFAULT_USER): # FIXME: Caching via username has issues when users get deleted. Need to switch to userid, but bigger change
-  return identifier + ':' + user
-
-
-def clear_cache():
-  global CLIENT_CACHE
-  CLIENT_CACHE = None
-
-
-def current_ms_from_utc():
-  return (datetime.datetime.utcnow() - datetime.datetime.utcfromtimestamp(0)).total_seconds() * 1000
-
-
-def get_client(identifier='default', user=_DEFAULT_USER):
-  global CLIENT_CACHE
-  if not CLIENT_CACHE:
-    CLIENT_CACHE = {}
-
-  cache_key = _get_cache_key(identifier, user) if conf_idbroker.is_idbroker_enabled('s3a') else _get_cache_key(identifier) # We don't want to cache by username when IDBroker not enabled
-  client = CLIENT_CACHE.get(cache_key)
-
-  if client and (client.expiration is None or client.expiration > int(current_ms_from_utc())): # expiration from IDBroker returns java timestamp in MS
-    return client
-  else:
-    client = _make_client(identifier, user)
-    CLIENT_CACHE[cache_key] = client
-    return client
-
-def get_credential_provider(identifier='default', user=_DEFAULT_USER):
+def get_credential_provider(identifier, user):
   client_conf = aws_conf.AWS_ACCOUNTS[identifier] if identifier in aws_conf.AWS_ACCOUNTS else None
   return CredentialProviderIDBroker(IDBroker.from_core_site('s3a', user)) if conf_idbroker.is_idbroker_enabled('s3a') else CredentialProviderConf(client_conf)
 
 
-def _make_client(identifier, user=_DEFAULT_USER):
+def _make_client(identifier, user):
   client_conf = aws_conf.AWS_ACCOUNTS[identifier] if identifier in aws_conf.AWS_ACCOUNTS else None
 
   client = Client.from_config(client_conf, get_credential_provider(identifier, user))

+ 2 - 2
desktop/libs/aws/src/aws/conf.py

@@ -289,10 +289,10 @@ def has_s3_access(user):
 
 def config_validator(user):
   res = []
-  from aws.client import get_client # Circular dependecy
+  import desktop.lib.fsmanager # Circular dependecy
   if is_enabled():
     try:
-      conn = get_client('default')._s3_connection
+      conn = desktop.lib.fsmanager.get_client(name='default', fs='s3a')._s3_connection
       conn.get_canonical_user_id()
     except Exception as e:
       LOG.exception('AWS failed configuration check.')

+ 34 - 11
desktop/libs/aws/src/aws/tests.py

@@ -22,8 +22,10 @@ from mock import patch, Mock
 from nose.tools import assert_equal, assert_true, assert_not_equal
 
 from aws import conf
-from aws.client import clear_cache, Client, get_client, get_credential_provider, current_ms_from_utc
+from aws.client import Client, get_credential_provider
 
+from desktop.lib.fsmanager import get_client, clear_cache
+from desktop.lib.python_util import current_ms_from_utc
 
 LOG = logging.getLogger(__name__)
 
@@ -35,10 +37,10 @@ class TestAWS(unittest.TestCase):
       with patch('aws.client.conf_idbroker.get_conf') as get_conf:
         with patch('aws.client.Client.get_s3_connection'):
           get_conf.return_value = {}
-          client1 = get_client('default')
-          client2 = get_client('default', 'test')
+          client1 = get_client(name='default', fs='s3a')
+          client2 = get_client(name='default', fs='s3a', user='test')
 
-          provider = get_credential_provider()
+          provider = get_credential_provider('default', 'hue')
           assert_equal(provider.get_credentials().get('AccessKeyId'), conf.AWS_ACCOUNTS['default'].ACCESS_KEY_ID.get())
           assert_equal(client1, client2) # Should be the same as no support for user based client with credentials & no Expiration
     finally:
@@ -58,18 +60,18 @@ class TestAWS(unittest.TestCase):
             get_cab.return_value = {
               'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
             }
-            provider = get_credential_provider()
+            provider = get_credential_provider('default', 'hue')
             assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId')
-            client1 = get_client('default', 'HUE')
-            client2 = get_client('default', 'HUE')
+            client1 = get_client(name='default', fs='s3a', user='hue')
+            client2 = get_client(name='default', fs='s3a', user='hue')
             assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
 
             get_cab.return_value = {
               'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': int(current_ms_from_utc()) + 10*1000}
             }
-            client3 = get_client('default', 'HUE')
-            client4 = get_client('default', 'HUE')
-            client5 = get_client('default', 'test')
+            client3 = get_client(name='default', fs='s3a', user='hue')
+            client4 = get_client(name='default', fs='s3a', user='hue')
+            client5 = get_client(name='default', fs='s3a', user='test')
             assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
             assert_not_equal(client4, client5) # Test different user have different clients
     finally:
@@ -89,7 +91,7 @@ class TestAWS(unittest.TestCase):
             get_cab.return_value = {
               'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
             }
-            provider = get_credential_provider()
+            provider = get_credential_provider('default', 'hue')
             assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId')
 
             client = Client.from_config(conf.AWS_ACCOUNTS['default'], get_credential_provider('default', 'hue'))
@@ -98,3 +100,24 @@ class TestAWS(unittest.TestCase):
       finish()
       clear_cache()
       conf.clear_cache()
+
+  def test_with_idbroker_on_ec2(self):
+    try:
+      finish = conf.AWS_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set
+      with patch('aws.client.aws_conf.get_region') as get_region:
+        with patch('aws.client.conf_idbroker.get_conf') as get_conf:
+          with patch('aws.client.Client.get_s3_connection'):
+            with patch('aws.client.IDBroker.get_cab') as get_cab:
+              get_region.return_value = 'us-west-1'
+              get_conf.return_value = {
+                'fs.s3a.ext.cab.address': 'address'
+              }
+              get_cab.return_value = {
+                'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
+              }
+              client = Client.from_config(None, get_credential_provider('default', 'hue'))
+              assert_equal(client._region, 'us-west-1') # Test different user have different clients
+    finally:
+      finish()
+      clear_cache()
+      conf.clear_cache()

+ 16 - 7
desktop/libs/azure/src/azure/abfs/abfs.py

@@ -70,7 +70,9 @@ class ABFS(object):
                temp_dir="/tmp",
                umask=0o1022,
                hdfs_supergroup=None,
-               auth_provider=None):
+               access_token=None,
+               token_type=None,
+               expiration=None):
     self._url = url
     self._superuser = hdfs_superuser
     self._security_enabled = security_enabled
@@ -80,7 +82,8 @@ class ABFS(object):
     self._fs_defaultfs = fs_defaultfs
     self._logical_name = logical_name
     self._supergroup = hdfs_supergroup
-    self._auth_provider = auth_provider
+    self._access_token = access_token
+    self._token_type = token_type
     split = lib_urlparse(fs_defaultfs)
     self._scheme = split.scheme
     self._netloc = split.netloc
@@ -88,8 +91,8 @@ class ABFS(object):
     self._has_trash_support = False
     self._filebrowser_action = PERMISSION_ACTION_ABFS
 
-    self._client = http_client.HttpClient(url, exc_class=WebHdfsException, logger=LOG)
-    self._root = resource.Resource(self._client)
+    self.expiration = expiration
+    self._root = self.get_client(url)
 
     # To store user info
     self._thread_local = threading.local()
@@ -98,6 +101,7 @@ class ABFS(object):
 
   @classmethod
   def from_config(cls, hdfs_config, auth_provider):
+    credentials = auth_provider.get_credentials()
     return cls(url=hdfs_config.WEBHDFS_URL.get(),
                fs_defaultfs=hdfs_config.FS_DEFAULTFS.get(),
                logical_name=None,
@@ -106,14 +110,19 @@ class ABFS(object):
                temp_dir=None,
                umask=get_umask_mode(),
                hdfs_supergroup=None,
-               auth_provider=auth_provider)
+               access_token=credentials.get('access_token'),
+               token_type=credentials.get('token_type'),
+               expiration=int(credentials.get('expires_on')) * 1000 if credentials.get('expires_on') is not None else None)
+
+  def get_client(self, url):
+    return resource.Resource(http_client.HttpClient(url, exc_class=WebHdfsException, logger=LOG))
 
   def _getheaders(self):
     return {
-      "Authorization": self._auth_provider.get_token(),
+      "Authorization": self._token_type + " " + self._access_token,
       "x-ms-version" : "2019-02-02" #note this is required for setaccesscontrols
     }
-  
+
   # Parse info about filesystems, directories, and files
   # --------------------------------
   def isdir(self, path):

+ 6 - 7
desktop/libs/azure/src/azure/abfs/upload.py

@@ -27,7 +27,7 @@ from django.core.files.uploadedfile import SimpleUploadedFile
 from django.core.files.uploadhandler import FileUploadHandler, SkipFile, StopFutureHandlers, StopUpload, UploadFileException
 from django.utils.translation import ugettext as _
 
-from azure.client import get_client_abfs
+from desktop.lib import fsmanager
 from azure.abfs.__init__ import parse_uri
 from azure.abfs.abfs import ABFSFileSystemException
 
@@ -109,16 +109,16 @@ class ABFSFileUploadHandler(FileUploadHandler):
       return None
 
   def _get_abfs(self, request):
-    fs = get_client_abfs()
-    
+    fs = fsmanager.get_client(fs='abfs')
+
     if not fs:
       raise ABFSFileUploadError(_("No ABFS filesystem found"))
-    
+
     return fs
-  
+
   def _is_abfs_upload(self):
     return self._get_scheme() and self._get_scheme().startswith('ABFS')
-  
+
   def _get_scheme(self):
     if self.destination:
       dst_parts = self.destination.split('://')
@@ -128,4 +128,3 @@ class ABFSFileUploadHandler(FileUploadHandler):
         raise ABFSFileSystemException('Destination does not start with a valid scheme.')
     else:
       return None
-  

+ 11 - 15
desktop/libs/azure/src/azure/active_directory.py

@@ -18,8 +18,8 @@ from __future__ import absolute_import
 from builtins import object
 import logging
 
-from time import time
 from azure.conf import AZURE_ACCOUNTS, get_default_refresh_url
+from desktop.lib.python_util import current_ms_from_utc
 from desktop.lib.rest import http_client, resource
 
 LOG = logging.getLogger(__name__)
@@ -32,7 +32,6 @@ class ActiveDirectory(object):
 
     self._client = http_client.HttpClient(url, logger=LOG)
     self._root = resource.Resource(self._client)
-    self._token = None
     self._version = version
 
 
@@ -44,19 +43,16 @@ class ActiveDirectory(object):
 
 
   def _get_token(self, params=None):
-    is_token_expired = self._token is None or time() >= self._token["expires_on"]
-    if is_token_expired:
-      LOG.debug("Authenticating to Azure Active Directory: %s" % self._url)
-      data = {
-        "grant_type" : "client_credentials",
-        "client_id" : self._access_key_id,
-        "client_secret" : self._secret_access_key
-      }
-      data.update(params)
-      self._token = self._root.post("/", data=data, log_response=False);
-      self._token["expires_on"] = int(self._token.get("expires_on", self._token.get("expires_in")))
-
-    return self._token["token_type"] + " " + self._token["access_token"]
+    LOG.debug("Authenticating to Azure Active Directory: %s" % self._url)
+    data = {
+      "grant_type" : "client_credentials",
+      "client_id" : self._access_key_id,
+      "client_secret" : self._secret_access_key
+    }
+    data.update(params)
+    token = self._root.post("/", data=data, log_response=False)
+    token["expires_on"] = int(token.get("expires_on", (current_ms_from_utc() + int(token.get("expires_in")) * 1000) / 1000))
+    return token
 
 
   @classmethod

+ 15 - 6
desktop/libs/azure/src/azure/adls/webhdfs.py

@@ -50,7 +50,9 @@ class WebHdfs(HadoopWebHdfs):
                temp_dir="/tmp",
                umask=0o1022,
                hdfs_supergroup=None,
-               auth_provider=None):
+               access_token=None,
+               token_type=None,
+               expiration=None):
     self._url = url
     self._superuser = hdfs_superuser
     self._security_enabled = security_enabled
@@ -60,7 +62,8 @@ class WebHdfs(HadoopWebHdfs):
     self._fs_defaultfs = fs_defaultfs
     self._logical_name = logical_name
     self._supergroup = hdfs_supergroup
-    self._auth_provider = auth_provider
+    self._access_token = access_token
+    self._token_type = token_type
     split = urlparse(fs_defaultfs)
     self._scheme = split.scheme
     self._netloc = split.netloc
@@ -68,8 +71,8 @@ class WebHdfs(HadoopWebHdfs):
     self._has_trash_support = False
     self._filebrowser_action = PERMISSION_ACTION_ADLS
 
-    self._client = http_client.HttpClient(url, exc_class=WebHdfsException, logger=LOG)
-    self._root = resource.Resource(self._client)
+    self._root = self.get_client(url)
+    self.expiration = expiration
 
     # To store user info
     self._thread_local = threading.local()
@@ -78,6 +81,7 @@ class WebHdfs(HadoopWebHdfs):
 
   @classmethod
   def from_config(cls, hdfs_config, auth_provider):
+    credentials = auth_provider.get_credentials()
     fs_defaultfs = get_default_adls_fs()
     url = get_default_adls_url()
     return cls(url=url,
@@ -88,11 +92,16 @@ class WebHdfs(HadoopWebHdfs):
                temp_dir=None,
                umask=get_umask_mode(),
                hdfs_supergroup=None,
-               auth_provider=auth_provider)
+               access_token=credentials.get('access_token'),
+               token_type=credentials.get('token_type'),
+               expiration=int(credentials.get('expires_on')) * 1000 if credentials.get('expires_on')  is not None else None)
+
+  def get_client(self, url):
+    return resource.Resource(http_client.HttpClient(url, exc_class=WebHdfsException, logger=LOG))
 
   def _getheaders(self):
     return {
-      "Authorization": self._auth_provider.get_token(),
+      "Authorization": self._token_type + " " + self._access_token,
     }
 
   def is_web_accessible(self):

+ 22 - 41
desktop/libs/azure/src/azure/client.py

@@ -16,58 +16,39 @@
 from __future__ import absolute_import
 
 import logging
-import os
 
 from azure import conf
 from azure.adls.webhdfs import WebHdfs
 from azure.abfs.abfs import ABFS
 from azure.active_directory import ActiveDirectory
 
-LOG = logging.getLogger(__name__)
-
-CLIENT_CACHE = None
+from desktop.lib.idbroker import conf as conf_idbroker
+from desktop.lib.idbroker.client import IDBroker
 
-def get_client(identifier='default', user=None):
-  global CLIENT_CACHE
-  _init_clients()
-  if identifier not in CLIENT_CACHE["adls"]:
-    raise ValueError('Unknown azure client: %s, check your configuration' % identifier)
-  return CLIENT_CACHE["adls"][identifier]
+LOG = logging.getLogger(__name__)
 
-def get_client_abfs(identifier='default', user=None):
-  global CLIENT_CACHE
-  _init_clients()
-  if identifier not in CLIENT_CACHE["abfs"]:
-    raise ValueError('Unknown azure client: %s, check your configuration' % identifier)
-  return CLIENT_CACHE["abfs"][identifier]
+def _make_adls_client(identifier, user):
+  client_conf = conf.ADLS_CLUSTERS[identifier]
+  return WebHdfs.from_config(client_conf, get_credential_provider(identifier, user))
 
-def _init_clients():
-  global CLIENT_CACHE
-  if CLIENT_CACHE is not None:
-    return
-  CLIENT_CACHE = {}
-  CLIENT_CACHE["azure"] = {}
-  CLIENT_CACHE["adls"] = {}
-  CLIENT_CACHE["abfs"] = {}
-  for identifier in list(conf.AZURE_ACCOUNTS.keys()):
-    CLIENT_CACHE["azure"][identifier] = _make_azure_client(identifier)
+def _make_abfs_client(identifier, user):
+  client_conf = conf.ABFS_CLUSTERS[identifier]
+  return ABFS.from_config(client_conf, get_credential_provider(identifier, user, version='v2.0'))
 
-  for identifier in list(conf.ADLS_CLUSTERS.keys()):
-    CLIENT_CACHE["adls"][identifier] = _make_adls_client(identifier)
+def get_credential_provider(identifier, user, version=None):
+  client_conf = conf.AZURE_ACCOUNTS[identifier] if identifier in conf.AZURE_ACCOUNTS else None
+  return CredentialProviderIDBroker(IDBroker.from_core_site('azure', user)) if conf_idbroker.is_idbroker_enabled('azure') else CredentialProviderAD(ActiveDirectory.from_config(client_conf, version=version))
 
-  for identifier in list(conf.ABFS_CLUSTERS.keys()):
-    CLIENT_CACHE["abfs"][identifier] = _make_abfs_client(identifier)
+class CredentialProviderAD(object):
+  def __init__(self, ad):
+    self.ad=ad
 
-def _make_adls_client(identifier):
-  client_conf = conf.ADLS_CLUSTERS[identifier]
-  azure_client = CLIENT_CACHE["azure"][identifier]
-  return WebHdfs.from_config(client_conf, azure_client)
+  def get_credentials(self):
+    return self.ad.get_token()
 
-def _make_abfs_client(identifier):
-  client_conf = conf.ABFS_CLUSTERS[identifier]
-  azure_client_conf = conf.AZURE_ACCOUNTS[identifier]
-  return ABFS.from_config(client_conf, ActiveDirectory.from_config(azure_client_conf, version='v2.0'))#temporary fix
+class CredentialProviderIDBroker(object):
+  def __init__(self, idbroker):
+    self.idbroker=idbroker
 
-def _make_azure_client(identifier):
-  client_conf = conf.AZURE_ACCOUNTS[identifier]
-  return ActiveDirectory.from_config(client_conf)
+  def get_credentials(self):
+    return self.idbroker.get_cab()

+ 5 - 4
desktop/libs/azure/src/azure/conf.py

@@ -21,6 +21,7 @@ from django.utils.translation import ugettext_lazy as _, ugettext as _t
 from hadoop.core_site import get_adls_client_id, get_adls_authentication_code, get_adls_refresh_url
 
 from desktop.lib.conf import Config, UnspecifiedConfigSection, ConfigSection, coerce_password_from_script
+from desktop.lib.idbroker import conf as conf_idbroker
 
 
 LOG = logging.getLogger(__name__)
@@ -144,10 +145,10 @@ ABFS_CLUSTERS = UnspecifiedConfigSection(
 )
 
 def is_adls_enabled():
-  return ('default' in list(AZURE_ACCOUNTS.keys()) and AZURE_ACCOUNTS['default'].get_raw() and AZURE_ACCOUNTS['default'].CLIENT_ID.get() is not None and 'default' in list(ADLS_CLUSTERS.keys()) and ADLS_CLUSTERS['default'].get_raw())
+  return ('default' in list(AZURE_ACCOUNTS.keys() and AZURE_ACCOUNTS['default'].get_raw() and AZURE_ACCOUNTS['default'].CLIENT_ID.get()) or (conf_idbroker.is_idbroker_enabled('azure')) and 'default' in list(ADLS_CLUSTERS.keys()))
 
 def is_abfs_enabled():
-  return ('default' in list(AZURE_ACCOUNTS.keys()) and AZURE_ACCOUNTS['default'].get_raw() and AZURE_ACCOUNTS['default'].CLIENT_ID.get() is not None and 'default' in list(ABFS_CLUSTERS.keys()) and ABFS_CLUSTERS['default'].get_raw())
+  return ('default' in list(AZURE_ACCOUNTS.keys() and AZURE_ACCOUNTS['default'].get_raw() and AZURE_ACCOUNTS['default'].CLIENT_ID.get()) or (conf_idbroker.is_idbroker_enabled('azure')) and 'default' in list(ABFS_CLUSTERS.keys()))
 
 def has_adls_access(user):
   from desktop.auth.backend import is_admin
@@ -160,11 +161,11 @@ def has_abfs_access(user):
 def config_validator(user):
   res = []
 
-  import azure.client # Avoid cyclic loop
+  import desktop.lib.fsmanager # Avoid cyclic loop
 
   if is_adls_enabled() or is_abfs_enabled():
     try:
-      headers = azure.client.get_client('default')._getheaders()
+      headers = desktop.lib.fsmanager.get_client(name='default', fs='abfs')._getheaders()
       if not headers.get('Authorization'):
         raise ValueError('Failed to obtain Azure authorization token')
     except Exception as e:

+ 139 - 0
desktop/libs/azure/src/azure/tests.py

@@ -0,0 +1,139 @@
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+
+import logging
+import unittest
+
+from mock import patch, Mock, PropertyMock
+from nose.plugins.skip import SkipTest
+from nose.tools import assert_equal, assert_true, assert_not_equal
+
+from azure import conf
+from azure.client import get_credential_provider
+
+from desktop.lib.fsmanager import get_client, clear_cache, is_enabled
+from desktop.lib.python_util import current_ms_from_utc
+
+LOG = logging.getLogger(__name__)
+
+
+class TestAzureAdl(unittest.TestCase):
+  def setUp(self):
+    if not is_enabled('adl'):
+      raise SkipTest('adl not enabled')
+
+
+  def test_with_credentials(self):
+    try:
+      finish = conf.AZURE_ACCOUNTS.set_for_testing({'default': {'client_id':'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'}})
+      with patch('azure.client.conf_idbroker.get_conf') as get_conf:
+        with patch('azure.client.WebHdfs'):
+          with patch('azure.client.ActiveDirectory.get_token') as get_token:
+            get_token.return_value = {'access_token': 'access_token', 'token_type': '', 'expires_on': None}
+            get_conf.return_value = {}
+            client1 = get_client(name='default', fs='adl')
+            client2 = get_client(name='default', fs='adl', user='test')
+
+            provider = get_credential_provider('default', 'hue')
+            assert_equal(provider.get_credentials().get('access_token'), 'access_token')
+            assert_equal(client1, client2) # Should be the same as no support for user based client with credentials & no Expiration
+    finally:
+      finish()
+      clear_cache()
+
+
+  def test_with_idbroker(self):
+    try:
+      finish = conf.AZURE_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set
+      with patch('azure.client.conf_idbroker.get_conf') as get_conf:
+        with patch('azure.client.WebHdfs.get_client'):
+          with patch('azure.client.IDBroker.get_cab') as get_cab:
+            get_conf.return_value = {
+              'fs.azure.ext.cab.address': 'address'
+            }
+            get_cab.return_value = { 'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0 }
+            provider = get_credential_provider('default', 'hue')
+            assert_equal(provider.get_credentials().get('access_token'), 'access_token')
+            client1 = get_client(name='default', fs='adl', user='hue')
+            client2 = get_client(name='default', fs='adl', user='hue')
+            assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
+
+            get_cab.return_value = {
+              'Credentials': {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000}
+            }
+            client3 = get_client(name='default', fs='adl', user='hue')
+            client4 = get_client(name='default', fs='adl', user='hue')
+            client5 = get_client(name='default', fs='adl', user='test')
+            assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
+            assert_not_equal(client4, client5) # Test different user have different clients
+    finally:
+      finish()
+      clear_cache()
+
+
+class TestAzureAbfs(unittest.TestCase):
+  def setUp(self):
+    if not is_enabled('abfs'):
+      raise SkipTest('abfs not enabled')
+
+
+  def test_with_credentials(self):
+    try:
+      finish = conf.AZURE_ACCOUNTS.set_for_testing({'default': {'client_id':'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'}})
+      with patch('azure.client.conf_idbroker.get_conf') as get_conf:
+        with patch('azure.client.ABFS'):
+          with patch('azure.client.ActiveDirectory.get_token') as get_token:
+            get_token.return_value = {'access_token': 'access_token', 'token_type': '', 'expires_on': None}
+            get_conf.return_value = {}
+            client1 = get_client(name='default', fs='abfs')
+            client2 = get_client(name='default', fs='abfs', user='test')
+
+            provider = get_credential_provider('default', 'hue')
+            assert_equal(provider.get_credentials().get('access_token'), 'access_token')
+            assert_equal(client1, client2) # Should be the same as no support for user based client with credentials & no Expiration
+    finally:
+      finish()
+      clear_cache()
+
+
+  def test_with_idbroker(self):
+    try:
+      finish = conf.AZURE_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set
+      with patch('azure.client.conf_idbroker.get_conf') as get_conf:
+        with patch('azure.client.ABFS.get_client'):
+          with patch('azure.client.IDBroker.get_cab') as get_cab:
+            get_conf.return_value = {
+              'fs.azure.ext.cab.address': 'address'
+            }
+            get_cab.return_value = { 'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0 }
+            provider = get_credential_provider('default', 'hue')
+            assert_equal(provider.get_credentials().get('access_token'), 'access_token')
+            client1 = get_client(name='default', fs='abfs', user='hue')
+            client2 = get_client(name='default', fs='abfs', user='hue')
+            assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
+
+            get_cab.return_value = {
+              'Credentials': {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000}
+            }
+            client3 = get_client(name='default', fs='abfs', user='hue')
+            client4 = get_client(name='default', fs='abfs', user='hue')
+            client5 = get_client(name='default', fs='abfs', user='test')
+            assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
+            assert_not_equal(client4, client5) # Test different user have different clients
+    finally:
+      finish()
+      clear_cache()

+ 1 - 0
desktop/libs/hadoop/src/hadoop/fs/webhdfs.py

@@ -93,6 +93,7 @@ class WebHdfs(Hdfs):
     self._netloc = "";
     self._is_remote = False
     self._has_trash_support = True
+    self.expiration = None
 
     self._client = self._make_client(url, security_enabled, ssl_cert_ca_verify)
     self._root = resource.Resource(self._client)