Browse Source

HUE-8339 [impala] Smart thrift connection pool for Impala

Chris Conner 7 years ago
parent
commit
bd05df617b

+ 4 - 1
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -482,6 +482,7 @@ class HiveServerClient:
   def __init__(self, query_server, user):
     self.query_server = query_server
     self.user = user
+    self.coordinator_host = ''
 
     use_sasl, mechanism, kerberos_principal_short_name, impersonation_enabled, auth_username, auth_password = self.get_security()
     LOG.info(
@@ -539,7 +540,8 @@ class HiveServerClient:
         certfile=certfile,
         validate=validate,
         transport_mode=query_server.get('transport_mode', 'socket'),
-        http_url=query_server.get('http_url', '')
+        http_url=query_server.get('http_url', ''),
+        coordinator_host=self.coordinator_host
     )
 
 
@@ -602,6 +604,7 @@ class HiveServerClient:
 
     req = TOpenSessionReq(**kwargs)
     res = self._client.OpenSession(req)
+    self.coordinator_host = self._client.get_coordinator_host()
 
     if res.status is not None and res.status.statusCode not in (TStatusCode.SUCCESS_STATUS,):
       if hasattr(res.status, 'errorMessage') and res.status.errorMessage:

+ 46 - 11
desktop/core/src/desktop/lib/thrift_util.py

@@ -93,7 +93,8 @@ class ConnectionConfig(object):
                transport='buffered',
                multiple=False,
                transport_mode='socket',
-               http_url=''):
+               http_url='',
+               coordinator_host=''):
     """
     @param klass The thrift client class
     @param host Host to connect to
@@ -116,6 +117,7 @@ class ConnectionConfig(object):
     @param multiple Whether Use MultiplexedProtocol
     @param transport_mode Can be socket or http
     @param Url used when using http transport mode
+    @param Host for Impala coordinator to create coordinator specific pool
     """
     self.klass = klass
     self.host = host
@@ -136,11 +138,18 @@ class ConnectionConfig(object):
     self.multiple = multiple
     self.transport_mode = transport_mode
     self.http_url = http_url
+    self.coordinator_host = coordinator_host
 
   def __str__(self):
     return ', '.join(map(str, [self.klass, self.host, self.port, self.service_name, self.use_sasl, self.kerberos_principal, self.timeout_seconds,
                                self.mechanism, self.username, self.use_ssl, self.ca_certs, self.keyfile, self.certfile, self.validate, self.transport,
-                               self.multiple, self.transport_mode, self.http_url]))
+                               self.multiple, self.transport_mode, self.http_url, self.coordinator_host]))
+
+  def update_coordinator_host(self, coordinator_host):
+    self.coordinator_host = coordinator_host
+
+  def get_coordinator_host(self):
+    return self.coordinator_host
 
 class ConnectionPooler(object):
   """
@@ -164,13 +173,16 @@ class ConnectionPooler(object):
     self.poolsize = poolsize
     self.dictlock = threading.Lock()
 
-  def get_client(self, conf, get_client_timeout=None):
-    """
-    Could block while we wait for the pool to become non-empty.
+  def create_pool_impala(self, conf):
+    self.dictlock.acquire()
+    try:
+      if _get_pool_key(conf) not in self.pooldict:
+        q = LifoQueue(self.poolsize)
+        self.pooldict[_get_pool_key(conf)] = q
+    finally:
+      self.dictlock.release()
 
-    @param get_client_timeout: how long (in seconds) to wait on the pool
-                               to get a client before failing
-    """
+  def create_pool(self, conf):
     # First up, check to see if we have a pool for this endpoint
     if _get_pool_key(conf) not in self.pooldict:
       # Uh-oh, we need to initialise the queue. Take the dict lock.
@@ -206,11 +218,20 @@ class ConnectionPooler(object):
       finally:
         self.dictlock.release()
 
+  def get_client(self, conf, get_client_timeout=None):
+    """
+    Could block while we wait for the pool to become non-empty.
+
+    @param get_client_timeout: how long (in seconds) to wait on the pool
+                               to get a client before failing
+    """
     connection = None
 
     start_pool_get_time = time.time()
     has_waited_for = 0
 
+    self.create_pool(conf)
+
     while connection is None:
       if get_client_timeout is not None:
         this_round_timeout = max(min(get_client_timeout - has_waited_for, 1), 0)
@@ -229,7 +250,7 @@ class ConnectionPooler(object):
           raise socket.timeout(
             ("Timed out after %.2f seconds waiting to retrieve a %s client from the pool.") % (has_waited_for, conf.service_name))
         else:
-          message = "Waited %d seconds for a Thrift client to %s:%d" % (has_waited_for, conf.host, conf.port)
+          message = "Waited %d seconds for a Thrift client to %s:%d %s" % (has_waited_for, conf.host, conf.port, conf.get_coordinator_host())
           log_if_slow_call(duration=has_waited_for, message=message)
 
     return connection
@@ -240,6 +261,12 @@ class ConnectionPooler(object):
     pass back a client that was not retrieved from a pool, and
     you might well get an exception for doing so.
     """
+    if client.get_coordinator_host() is not None:
+      conf.update_coordinator_host(client.get_coordinator_host())
+      self.create_pool_impala(conf)
+    if client.get_coordinator_host() is not None and client.get_coordinator_host() != conf.get_coordinator_host():
+      conf.update_coordinator_host(client.get_coordinator_host())
+
     self.pooldict[_get_pool_key(conf)].put(client)
 
 def _get_pool_key(conf):
@@ -247,7 +274,7 @@ def _get_pool_key(conf):
   Given a ConnectionConfig, return the tuple used as the key in the dictionary
   of connections by the ConnectionPooler class.
   """
-  return (conf.klass, conf.host, conf.port)
+  return (conf.klass, conf.host, conf.port, conf.get_coordinator_host())
 
 def construct_superclient(conf):
   """
@@ -411,10 +438,14 @@ class SuperClient(object):
   TODO(todd): get this into the Thrift lib
   """
 
-  def __init__(self, wrapped_client, transport, timeout_seconds=None):
+  def __init__(self, wrapped_client, transport, timeout_seconds=None, coordinator_host=None):
     self.wrapped = wrapped_client
     self.transport = transport
     self.timeout_seconds = timeout_seconds
+    self.coordinator_host = coordinator_host
+
+  def get_coordinator_host(self):
+    return self.coordinator_host
 
   def __getattr__(self, attr):
     if attr in self.__dict__:
@@ -439,6 +470,10 @@ class SuperClient(object):
           logging.debug("Thrift call: %s.%s(args=%s, kwargs=%s)" % (str(self.wrapped.__class__), attr, str_args, repr(kwargs)))
 
           ret = res(*args, **kwargs)
+          if 'http_addr' in repr(ret):
+            coordinator_host = re.search('http_addr\':\ \'(.*:[0-9]{2,})\', \'', repr(ret))
+            self.coordinator_host = coordinator_host.group(1)
+
           log_msg = _unpack_guid_secret_in_handle(repr(ret))
 
           # Truncate log message, increase output in DEBUG mode