瀏覽代碼

[idbroker] Refactor IDBroker with improved HA support and giving more preference to RAZ when both are configured in Hue (#3626)

What changes were proposed in this pull request?

- Refactor IDBroker support and give more preference to RAZ when both are configured in Hue.
- Improved IDBroker HA code section to switchover to healthy instance correctly and not depend only on the first one for every scenario. This should improve Hue page loading performance also.

How was this patch tested?

- Tested manually in a live E2E setup with RAZ enabled to check for no regressions + correct IDBroker switchover + improved Hue page load time.
- Update existing unit tests.
- Adding new unit tests for IDBroker HA.
Harsh Gupta 1 年之前
父節點
當前提交
0e711eacb0

+ 2 - 3
desktop/core/src/desktop/conf.py

@@ -2623,9 +2623,8 @@ def is_cm_managed():
 def is_gs_enabled():
   from desktop.lib.idbroker import conf as conf_idbroker # Circular dependencies  desktop.conf -> idbroker.conf -> desktop.conf
 
-  return ('default' in list(GC_ACCOUNTS.keys()) and GC_ACCOUNTS['default'].JSON_CREDENTIALS.get()) or \
-      conf_idbroker.is_idbroker_enabled('gs') or \
-      is_raz_gs()
+  return ('default' in list(GC_ACCOUNTS.keys()) and GC_ACCOUNTS['default'].JSON_CREDENTIALS.get()) or is_raz_gs() or \
+    conf_idbroker.is_idbroker_enabled('gs')
 
 def has_gs_access(user):
   from desktop.auth.backend import is_admin

+ 11 - 8
desktop/core/src/desktop/lib/idbroker/client.py

@@ -13,8 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from __future__ import absolute_import
-
 from builtins import object
 import logging
 
@@ -48,8 +46,11 @@ class IDBroker(object):
 
 
   def __init__(self, user=None, address=None, dt_path=None, path=None, security=None):
-    self.user=user
-    self.address=address
+    self.user = user
+    self.address = address
+    if not self.address:
+      raise PopupException('Failed to connect to IDBroker: No active or healthy instance was found.')
+
     self.dt_path = dt_path
     self.path = path
     self.security = security
@@ -60,9 +61,9 @@ class IDBroker(object):
   def _knox_token_params(self):
     if self.user:
       if self.security['type'] == 'kerberos':
-        return { 'doAs': self.user }
+        return {'doAs': self.user}
       else:
-        return { 'user.name': self.user }
+        return {'user.name': self.user}
     else:
       return None
 
@@ -73,7 +74,8 @@ class IDBroker(object):
     elif self.security['type'] == 'basic':
       self._client.set_basic_auth(self.security['params']['username'], self.security['params']['password'])
     try:
-      res = self._root.invoke("GET", self.dt_path + _KNOX_TOKEN_API, self._knox_token_params(), allow_redirects=True, log_response=False) # Can't log response because returns credentials
+      # Can't log response because returns credentials
+      res = self._root.invoke("GET", self.dt_path + _KNOX_TOKEN_API, self._knox_token_params(), allow_redirects=True, log_response=False)
       return res.get('access_token')
     except Exception as e:
       raise PopupException('Failed to authenticate to IDBroker with error: %s' % e.message)
@@ -82,6 +84,7 @@ class IDBroker(object):
   def get_cab(self):
     self._client.set_bearer_auth(self.get_auth_token())
     try:
-      return self._root.invoke("GET", self.path + _CAB_API_CREDENTIALS_GLOBAL, allow_redirects=True, log_response=False) # Can't log response because returns credentials
+      # Can't log response because returns credentials
+      return self._root.invoke("GET", self.path + _CAB_API_CREDENTIALS_GLOBAL, allow_redirects=True, log_response=False)
     except Exception as e:
       raise PopupException('Failed to obtain storage credentials from IDBroker with error: %s' % e.message)

+ 30 - 30
desktop/core/src/desktop/lib/idbroker/conf.py

@@ -13,61 +13,56 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from __future__ import absolute_import
-
 import logging
-import sys
 import requests
 
 from requests_kerberos import HTTPKerberosAuth
 from hadoop.core_site import get_conf
 
-if sys.version_info[0] > 2:
-  from django.utils.translation import gettext_lazy as _t
-else:
-  from django.utils.translation import ugettext_lazy as _t
+from django.utils.translation import gettext_lazy as _t
+
 
 LOG = logging.getLogger()
 
+
 _CNF_CAB_ADDRESS = 'fs.%s.ext.cab.address' # http://host:8444/gateway
 _CNF_CAB_ADDRESS_DT_PATH = 'fs.%s.ext.cab.dt.path' # dt
 _CNF_CAB_ADDRESS_PATH = 'fs.%s.ext.cab.path' # aws-cab
 _CNF_CAB_USERNAME = 'fs.%s.ext.cab.username' # when not using kerberos
 _CNF_CAB_PASSWORD = 'fs.%s.ext.cab.password'
+
 SUPPORTED_FS = {'s3a': 's3a', 'adl': 'azure', 'abfs': 'azure', 'azure': 'azure', 'gs': 'gs'}
 
 def validate_fs(fs=None):
   if fs in SUPPORTED_FS:
     return SUPPORTED_FS[fs]
   else:
-    LOG.warning('Selected FS %s is not supported by Hue IDBroker client' % fs)
+    LOG.warning('Selected filesystem %s is not supported by Hue IDBroker client.' % fs)
     return None
 
 def _handle_idbroker_ha(fs=None):
-  fs = validate_fs(fs)
-  idbrokeraddr = get_conf().get(_CNF_CAB_ADDRESS % fs) if fs else None
-  response = None
+  idbroker_addr_list = []
   if fs:
-    id_broker_addr = get_conf().get(_CNF_CAB_ADDRESS % fs)
-    if id_broker_addr:
-      id_broker_addr_list = id_broker_addr.split(',')
-      for id_broker_addr in id_broker_addr_list:
-        try:
-          response = requests.get(id_broker_addr.rstrip('/') + '/dt/knoxtoken/api/v1/token', auth=HTTPKerberosAuth(), verify=False)
-        except Exception as e:
-          if 'Name or service not known' in str(e):
-            LOG.warn('IDBroker %s is not available for use' % id_broker_addr)
-        # Check response for None and if response code is successful (200) or authentication needed (401)
-        if (response is not None) and (response.status_code in (200, 401)):
-          idbrokeraddr = id_broker_addr
-          break
-      return idbrokeraddr
-    else:
-      return idbrokeraddr
-  else:
-    return idbrokeraddr
+    idbroker_addr = get_conf().get(_CNF_CAB_ADDRESS % fs, '')
+    idbroker_addr_list = idbroker_addr.split(',')
+
+  response = None
+  for idb in idbroker_addr_list:
+    LOG.info('Attempting to connect to IDBroker URL: %s' % idb)
+
+    try:
+      response = requests.get(idb.rstrip('/') + '/dt/knoxtoken/api/v1/token', auth=HTTPKerberosAuth(), verify=False)
+    except Exception as e:
+      if 'Failed to establish a new connection' in str(e):
+        LOG.warning('IDBroker URL %s is not available.' % idb)
+
+    # Check response for None and if response code is successful (200) or authentication needed (401)
+    if (response is not None) and (response.status_code in (200, 401)):
+      return idb
+
 
 def get_cab_address(fs=None):
+  fs = validate_fs(fs)
   return _handle_idbroker_ha(fs)
 
 def get_cab_dt_path(fs=None):
@@ -89,7 +84,12 @@ def get_cab_password(fs=None):
 def is_idbroker_enabled(fs=None):
   from desktop.conf import RAZ  # Must be imported dynamically in order to have proper value
 
-  return get_cab_address(fs) is not None and not RAZ.IS_ENABLED.get() # Skipping IDBroker for FS when RAZ is present
+  fs = validate_fs(fs)
+  idbroker_addr_from_coresite = get_conf().get(_CNF_CAB_ADDRESS % fs)
+
+  # When RAZ is configured, skip checking for IDBroker configs from core-site. 
+  # RAZ gets precedence over IDBroker when both are configured in Hue.
+  return (not RAZ.IS_ENABLED.get() and bool(idbroker_addr_from_coresite))
 
 def config_validator():
   res = []

+ 100 - 35
desktop/core/src/desktop/lib/idbroker/tests.py

@@ -13,63 +13,128 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from __future__ import absolute_import
-
 import logging
 import unittest
-import sys
 
-from nose.tools import assert_equal, assert_true
+from nose.tools import assert_equal, assert_raises
+from unittest.mock import Mock, patch
 
 from desktop.lib.idbroker.client import IDBroker
+from desktop.lib.idbroker.conf import _handle_idbroker_ha
+from desktop.lib.exceptions_renderable import PopupException
 
-if sys.version_info[0] > 2:
-  from unittest.mock import patch
-else:
-  from mock import patch
 
 LOG = logging.getLogger()
 
-class TestIDBroker(unittest.TestCase):
+
+class TestIDBrokerClient(unittest.TestCase):
   def test_username_authentication(self):
     with patch('desktop.lib.idbroker.conf.get_conf') as conf:
       with patch('desktop.lib.idbroker.client.resource.Resource.invoke') as invoke:
         with patch('desktop.lib.idbroker.client.http_client.HttpClient.set_basic_auth') as set_basic_auth:
-          conf.return_value = {
-            'fs.s3a.ext.cab.address': 'address',
-            'fs.s3a.ext.cab.dt.path': 'dt_path',
-            'fs.s3a.ext.cab.path': 'path',
-            'fs.s3a.ext.cab.username': 'username',
-            'fs.s3a.ext.cab.password': 'password'
-          }
-          invoke.return_value = {
-             'Credentials': 'Credentials'
-          }
-          client = IDBroker.from_core_site('s3a', 'test')
-
-          cab = client.get_cab()
-          assert_equal(invoke.call_count, 2) # get_cab calls twice
-          assert_equal(cab.get('Credentials'), 'Credentials')
-          assert_equal(set_basic_auth.call_count, 1)
-
-  def test_kerberos_authentication(self):
-    with patch('desktop.lib.idbroker.conf.get_conf') as conf:
-      with patch('desktop.lib.idbroker.client.is_kerberos_enabled') as is_kerberos_enabled:
-        with patch('desktop.lib.idbroker.client.resource.Resource.invoke') as invoke:
-          with patch('desktop.lib.idbroker.client.http_client.HttpClient.set_kerberos_auth') as set_kerberos_auth:
-            is_kerberos_enabled.return_value = True
+          with patch('desktop.lib.idbroker.conf.get_cab_address') as get_cab_address:
             conf.return_value = {
               'fs.s3a.ext.cab.address': 'address',
               'fs.s3a.ext.cab.dt.path': 'dt_path',
               'fs.s3a.ext.cab.path': 'path',
-              'hadoop.security.authentication': 'kerberos',
+              'fs.s3a.ext.cab.username': 'username',
+              'fs.s3a.ext.cab.password': 'password'
             }
             invoke.return_value = {
               'Credentials': 'Credentials'
             }
-            client = IDBroker.from_core_site('s3a', 'test')
+            get_cab_address.return_value = 'address'
 
+            client = IDBroker.from_core_site('s3a', 'test')
             cab = client.get_cab()
+
             assert_equal(invoke.call_count, 2) # get_cab calls twice
             assert_equal(cab.get('Credentials'), 'Credentials')
-            assert_equal(set_kerberos_auth.call_count, 1)
+            assert_equal(set_basic_auth.call_count, 1)
+
+
+  def test_kerberos_authentication(self):
+    with patch('desktop.lib.idbroker.conf.get_conf') as conf:
+      with patch('desktop.lib.idbroker.client.is_kerberos_enabled') as is_kerberos_enabled:
+        with patch('desktop.lib.idbroker.client.resource.Resource.invoke') as invoke:
+          with patch('desktop.lib.idbroker.client.http_client.HttpClient.set_kerberos_auth') as set_kerberos_auth:
+            with patch('desktop.lib.idbroker.conf.get_cab_address') as get_cab_address:
+              is_kerberos_enabled.return_value = True
+              conf.return_value = {
+                'fs.s3a.ext.cab.address': 'address',
+                'fs.s3a.ext.cab.dt.path': 'dt_path',
+                'fs.s3a.ext.cab.path': 'path',
+                'hadoop.security.authentication': 'kerberos',
+              }
+              invoke.return_value = {
+                'Credentials': 'Credentials'
+              }
+              get_cab_address.return_value = 'address'
+
+              client = IDBroker.from_core_site('s3a', 'test')
+              cab = client.get_cab()
+
+              assert_equal(invoke.call_count, 2) # get_cab calls twice
+              assert_equal(cab.get('Credentials'), 'Credentials')
+              assert_equal(set_kerberos_auth.call_count, 1)
+
+
+  def test_no_idbroker_address_found(self):
+    with patch('desktop.lib.idbroker.conf.get_conf') as conf:
+      with patch('desktop.lib.idbroker.conf.get_cab_address') as get_cab_address:
+        conf.return_value = {
+          'fs.s3a.ext.cab.address': 'address',
+          'fs.s3a.ext.cab.dt.path': 'dt_path',
+          'fs.s3a.ext.cab.path': 'path'
+        }
+
+        # No active IDBroker URL available
+        get_cab_address.return_value = None
+        assert_raises(PopupException, IDBroker.from_core_site, 's3a', 'test')
+
+
+
+class TestIDBrokerHA(unittest.TestCase):
+  def test_idbroker_non_ha(self):
+    with patch('desktop.lib.idbroker.conf.get_conf') as conf:
+      with patch('desktop.lib.idbroker.conf.requests.get') as requests_get:
+        conf.return_value = {'fs.s3a.ext.cab.address': 'https://idbroker0.gethue.com:8444/gateway'}
+        requests_get.return_value = Mock(status_code=200)
+
+        idbroker_url = _handle_idbroker_ha(fs='s3a')
+        assert_equal(idbroker_url, 'https://idbroker0.gethue.com:8444/gateway')
+        assert_equal(requests_get.call_count, 1)
+
+
+  def test_idbroker_ha(self):
+    with patch('desktop.lib.idbroker.conf.get_conf') as conf:
+      with patch('desktop.lib.idbroker.conf.requests.get') as requests_get:
+        conf.return_value = {
+          'fs.s3a.ext.cab.address': 'https://idbroker0.gethue.com:8444/gateway,https://idbroker1.gethue.com:8444/gateway'
+        }
+
+        # When IDBroker0 is healthy and IDBroker1 is unhealthy
+        requests_get.side_effect = [Mock(status_code=200), Mock(status_code=404)]
+        idbroker_url = _handle_idbroker_ha(fs='s3a')
+
+        assert_equal(idbroker_url, 'https://idbroker0.gethue.com:8444/gateway')
+        assert_equal(requests_get.call_count, 1)
+        requests_get.reset_mock()
+
+
+        # When IDBroker0 is unhealthy and IDBroker1 is healthy
+        requests_get.side_effect = [Mock(status_code=404), Mock(status_code=200)]
+        idbroker_url = _handle_idbroker_ha(fs='s3a')
+
+        assert_equal(idbroker_url, 'https://idbroker1.gethue.com:8444/gateway')
+        assert_equal(requests_get.call_count, 2)
+        requests_get.reset_mock()
+
+
+        # When both IDBroker0 and IDBroker1 are unhealthy
+        requests_get.side_effect = [Mock(status_code=404), Mock(status_code=404)]
+        idbroker_url = _handle_idbroker_ha(fs='s3a')
+
+        assert_equal(idbroker_url, None)
+        assert_equal(requests_get.call_count, 2)
+

+ 2 - 2
desktop/libs/aws/src/aws/conf.py

@@ -273,8 +273,8 @@ AWS_ACCOUNTS = UnspecifiedConfigSection(
 def is_enabled():
   return ('default' in list(AWS_ACCOUNTS.keys()) and AWS_ACCOUNTS['default'].get_raw() and AWS_ACCOUNTS['default'].ACCESS_KEY_ID.get()) or \
       has_iam_metadata() or \
-      conf_idbroker.is_idbroker_enabled('s3a') or \
-      is_raz_s3()
+      is_raz_s3() or \
+      conf_idbroker.is_idbroker_enabled('s3a')
 
 
 def is_ec2_instance():

+ 70 - 59
desktop/libs/aws/src/aws/tests.py

@@ -13,10 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from __future__ import absolute_import
-
 import logging
-import sys
 import unittest
 
 from nose.tools import assert_equal, assert_true, assert_not_equal
@@ -28,10 +25,8 @@ from desktop.lib.fsmanager import get_client, clear_cache
 from desktop.lib.python_util import current_ms_from_utc
 from desktop.conf import RAZ
 
-if sys.version_info[0] > 2:
-  from unittest.mock import patch
-else:
-  from mock import patch
+from unittest.mock import patch
+
 
 LOG = logging.getLogger()
 
@@ -54,86 +49,102 @@ class TestAWS(unittest.TestCase):
       clear_cache()
       conf.clear_cache()
 
+
   def test_with_idbroker(self):
     try:
       finish = conf.AWS_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set
       with patch('aws.client.conf_idbroker.get_conf') as get_conf:
-        with patch('aws.client.Client.get_s3_connection'):
-          with patch('aws.client.IDBroker.get_cab') as get_cab:
-            with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata:
-              get_conf.return_value = {
-                'fs.s3a.ext.cab.address': 'address'
-              }
-              get_cab.return_value = {
-                'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
-              }
-              has_iam_metadata.return_value = True
-              provider = get_credential_provider('default', 'hue')
-              assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId')
-              client1 = get_client(name='default', fs='s3a', user='hue')
-              client2 = get_client(name='default', fs='s3a', user='hue')
-              assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
-
-              get_cab.return_value = {
-                'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': int(current_ms_from_utc()) + 10*1000}
-              }
-              client3 = get_client(name='default', fs='s3a', user='hue')
-              client4 = get_client(name='default', fs='s3a', user='hue')
-              client5 = get_client(name='default', fs='s3a', user='test')
-              assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
-              assert_not_equal(client4, client5) # Test different user have different clients
+        with patch('aws.client.conf_idbroker.get_cab_address') as get_cab_address:
+          with patch('aws.client.Client.get_s3_connection'):
+            with patch('aws.client.IDBroker.get_cab') as get_cab:
+              with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata:
+                get_conf.return_value = {
+                  'fs.s3a.ext.cab.address': 'address'
+                }
+                get_cab_address.return_value = 'address'
+                get_cab.return_value = {
+                  'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
+                }
+                has_iam_metadata.return_value = True
+                provider = get_credential_provider('default', 'hue')
+
+                assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId')
+
+                client1 = get_client(name='default', fs='s3a', user='hue')
+                client2 = get_client(name='default', fs='s3a', user='hue')
+
+                assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
+
+                get_cab.return_value = {
+                  'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': int(current_ms_from_utc()) + 10*1000}
+                }
+                client3 = get_client(name='default', fs='s3a', user='hue')
+                client4 = get_client(name='default', fs='s3a', user='hue')
+                client5 = get_client(name='default', fs='s3a', user='test')
+
+                assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
+                assert_not_equal(client4, client5) # Test different user have different clients
     finally:
       finish()
       clear_cache()
       conf.clear_cache()
 
