Browse Source

HUE-8998 [fb] Add test for S3 + IDBroker + Conf

Jean-Francois Desjeans Gauthier 6 years ago
parent
commit
a04c389557

+ 1 - 1
desktop/libs/aws/src/aws/client.py

@@ -166,7 +166,7 @@ class Client(object):
         aws_access_key_id=credentials.get('AccessKeyId'),
         aws_secret_access_key=credentials.get('SecretAccessKey'),
         aws_security_token=credentials.get('SessionToken'),
-        region=aws_conf.get_default_region(),
+        region=aws_conf.get_region(conf),
         host=conf.HOST.get(),
         proxy_address=conf.PROXY_ADDRESS.get(),
         proxy_port=conf.PROXY_PORT.get(),

+ 10 - 5
desktop/libs/aws/src/aws/conf.py

@@ -82,20 +82,24 @@ def get_default_session_token():
 
 
 def get_default_region():
+  return get_region(AWS_ACCOUNTS['default']) if 'default' in AWS_ACCOUNTS else ''
+
+
+def get_region(conf):
   region = ''
 
-  if 'default' in AWS_ACCOUNTS:
+  if conf:
     # First check the host/endpoint configuration
-    if AWS_ACCOUNTS['default'].HOST.get():
-      endpoint = AWS_ACCOUNTS['default'].HOST.get()
+    if conf.HOST.get():
+      endpoint = conf.HOST.get()
       if re.search(SUBDOMAIN_ENDPOINT_RE, endpoint, re.IGNORECASE):
         region = re.search(SUBDOMAIN_ENDPOINT_RE, endpoint, re.IGNORECASE).group('region')
       elif re.search(HYPHEN_ENDPOINT_RE, endpoint, re.IGNORECASE):
         region = re.search(HYPHEN_ENDPOINT_RE, endpoint, re.IGNORECASE).group('region')
       elif re.search(DUALSTACK_ENDPOINT_RE, endpoint, re.IGNORECASE):
         region = re.search(DUALSTACK_ENDPOINT_RE, endpoint, re.IGNORECASE).group('region')
-    elif AWS_ACCOUNTS['default'].REGION.get():
-      region = AWS_ACCOUNTS['default'].REGION.get()
+    elif conf.REGION.get():
+      region = conf.REGION.get()
 
     # If the parsed out region is not in the list of supported regions, fallback to the default
     if region not in get_locations():
@@ -104,6 +108,7 @@ def get_default_region():
 
   return region
 
+
 def get_key_expiry():
   if 'default' in AWS_ACCOUNTS:
     return AWS_ACCOUNTS['default'].KEY_EXPIRY.get()

+ 22 - 1
desktop/libs/aws/src/aws/tests.py

@@ -22,7 +22,7 @@ from mock import patch, Mock
 from nose.tools import assert_equal, assert_true, assert_not_equal
 
 from aws import conf
-from aws.client import clear_cache, get_client, get_credential_provider, current_ms_from_utc
+from aws.client import clear_cache, Client, get_client, get_credential_provider, current_ms_from_utc
 
 LOG = logging.getLogger(__name__)
 
@@ -72,3 +72,24 @@ class TestAWS(unittest.TestCase):
     finally:
       finish()
       clear_cache()
+
+  def test_with_idbroker_and_config(self):
+    try:
+      finish = conf.AWS_ACCOUNTS.set_for_testing({'default': {'region': 'ap-northeast-1'}})
+      with patch('aws.client.conf_idbroker.get_conf') as get_conf:
+        with patch('aws.client.Client.get_s3_connection'):
+          with patch('aws.client.IDBroker.get_cab') as get_cab:
+            get_conf.return_value = {
+              'fs.s3a.ext.cab.address': 'address'
+            }
+            get_cab.return_value = {
+              'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0}
+            }
+            provider = get_credential_provider()
+            assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId')
+
+            client = Client.from_config(conf.AWS_ACCOUNTS['default'], get_credential_provider('default', 'hue'))
+            assert_equal(client._region, 'ap-northeast-1')
+    finally:
+      finish()
+      clear_cache()