Преглед изворни кода

HUE-6780 [s3] Correctly infer and display region when connected to S3 by endpoint

Jenny Kim пре 8 година
родитељ
комит
484fb77

+ 2 - 2
desktop/libs/aws/src/aws/client.py

@@ -40,7 +40,7 @@ class Client(object):
     self._access_key_id = aws_access_key_id
     self._secret_access_key = aws_secret_access_key
     self._security_token = aws_security_token
-    self._region = region.lower() if region else get_default_region()
+    self._region = region.lower()
     self._timeout = timeout
     self._host = host
     self._proxy_address = proxy_address
@@ -70,7 +70,7 @@ class Client(object):
       aws_access_key_id=access_key_id,
       aws_secret_access_key=secret_access_key,
       aws_security_token=security_token,
-      region=conf.REGION.get(),
+      region=get_default_region(),
       host=conf.HOST.get(),
       proxy_address=conf.PROXY_ADDRESS.get(),
       proxy_port=conf.PROXY_PORT.get(),

+ 32 - 2
desktop/libs/aws/src/aws/conf.py

@@ -16,6 +16,7 @@
 from __future__ import absolute_import
 
 import logging
+import re
 
 import boto.utils
 from boto.s3.connection import Location
@@ -30,7 +31,16 @@ from hadoop.core_site import get_s3a_access_key, get_s3a_secret_key
 LOG = logging.getLogger(__name__)
 
 
-DEFAULT_CALLING_FORMAT='boto.s3.connection.OrdinaryCallingFormat'
+DEFAULT_CALLING_FORMAT = 'boto.s3.connection.OrdinaryCallingFormat'
+SUBDOMAIN_ENDPOINT_RE = 's3.(?P<region>[a-z0-9-]+).amazonaws.com'
+HYPHEN_ENDPOINT_RE = 's3-(?P<region>[a-z0-9-]+).amazonaws.com'
+DUALSTACK_ENDPOINT_RE = 's3.dualstack.(?P<region>[a-z0-9-]+).amazonaws.com'
+
+
+def get_locations():
+  return (Location.EU, Location.EUCentral1, Location.EUWest, Location.EUWest2, Location.CACentral, Location.USEast,
+          Location.USEast2, Location.USWest, Location.USWest2, Location.SAEast, Location.APNortheast,
+          Location.APNortheast2, Location.APSoutheast, Location.APSoutheast2, Location.APSouth, Location.CNNorth1)
 
 
 def get_default_access_key_id():
@@ -50,7 +60,27 @@ def get_default_secret_key():
 
 
 def get_default_region():
-  return AWS_ACCOUNTS['default'].REGION.get() if 'default' in AWS_ACCOUNTS else Location.DEFAULT
+  region = Location.DEFAULT
+
+  if 'default' in AWS_ACCOUNTS:
+    # First check the host/endpoint configuration
+    if AWS_ACCOUNTS['default'].HOST.get():
+      endpoint = AWS_ACCOUNTS['default'].HOST.get()
+      if re.search(SUBDOMAIN_ENDPOINT_RE, endpoint, re.IGNORECASE):
+        region = re.search(SUBDOMAIN_ENDPOINT_RE, endpoint, re.IGNORECASE).group('region')
+      elif re.search(HYPHEN_ENDPOINT_RE, endpoint, re.IGNORECASE):
+        region = re.search(HYPHEN_ENDPOINT_RE, endpoint, re.IGNORECASE).group('region')
+      elif re.search(DUALSTACK_ENDPOINT_RE, endpoint, re.IGNORECASE):
+        region = re.search(DUALSTACK_ENDPOINT_RE, endpoint, re.IGNORECASE).group('region')
+    elif AWS_ACCOUNTS['default'].REGION.get():
+      region = AWS_ACCOUNTS['default'].REGION.get()
+
+    # If the parsed out region is not in the list of supported regions, fallback to the default
+    if region not in get_locations():
+      LOG.warn("Region, %s, not found in the list of supported regions: %s" % (region, ', '.join(get_locations())))
+      region = Location.DEFAULT
+
+  return region
 
 
 AWS_ACCOUNTS = UnspecifiedConfigSection(

+ 49 - 2
desktop/libs/aws/src/aws/s3/s3_test.py

@@ -15,10 +15,12 @@
 # limitations under the License.
 from __future__ import absolute_import
 
-from nose.plugins.skip import SkipTest
-from nose.tools import assert_raises, eq_
+from boto.s3.connection import Location
+from nose.tools import assert_equal, assert_raises, eq_
 
 from aws import s3
+from aws import conf
+from aws.conf import get_default_region
 
 
 def test_parse_uri():
@@ -68,3 +70,48 @@ def test_s3datetime_to_timestamp():
 
   assert_raises(AssertionError, f, 'Thu, 26 Feb 2015 20:42:07 PDT')
   assert_raises(AssertionError, f, '2015-02-26T20:42:07.040Z')
+
+
+def test_get_default_region():
+  # Verify that Hue can infer region from subdomain hosts
+  finish = conf.AWS_ACCOUNTS['default'].HOST.set_for_testing('s3.ap-northeast-2.amazonaws.com')
+  try:
+    assert_equal('ap-northeast-2', get_default_region())
+  finally:
+    if finish:
+      finish()
+
+  # Verify that Hue can infer region from hyphenated hosts
+  finish = conf.AWS_ACCOUNTS['default'].HOST.set_for_testing('s3-ap-south-1.amazonaws.com')
+  try:
+    assert_equal('ap-south-1', get_default_region())
+  finally:
+    if finish:
+      finish()
+
+  # Verify that Hue can infer region from hyphenated hosts
+  finish = conf.AWS_ACCOUNTS['default'].HOST.set_for_testing('s3.dualstack.ap-southeast-2.amazonaws.com')
+  try:
+    assert_equal('ap-southeast-2', get_default_region())
+  finally:
+    if finish:
+      finish()
+
+  # Verify that Hue falls back to the default if the region is not valid
+  finish = conf.AWS_ACCOUNTS['default'].HOST.set_for_testing('s3-external-1.amazonaws.com')
+  try:
+    assert_equal(Location.DEFAULT, get_default_region())
+  finally:
+    if finish:
+      finish()
+
+  # Verify that Hue uses the region if specified
+  finish = [
+    conf.AWS_ACCOUNTS['default'].HOST.set_for_testing(''),
+    conf.AWS_ACCOUNTS['default'].REGION.set_for_testing('ca-central-1'),
+  ]
+  try:
+    assert_equal('ca-central-1', get_default_region())
+  finally:
+    for reset in finish:
+      reset()

+ 2 - 3
desktop/libs/aws/src/aws/s3/s3fs.py

@@ -31,7 +31,7 @@ from boto.s3.prefix import Prefix
 from django.utils.translation import ugettext as _
 
 from aws import s3
-from aws.conf import get_default_region
+from aws.conf import get_default_region, get_locations
 from aws.s3 import normpath, s3file, translate_s3_error, S3A_ROOT
 from aws.s3.s3stat import S3Stat
 
@@ -153,8 +153,7 @@ class S3FileSystem(object):
         raise S3FileSystemException(e.message or e.reason)
 
   def _get_location(self):
-    if get_default_region() in (Location.EU, Location.EUCentral1, Location.CACentral, Location.USWest, Location.USWest2, Location.SAEast,
-                                Location.APNortheast, Location.APSoutheast, Location.APSoutheast2, Location.CNNorth1):
+    if get_default_region() in get_locations():
       return get_default_region()
     else:
       return Location.DEFAULT