+
   def test_with_idbroker_and_config(self):
     try:
       finish = conf.AWS_ACCOUNTS.set_for_testing({'default': {'region': 'ap-northeast-1'}})
       with patch('aws.client.conf_idbroker.get_conf') as get_conf:
-        with patch('aws.client.Client.get_s3_connection'):
-          with patch('aws.client.IDBroker.get_cab') as get_cab:
-            with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata:
-              get_conf.return_value = {
-                'fs.s3a.ext.cab.address': 'address'
-              }
-              get_cab.return_value = {
-                'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
-              }
-              has_iam_metadata.return_value = True
-              provider = get_credential_provider('default', 'hue')
-              assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId')
-
-              client = Client.from_config(conf.AWS_ACCOUNTS['default'], get_credential_provider('default', 'hue'))
-              assert_equal(client._region, 'ap-northeast-1')
-    finally:
-      finish()
-      clear_cache()
-      conf.clear_cache()
-
-  def test_with_idbroker_on_ec2(self):
-    try:
-      finish = conf.AWS_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set
-      with patch('aws.client.aws_conf.get_region') as get_region:
-        with patch('aws.client.conf_idbroker.get_conf') as get_conf:
+        with patch('aws.client.conf_idbroker.get_cab_address') as get_cab_address:
           with patch('aws.client.Client.get_s3_connection'):
             with patch('aws.client.IDBroker.get_cab') as get_cab:
               with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata:
-                get_region.return_value = 'us-west-1'
                 get_conf.return_value = {
                   'fs.s3a.ext.cab.address': 'address'
                 }
+                get_cab_address.return_value = 'address'
                 get_cab.return_value = {
                   'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
                 }
                 has_iam_metadata.return_value = True
-                client = Client.from_config(None, get_credential_provider('default', 'hue'))
-                assert_equal(client._region, 'us-west-1') # Test different user have different clients
+
+                provider = get_credential_provider('default', 'hue')
+                assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId')
+
+                client = Client.from_config(conf.AWS_ACCOUNTS['default'], get_credential_provider('default', 'hue'))
+                assert_equal(client._region, 'ap-northeast-1')
     finally:
       finish()
       clear_cache()
       conf.clear_cache()
 
+
+  def test_with_idbroker_on_ec2(self):
+    try:
+      finish = conf.AWS_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set
+      with patch('aws.client.aws_conf.get_region') as get_region:
+        with patch('aws.client.conf_idbroker.get_conf') as get_conf:
+          with patch('aws.client.conf_idbroker.get_cab_address') as get_cab_address:
+            with patch('aws.client.Client.get_s3_connection'):
+              with patch('aws.client.IDBroker.get_cab') as get_cab:
+                with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata:
+                  get_region.return_value = 'us-west-1'
+                  get_conf.return_value = {
+                    'fs.s3a.ext.cab.address': 'address'
+                  }
+                  get_cab_address.return_value = 'address'
+                  get_cab.return_value = {
+                    'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
+                  }
+                  has_iam_metadata.return_value = True
+                  client = Client.from_config(None, get_credential_provider('default', 'hue'))
+
+                  assert_equal(client._region, 'us-west-1') # Test different user have different clients
+    finally:
+      finish()
+      clear_cache()
+      conf.clear_cache()
+
+
   def test_with_raz_enabled(self):
     with patch('aws.client.RazS3Connection') as raz_s3_connection:
       resets = [

+ 3 - 3
desktop/libs/azure/src/azure/conf.py

@@ -166,9 +166,9 @@ def is_adls_enabled():
     or (conf_idbroker.is_idbroker_enabled('azure') and has_azure_metadata())) and 'default' in list(ADLS_CLUSTERS.keys())
 
 def is_abfs_enabled():
-  return ('default' in list(AZURE_ACCOUNTS.keys()) and AZURE_ACCOUNTS['default'].get_raw() and AZURE_ACCOUNTS['default'].CLIENT_ID.get() \
-    or (conf_idbroker.is_idbroker_enabled('azure') and has_azure_metadata())) and 'default' in list(ABFS_CLUSTERS.keys()) \
-    or is_raz_abfs()
+  return is_raz_abfs() or \
+    ('default' in list(AZURE_ACCOUNTS.keys()) and AZURE_ACCOUNTS['default'].get_raw() and AZURE_ACCOUNTS['default'].CLIENT_ID.get() or \
+    (conf_idbroker.is_idbroker_enabled('azure') and has_azure_metadata())) and 'default' in list(ABFS_CLUSTERS.keys())
 
 def has_adls_access(user):
   from desktop.conf import RAZ  # Must be imported dynamically in order to have proper value

+ 96 - 65
desktop/libs/azure/src/azure/tests.py

@@ -13,25 +13,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from __future__ import absolute_import
-
 import logging
-import sys
 import unittest
 
-from nose.plugins.skip import SkipTest
-from nose.tools import assert_equal, assert_true, assert_not_equal
+from nose.tools import assert_equal, assert_not_equal
+from unittest.mock import patch
 
 from azure import conf
 from azure.client import get_credential_provider
 
-from desktop.lib.fsmanager import get_client, clear_cache, is_enabled
+from desktop.lib.fsmanager import get_client, clear_cache
 from desktop.lib.python_util import current_ms_from_utc
 
-if sys.version_info[0] > 2:
-  from unittest.mock import patch
-else:
-  from mock import patch
 
 LOG = logging.getLogger()
 
@@ -48,7 +41,11 @@ class TestAzureAdl(unittest.TestCase):
             with patch('azure.conf.core_site.get_conf') as core_site_get_conf:
               get_token.return_value = {'access_token': 'access_token', 'token_type': '', 'expires_on': None}
               get_conf.return_value = {}
-              core_site_get_conf.return_value = {'dfs.adls.oauth2.client.id': 'client_id', 'dfs.adls.oauth2.credential': 'client_secret', 'dfs.adls.oauth2.refresh.url': 'refresh_url'}
+              core_site_get_conf.return_value = {
+                'dfs.adls.oauth2.client.id': 'client_id',
+                'dfs.adls.oauth2.credential': 'client_secret',
+                'dfs.adls.oauth2.refresh.url': 'refresh_url'
+              }
               client1 = get_client(name='default', fs='adl')
               client2 = get_client(name='default', fs='adl', user='test')
 
@@ -60,10 +57,15 @@ class TestAzureAdl(unittest.TestCase):
         f()
       clear_cache()
 
+
   def test_with_credentials(self):
     try:
-      finish = (conf.AZURE_ACCOUNTS.set_for_testing({'default': {'client_id':'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'}}),
-                conf.ADLS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}}))
+      finish = (
+        conf.AZURE_ACCOUNTS.set_for_testing({
+          'default': {'client_id': 'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'}
+        }),
+        conf.ADLS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}})
+      )
       with patch('azure.client.conf_idbroker.get_conf') as get_conf:
         with patch('azure.client.WebHdfs.get_client'):
           with patch('azure.client.ActiveDirectory.get_token') as get_token:
@@ -83,31 +85,41 @@ class TestAzureAdl(unittest.TestCase):
 
   def test_with_idbroker(self):
     try:
-      finish = (conf.AZURE_ACCOUNTS.set_for_testing({}),
-                conf.ADLS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}}))
+      finish = (
+        conf.AZURE_ACCOUNTS.set_for_testing({}),
+        conf.ADLS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}})
+      )
       with patch('azure.client.conf_idbroker.get_conf') as get_conf:
-        with patch('azure.client.WebHdfs.get_client'):
-          with patch('azure.client.IDBroker.get_cab') as get_cab:
-            with patch('azure.client.conf.has_azure_metadata') as has_azure_metadata:
-              get_conf.return_value = {
-                'fs.azure.ext.cab.address': 'address'
-              }
-              has_azure_metadata.return_value = True
-              get_cab.return_value = { 'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0 }
-              provider = get_credential_provider('default', 'hue')
-              assert_equal(provider.get_credentials().get('access_token'), 'access_token')
-              client1 = get_client(name='default', fs='adl', user='hue')
-              client2 = get_client(name='default', fs='adl', user='hue')
-              assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
-
-              get_cab.return_value = {
-                'Credentials': {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000}
-              }
-              client3 = get_client(name='default', fs='adl', user='hue')
-              client4 = get_client(name='default', fs='adl', user='hue')
-              client5 = get_client(name='default', fs='adl', user='test')
-              assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
-              assert_not_equal(client4, client5) # Test different user have different clients
+        with patch('azure.client.conf_idbroker.get_cab_address') as get_cab_address:
+          with patch('azure.client.WebHdfs.get_client'):
+            with patch('azure.client.IDBroker.get_cab') as get_cab:
+              with patch('azure.client.conf.has_azure_metadata') as has_azure_metadata:
+                get_conf.return_value = {
+                  'fs.azure.ext.cab.address': 'address'
+                }
+                get_cab_address.return_value = 'address'
+                has_azure_metadata.return_value = True
+                get_cab.return_value = {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0}
+                provider = get_credential_provider('default', 'hue')
+
+                assert_equal(provider.get_credentials().get('access_token'), 'access_token')
+
+                client1 = get_client(name='default', fs='adl', user='hue')
+                client2 = get_client(name='default', fs='adl', user='hue')
+
+                assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
+
+                get_cab.return_value = {
+                  'Credentials': {
+                    'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000
+                  }
+                }
+                client3 = get_client(name='default', fs='adl', user='hue')
+                client4 = get_client(name='default', fs='adl', user='hue')
+                client5 = get_client(name='default', fs='adl', user='test')
+
+                assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
+                assert_not_equal(client4, client5) # Test different user have different clients
     finally:
       for f in finish:
         f()
@@ -126,7 +138,11 @@ class TestAzureAbfs(unittest.TestCase):
             with patch('azure.conf.core_site.get_conf') as core_site_get_conf:
               get_token.return_value = {'access_token': 'access_token', 'token_type': '', 'expires_on': None}
               get_conf.return_value = {}
-              core_site_get_conf.return_value = {'fs.azure.account.oauth2.client.id': 'client_id', 'fs.azure.account.oauth2.client.secret': 'client_secret', 'fs.azure.account.oauth2.client.endpoint': 'refresh_url'}
+              core_site_get_conf.return_value = {
+                'fs.azure.account.oauth2.client.id': 'client_id',
+                'fs.azure.account.oauth2.client.secret': 'client_secret',
+                'fs.azure.account.oauth2.client.endpoint': 'refresh_url'
+              }
               client1 = get_client(name='default', fs='abfs')
               client2 = get_client(name='default', fs='abfs', user='test')
 
@@ -138,10 +154,15 @@ class TestAzureAbfs(unittest.TestCase):
         f()
       clear_cache()
 
+
   def test_with_credentials(self):
     try:
-      finish = (conf.AZURE_ACCOUNTS.set_for_testing({'default': {'client_id':'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'}}),
-                conf.ABFS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}}))
+      finish = (
+        conf.AZURE_ACCOUNTS.set_for_testing({
+          'default': {'client_id': 'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'}
+        }),
+        conf.ABFS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}})
+      )
       with patch('azure.client.conf_idbroker.get_conf') as get_conf:
         with patch('azure.client.ABFS.get_client'):
           with patch('azure.client.ActiveDirectory.get_token') as get_token:
@@ -161,32 +182,42 @@ class TestAzureAbfs(unittest.TestCase):
 
   def test_with_idbroker(self):
     try:
-      finish = (conf.AZURE_ACCOUNTS.set_for_testing({}),
-                conf.ABFS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}}))
+      finish = (
+        conf.AZURE_ACCOUNTS.set_for_testing({}),
+        conf.ABFS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}})
+      )
       with patch('azure.client.conf_idbroker.get_conf') as get_conf:
-        with patch('azure.client.ABFS.get_client'):
-          with patch('azure.client.IDBroker.get_cab') as get_cab:
-            with patch('azure.client.conf.has_azure_metadata') as has_azure_metadata:
-              get_conf.return_value = {
-                'fs.azure.ext.cab.address': 'address'
-              }
-              has_azure_metadata.return_value = True
-              get_cab.return_value = { 'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0 }
-              provider = get_credential_provider('default', 'hue')
-              assert_equal(provider.get_credentials().get('access_token'), 'access_token')
-              client1 = get_client(name='default', fs='abfs', user='hue')
-              client2 = get_client(name='default', fs='abfs', user='hue')
-              assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
-
-              get_cab.return_value = {
-                'Credentials': {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000}
-              }
-              client3 = get_client(name='default', fs='abfs', user='hue')
-              client4 = get_client(name='default', fs='abfs', user='hue')
-              client5 = get_client(name='default', fs='abfs', user='test')
-              assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
-              assert_not_equal(client4, client5) # Test different user have different clients
+        with patch('azure.client.conf_idbroker.get_cab_address') as get_cab_address:
+          with patch('azure.client.ABFS.get_client'):
+            with patch('azure.client.IDBroker.get_cab') as get_cab:
+              with patch('azure.client.conf.has_azure_metadata') as has_azure_metadata:
+                get_conf.return_value = {
+                  'fs.azure.ext.cab.address': 'address'
+                }
+                get_cab_address.return_value = 'address'
+                has_azure_metadata.return_value = True
+                get_cab.return_value = {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0}
+                provider = get_credential_provider('default', 'hue')
+
+                assert_equal(provider.get_credentials().get('access_token'), 'access_token')
+
+                client1 = get_client(name='default', fs='abfs', user='hue')
+                client2 = get_client(name='default', fs='abfs', user='hue')
+
+                assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal
+
+                get_cab.return_value = {
+                  'Credentials': {
+                    'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000
+                  }
+                }
+                client3 = get_client(name='default', fs='abfs', user='hue')
+                client4 = get_client(name='default', fs='abfs', user='hue')
+                client5 = get_client(name='default', fs='abfs', user='test')
+
+                assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal
+                assert_not_equal(client4, client5) # Test different user have different clients
     finally:
       for f in finish:
         f()
-      clear_cache()
+      clear_cache()