فهرست منبع

HUE-9100 [beeswax] Multi-session support

Jean-Francois Desjeans Gauthier 6 سال پیش
والد
کامیت
bb0a842df7

+ 19 - 1
apps/beeswax/src/beeswax/conf.py

@@ -227,7 +227,9 @@ CLOSE_QUERIES = Config(
 
 MAX_NUMBER_OF_SESSIONS = Config(
   key="max_number_of_sessions",
-  help=_t("Hue will use at most this many HiveServer2 sessions per user at a time"),
+  help=_t("Hue will use at most this many HiveServer2 sessions per user at a time"
+          # The motivation for -1 is that Hue does currently keep track of session state perfectly and the user does not have ability to manage them effectively. The cost of a session is low
+          "-1 is unlimited number of sessions."),
   type=int,
   default=1
 )
@@ -324,3 +326,19 @@ USE_SASL = Config(
   private=False,
   type=coerce_bool,
   dynamic_default=get_use_sasl_default)
+
+def has_multiple_sessions():
+  """When true will create multiple sessions for user queries"""
+  return MAX_NUMBER_OF_SESSIONS.get() != 1
+
+CLOSE_SESSIONS = Config(
+  key="close_sessions",
+  help=_t('When set to True, Hue will close sessions created for background queries and open new ones as needed.'
+          'When set to False, Hue will keep sessions created for background queries opened and reuse them as needed.'
+          'This flag is useful when max_number_of_sessions != 1'),
+  type=coerce_bool,
+  dynamic_default=has_multiple_sessions
+)
+
+def has_session_pool():
+  return has_multiple_sessions() and not CLOSE_SESSIONS.get()

+ 51 - 3
apps/beeswax/src/beeswax/models.py

@@ -32,7 +32,7 @@ from enum import Enum
 from TCLIService.ttypes import TSessionHandle, THandleIdentifier, TOperationState, TOperationHandle, TOperationType
 
 from desktop.lib.exceptions_renderable import PopupException
-from desktop.models import Document
+from desktop.models import Document, Document2
 from desktop.redaction import global_redaction_engine
 from librdbms.server import dbms as librdbms_dbms
 from useradmin.models import User
@@ -404,7 +404,54 @@ class SessionManager(models.Manager):
     q = self.filter(owner=user, application=application).exclude(guid='').exclude(secret='')
     if filter_open:
       q = q.filter(status_code=0)
-    return q.order_by("-last_used")[0:n]
+    q = q.order_by("-last_used")
+    if n > 0:
+      return q[0:n]
+    else:
+      return q
+
+  def get_tez_session(self, user, application, n_sessions):
+    # Get 2 + n_sessions sessions and filter out the busy ones
+    sessions = Session.objects.get_n_sessions(user, n=2 + n_sessions, application=application)
+    LOG.debug('%s sessions found' % len(sessions))
+    if sessions:
+      # Include trashed documents to keep the query lazy
+      # and avoid retrieving all documents
+      docs = Document2.objects.get_history(doc_type='query-hive', user=user, include_trashed=True)
+      busy_sessions = set()
+
+      # Only check last 40 documents for performance
+      for doc in docs[:40]:
+        try:
+          snippet_data = json.loads(doc.data)['snippets'][0]
+        except (KeyError, IndexError):
+          # data might not contain a 'snippets' field or it might be empty
+          LOG.warn('No snippets in Document2 object of type query-hive')
+          continue
+        session_guid = snippet_data.get('result', {}).get('handle', {}).get('session_guid')
+        status = snippet_data.get('status')
+
+        if status in [QueryHistory.STATE.submitted.name, QueryHistory.STATE.running.name]:
+          if session_guid is not None and session_guid not in busy_sessions:
+            busy_sessions.add(session_guid)
+
+      n_busy_sessions = 0
+      available_sessions = []
+      for session in sessions:
+        if session.guid not in busy_sessions:
+          available_sessions.append(session)
+        else:
+          n_busy_sessions += 1
+
+      if n_sessions > 0 and n_busy_sessions == n_sessions:
+        raise Exception('Too many open sessions. Stop a running query before starting a new one')
+
+      if available_sessions:
+        session = available_sessions[0]
+      else:
+        session = None # No available session found
+
+      return session
 
 
 class Session(models.Model):
@@ -439,7 +486,7 @@ class Session(models.Model):
 
 
 class QueryHandle(object):
-  def __init__(self, secret=None, guid=None, operation_type=None, has_result_set=None, modified_row_count=None, log_context=None, session_guid=None):
+  def __init__(self, secret=None, guid=None, operation_type=None, has_result_set=None, modified_row_count=None, log_context=None, session_guid=None, session_id=None):
     self.secret = secret
     self.guid = guid
     self.operation_type = operation_type
@@ -467,6 +514,7 @@ class HiveServerQueryHandle(QueryHandle):
     super(HiveServerQueryHandle, self).__init__(**kwargs)
     self.secret, self.guid = self.get_encoded()
     self.session_guid = kwargs.get('session_guid')
+    self.session_id = kwargs.get('session_id')
 
   def get(self):
     return self.secret, self.guid

+ 12 - 4
apps/beeswax/src/beeswax/server/dbms.py

@@ -42,7 +42,7 @@ from beeswax import hive_site
 from beeswax.conf import HIVE_SERVER_HOST, HIVE_SERVER_PORT, HIVE_SERVER_HOST, HIVE_HTTP_THRIFT_PORT, HIVE_METASTORE_HOST, HIVE_METASTORE_PORT, LIST_PARTITIONS_LIMIT, SERVER_CONN_TIMEOUT, \
   AUTH_USERNAME, AUTH_PASSWORD, APPLY_NATURAL_SORT_MAX, QUERY_PARTITIONS_LIMIT, HIVE_DISCOVERY_HIVESERVER2_ZNODE, \
   HIVE_DISCOVERY_HS2, HIVE_DISCOVERY_LLAP, HIVE_DISCOVERY_LLAP_HA, HIVE_DISCOVERY_LLAP_ZNODE, CACHE_TIMEOUT, \
-  LLAP_SERVER_HOST, LLAP_SERVER_PORT, LLAP_SERVER_THRIFT_PORT, USE_SASL as HIVE_USE_SASL
+  LLAP_SERVER_HOST, LLAP_SERVER_PORT, LLAP_SERVER_THRIFT_PORT, USE_SASL as HIVE_USE_SASL, CLOSE_SESSIONS, has_session_pool, MAX_NUMBER_OF_SESSIONS
 from beeswax.common import apply_natural_sort
 from beeswax.design import hql_query
 from beeswax.hive_site import hiveserver2_use_ssl
@@ -183,7 +183,10 @@ def get_query_server_config(name='beeswax', connector=None):
           'transport_mode': 'http' if hive_site.hiveserver2_transport_mode() == 'HTTP' else 'socket',
           'auth_username': AUTH_USERNAME.get(),
           'auth_password': AUTH_PASSWORD.get(),
-          'use_sasl': HIVE_USE_SASL.get()
+          'use_sasl': HIVE_USE_SASL.get(),
+          'close_sessions': CLOSE_SESSIONS.get(),
+          'has_session_pool': has_session_pool(),
+          'max_number_of_sessions': MAX_NUMBER_OF_SESSIONS.get()
         }
 
     if name == 'sparksql': # Extends Hive as very similar
@@ -235,6 +238,11 @@ class QueryServerException(Exception):
     self.message = message
 
 
+class InvalidSessionQueryServerException(QueryServerException):
+  def __init__(self, e, message=''):
+    super(InvalidSessionQueryServerException, self).__init__(e, message=message)
+
+
 class QueryServerTimeoutException(Exception):
 
   def __init__(self, message=''):
@@ -832,9 +840,9 @@ class HiveServer2Dbms(object):
     return query_history
 
 
-  def use(self, database):
+  def use(self, database, session=None):
     query = hql_query('USE `%s`' % database)
-    return self.client.use(query)
+    return self.client.use(query, session=session)
 
 
   def get_log(self, query_handle, start_over=True):

+ 1 - 1
apps/beeswax/src/beeswax/server/hive_metastore_server.py

@@ -146,7 +146,7 @@ class HiveMetastoreClient(object):
     pass
 
 
-  def query(self, query, statement=0, with_multiple_session=False):
+  def query(self, query, statement=0):
     return HiveServerQueryHandle(secret='mock', guid='mock')
 
 

+ 137 - 140
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -32,7 +32,6 @@ from django.utils.translation import ugettext as _
 
 from desktop.lib import thrift_util
 from desktop.conf import DEFAULT_USER
-from desktop.models import Document2
 from beeswax import conf
 
 from TCLIService import TCLIService
@@ -47,7 +46,7 @@ from beeswax import hive_site
 from beeswax.hive_site import hiveserver2_use_ssl
 from beeswax.conf import CONFIG_WHITELIST, LIST_PARTITIONS_LIMIT
 from beeswax.models import Session, HiveServerQueryHandle, HiveServerQueryHistory, QueryHistory
-from beeswax.server.dbms import Table, DataTable, QueryServerException
+from beeswax.server.dbms import Table, DataTable, QueryServerException, InvalidSessionQueryServerException
 
 
 LOG = logging.getLogger(__name__)
@@ -343,7 +342,7 @@ class HiveServerTColumnValue2(object):
 
 
 class HiveServerDataTable(DataTable):
-  def __init__(self, results, schema, operation_handle, query_server):
+  def __init__(self, results, schema, operation_handle, query_server, session=None):
     self.schema = schema and schema.schema
     self.row_set = HiveServerTRowSet(results.results, schema)
     self.operation_handle = operation_handle
@@ -352,6 +351,7 @@ class HiveServerDataTable(DataTable):
     else:
       self.has_more = not self.row_set.is_empty()    # Should be results.hasMoreRows but always True in HS2
     self.startRowOffset = self.row_set.startRowOffset    # Always 0 in HS2
+    self.session = session
 
   @property
   def ready(self):
@@ -518,6 +518,9 @@ class HiveServerClient(object):
     self.query_server = query_server
     self.user = user
     self.coordinator_host = ''
+    self.has_close_sessions = query_server.get('close_sessions', False)
+    self.has_session_pool = query_server.get('has_session_pool', False)
+    self.max_number_of_sessions = query_server.get('max_number_of_sessions', 1)
 
     use_sasl, mechanism, kerberos_principal_short_name, impersonation_enabled, auth_username, auth_password = self.get_security()
     LOG.info(
@@ -671,7 +674,7 @@ class HiveServerClient(object):
     # TEZ returns properties, but we need the configuration to detect engine
     properties = session.get_properties()
     if not properties or self.query_server['server_name'] == 'beeswax':
-      configuration = self.get_configuration()
+      configuration = self.get_configuration(session=session)
       properties.update(configuration)
       session.properties = json.dumps(properties)
       session.save()
@@ -679,49 +682,59 @@ class HiveServerClient(object):
     return session
 
 
-  def call(self, fn, req, status=TStatusCode.SUCCESS_STATUS, with_multiple_session=False): # Note: with_multiple_session currently ignored
-    (res, session) = self.call_return_result_and_session(fn, req, status, with_multiple_session)
-    return res
-
+  def call(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
+    return self.call_return_result_and_session(fn, req, status, session=session)
 
-  def call_return_result_and_session(self, fn, req, status=TStatusCode.SUCCESS_STATUS, with_multiple_session=False):
-    n_sessions = conf.MAX_NUMBER_OF_SESSIONS.get()
 
-    # When a single session is allowed, avoid multiple session logic
-    with_multiple_session = n_sessions > 1
+  def call_return_result_and_session(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
+    if not hasattr(req, 'sessionHandle'):
+      return self._call_return_result_and_session(fn, req, status=TStatusCode.SUCCESS_STATUS, session=session)
 
-    session = None
+    if session:
+      if session.status_code not in (
+        TStatusCode.SUCCESS_STATUS, TStatusCode.SUCCESS_WITH_INFO_STATUS, TStatusCode.STILL_EXECUTING_STATUS):
+        LOG.info('Retrying with a new session for %s because status is %s' % (self.user, str(session.status_code)))
+        session = None
+      else:
+        try:
+          return self._call_return_result_and_session(fn, req, status=status, session=session)
+        except InvalidSessionQueryServerException as e:
+          LOG.info('Retrying with a new session because for %s of %s' % (self.user, str(e)))
 
-    if not with_multiple_session:
-      # Default behaviour: get one session
+    if self.has_session_pool:
+      session = Session.objects.get_tez_session(self.user, self.query_server['server_name'], self.max_number_of_sessions)
+    elif self.max_number_of_sessions == 1: # Default behaviour: reuse opened session
       session = Session.objects.get_session(self.user, self.query_server['server_name'])
-    else:
-      session = self._get_tez_session(n_sessions)
 
-    if session is None:
-      session = self.open_session(self.user)
+    if session:
+      try:
+        return self._call_return_result_and_session(fn, req, status=status, session=session)
+      except InvalidSessionQueryServerException as e:
+        LOG.info('Retrying with a new session because for %s of %s' % (self.user, str(e)))
+
+    if self.has_close_sessions and self.max_number_of_sessions > 1 and Session.objects.get_n_sessions(self.user, n=self.max_number_of_sessions, application=self.query_server['server_name']) >= self.max_number_of_sessions:
+      raise Exception('Too many open sessions. Stop a running query before starting a new one')
+
+    session = self.open_session(self.user)
+    return self._call_return_result_and_session(fn, req, status=status, session=session)
 
-    if hasattr(req, 'sessionHandle') and req.sessionHandle is None:
+
+  def _call_return_result_and_session(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
+    if hasattr(req, 'sessionHandle') and session:
       req.sessionHandle = session.get_handle()
 
     res = fn(req)
 
     # Not supported currently in HS2 and Impala: TStatusCode.INVALID_HANDLE_STATUS
     if res.status.statusCode == TStatusCode.ERROR_STATUS and \
-        re.search('Invalid SessionHandle|Invalid session|Client session expired', res.status.errorMessage or '', re.I):
-      LOG.info('Retrying with a new session because for %s of %s' % (self.user, res))
-      session.status_code = TStatusCode.INVALID_HANDLE_STATUS
-      session.save()
-
-      session = self.open_session(self.user)
-
-      req.sessionHandle = session.get_handle()
-
-      # Get back the name of the function to call
-      res = getattr(self._client, fn.attr)(req)
+      re.search('Invalid SessionHandle|Invalid session|Client session expired', res.status.errorMessage or '', re.I):
+      if session:
+        session.status_code = TStatusCode.INVALID_HANDLE_STATUS
+        session.save()
+      raise InvalidSessionQueryServerException(Exception('Invalid Session %s:\n%s' % (req, res)))
 
     if status is not None and res.status.statusCode not in (
-        TStatusCode.SUCCESS_STATUS, TStatusCode.SUCCESS_WITH_INFO_STATUS, TStatusCode.STILL_EXECUTING_STATUS):
+      TStatusCode.SUCCESS_STATUS, TStatusCode.SUCCESS_WITH_INFO_STATUS, TStatusCode.STILL_EXECUTING_STATUS):
       if hasattr(res.status, 'errorMessage') and res.status.errorMessage:
         message = res.status.errorMessage
       else:
@@ -731,53 +744,17 @@ class HiveServerClient(object):
       return (res, session)
 
 
-  def _get_tez_session(self, n_sessions):
-    # Get 2 + n_sessions sessions and filter out the busy ones
-    sessions = Session.objects.get_n_sessions(self.user, n=2 + n_sessions, application=self.query_server['server_name'])
-    LOG.debug('%s sessions found' % len(sessions))
-    if sessions:
-      # Include trashed documents to keep the query lazy
-      # and avoid retrieving all documents
-      docs = Document2.objects.get_history(doc_type='query-hive', user=self.user, include_trashed=True)
-      busy_sessions = set()
-
-      # Only check last 40 documents for performance
-      for doc in docs[:40]:
-        try:
-          snippet_data = json.loads(doc.data)['snippets'][0]
-        except (KeyError, IndexError):
-          # data might not contain a 'snippets' field or it might be empty
-          LOG.warn('No snippets in Document2 object of type query-hive')
-          continue
-        session_guid = snippet_data.get('result', {}).get('handle', {}).get('session_guid')
-        status = snippet_data.get('status')
-
-        if status in [QueryHistory.STATE.submitted.name, QueryHistory.STATE.running.name]:
-          if session_guid is not None and session_guid not in busy_sessions:
-            busy_sessions.add(session_guid)
-
-      n_busy_sessions = 0
-      available_sessions = []
-      for session in sessions:
-        if session.guid not in busy_sessions:
-          available_sessions.append(session)
-        else:
-          n_busy_sessions += 1
-
-      if n_busy_sessions == n_sessions:
-        raise Exception('Too many open sessions. Stop a running query before starting a new one')
-
-      if available_sessions:
-        session = available_sessions[0]
-      else:
-        session = None # No available session found
-
-      return session
-
-
-  def close_session(self, sessionHandle):
-    req = TCloseSessionReq(sessionHandle=sessionHandle)
-    return self._client.CloseSession(req)
+  def close_session(self, session):
+    req = TCloseSessionReq(sessionHandle=session.get_handle())
+    try:
+      res = self._client.CloseSession(req)
+      session.status_code = TStatusCode.INVALID_HANDLE_STATUS
+      session.save()
+      return res
+    except Exception as e:
+      session.status_code = TStatusCode.ERROR_STATUS
+      session.save()
+      raise e
 
 
   def get_databases(self, schemaName=None):
@@ -788,10 +765,10 @@ class HiveServerClient(object):
     if self.query_server['server_name'].startswith('impala'):
       req.schemaName = None
 
-    res = self.call(self._client.GetSchemas, req)
+    (res, session) = self.call(self._client.GetSchemas, req)
 
     results, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=5000)
-    self.close_operation(res.operationHandle)
+    self._close(res.operationHandle, session)
 
     col = 'TABLE_SCHEM'
     return HiveServerTRowSet(results.results, schema.schema).cols((col,))
@@ -800,8 +777,8 @@ class HiveServerClient(object):
   def get_database(self, database):
     query = 'DESCRIBE DATABASE EXTENDED `%s`' % (database)
 
-    (desc_results, desc_schema), operation_handle = self.execute_statement(query, max_rows=5000, orientation=TFetchOrientation.FETCH_NEXT)
-    self.close_operation(operation_handle)
+    desc_results, desc_schema, operation_handle, session = self.execute_statement(query, max_rows=5000, orientation=TFetchOrientation.FETCH_NEXT)
+    self._close(operation_handle, session)
 
     if self.query_server['server_name'].startswith('impala'):
       cols = ('name', 'location', 'comment') # Skip owner as on a new line
@@ -821,10 +798,10 @@ class HiveServerClient(object):
     if not table_types:
       table_types = self.DEFAULT_TABLE_TYPES
     req = TGetTablesReq(schemaName=database, tableName=table_names, tableTypes=table_types)
-    res = self.call(self._client.GetTables, req)
+    (res, session) = self.call(self._client.GetTables, req)
 
     results, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=5000)
-    self.close_operation(res.operationHandle)
+    self._close(res.operationHandle, session)
 
     cols = ('TABLE_NAME', 'TABLE_TYPE', 'REMARKS')
     return HiveServerTRowSet(results.results, schema.schema).cols(cols)
@@ -834,17 +811,17 @@ class HiveServerClient(object):
     if not table_types:
       table_types = self.DEFAULT_TABLE_TYPES
     req = TGetTablesReq(schemaName=database, tableName=table_names, tableTypes=table_types)
-    res = self.call(self._client.GetTables, req)
+    (res, session) = self.call(self._client.GetTables, req)
 
     results, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=5000)
-    self.close_operation(res.operationHandle)
+    self._close(res.operationHandle, session)
 
     return HiveServerTRowSet(results.results, schema.schema).cols(('TABLE_NAME',))
 
 
   def get_table(self, database, table_name, partition_spec=None):
     req = TGetTablesReq(schemaName=database.lower(), tableName=table_name.lower()) # Impala returns empty if not lower case
-    res = self.call(self._client.GetTables, req)
+    (res, session) = self.call(self._client.GetTables, req)
 
     table_results, table_schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT)
     self.close_operation(res.operationHandle)
@@ -855,22 +832,22 @@ class HiveServerClient(object):
       query = 'DESCRIBE FORMATTED `%s`.`%s`' % (database, table_name)
 
     try:
-      (desc_results, desc_schema), operation_handle = self.execute_statement(query, max_rows=10000, orientation=TFetchOrientation.FETCH_NEXT)
+      desc_results, desc_schema, operation_handle, session = self.execute_statement(query, max_rows=10000, orientation=TFetchOrientation.FETCH_NEXT, session=session)
       self.close_operation(operation_handle)
     except Exception as e:
       ex_string = str(e)
       if 'cannot find field' in ex_string: # Workaround until Hive 2.0 and HUE-3751
-        (desc_results, desc_schema), operation_handle = self.execute_statement('USE `%s`' % database)
+        desc_results, desc_schema, operation_handle, session = self.execute_statement('USE `%s`' % database, session=session)
         self.close_operation(operation_handle)
         if partition_spec:
           query = 'DESCRIBE FORMATTED `%s` PARTITION(%s)' % (table_name, partition_spec)
         else:
           query = 'DESCRIBE FORMATTED `%s`' % table_name
-        (desc_results, desc_schema), operation_handle = self.execute_statement(query, max_rows=10000, orientation=TFetchOrientation.FETCH_NEXT)
+        desc_results, desc_schema, operation_handle, session = self.execute_statement(query, max_rows=10000, orientation=TFetchOrientation.FETCH_NEXT, session=session)
         self.close_operation(operation_handle)
       elif 'not have privileges for DESCTABLE' in ex_string or 'AuthorizationException' in ex_string: # HUE-5608: No table permission but some column permissions
         query = 'DESCRIBE `%s`.`%s`' % (database, table_name)
-        (desc_results, desc_schema), operation_handle = self.execute_statement(query, max_rows=10000, orientation=TFetchOrientation.FETCH_NEXT)
+        desc_results, desc_schema, operation_handle, session = self.execute_statement(query, max_rows=10000, orientation=TFetchOrientation.FETCH_NEXT, session=session)
         self.close_operation(operation_handle)
 
         desc_results.results.columns[0].stringVal.values.insert(0, '# col_name')
@@ -895,32 +872,35 @@ class HiveServerClient(object):
           desc_results.results.columns[2].stringVal.values.append(None)
       else:
         raise e
+    finally:
+      if self.has_close_sessions:
+        self.close_session(session)
 
     return HiveServerTable(table_results.results, table_schema.schema, desc_results.results, desc_schema.schema)
 
 
-  def execute_query(self, query, max_rows=1000):
+  def execute_query(self, query, max_rows=1000, session=None):
     configuration = self._get_query_configuration(query)
-    return self.execute_query_statement(statement=query.query['query'], max_rows=max_rows, configuration=configuration)
+    return self.execute_query_statement(statement=query.query['query'], max_rows=max_rows, configuration=configuration, session=session)
 
 
-  def execute_query_statement(self, statement, max_rows=1000, configuration=None, orientation=TFetchOrientation.FETCH_FIRST, close_operation=False):
+  def execute_query_statement(self, statement, max_rows=1000, configuration=None, orientation=TFetchOrientation.FETCH_FIRST, close_operation=False, session=None):
     if configuration is None:
       configuration = {}
-    (results, schema), operation_handle = self.execute_statement(statement=statement, max_rows=max_rows, configuration=configuration, orientation=orientation)
+    results, schema, operation_handle, session = self.execute_statement(statement=statement, max_rows=max_rows, configuration=configuration, orientation=orientation, session=session)
 
     if close_operation:
       self.close_operation(operation_handle)
 
-    return HiveServerDataTable(results, schema, operation_handle, self.query_server)
+    return HiveServerDataTable(results, schema, operation_handle, self.query_server, session=session)
 
 
-  def execute_async_query(self, query, statement=0, with_multiple_session=False):
+  def execute_async_query(self, query, statement=0, session=None):
     if statement == 0:
       # Impala just has settings currently
       if self.query_server['server_name'] == 'beeswax':
         for resource in query.get_configuration_statements():
-          self.execute_statement(resource.strip())
+          self.execute_statement(resource.strip(), session=session)
 
     configuration = {}
 
@@ -931,27 +911,28 @@ class HiveServerClient(object):
     configuration.update(self._get_query_configuration(query))
     query_statement = query.get_query_statement(statement)
 
-    return self.execute_async_statement(statement=query_statement, confOverlay=configuration, with_multiple_session=with_multiple_session)
+    return self.execute_async_statement(statement=query_statement, confOverlay=configuration, session=session)
 
 
-  def execute_statement(self, statement, max_rows=1000, configuration=None, orientation=TFetchOrientation.FETCH_NEXT):
+  def execute_statement(self, statement, max_rows=1000, configuration=None, orientation=TFetchOrientation.FETCH_NEXT, session=None):
     if configuration is None:
       configuration = {}
     if self.query_server['server_name'].startswith('impala') and self.query_server['QUERY_TIMEOUT_S'] > 0:
       configuration['QUERY_TIMEOUT_S'] = str(self.query_server['QUERY_TIMEOUT_S'])
 
     req = TExecuteStatementReq(statement=statement.encode('utf-8'), confOverlay=configuration)
-    res = self.call(self._client.ExecuteStatement, req)
+    (res, session) = self.call(self._client.ExecuteStatement, req, session=session)
 
-    return self.fetch_result(res.operationHandle, max_rows=max_rows, orientation=orientation), res.operationHandle
+    results, schema = self.fetch_result(res.operationHandle, max_rows=max_rows, orientation=orientation)
+    return results, schema, res.operationHandle, session
 
 
-  def execute_async_statement(self, statement, confOverlay, with_multiple_session=False):
+  def execute_async_statement(self, statement, confOverlay, session=None):
     if self.query_server['server_name'].startswith('impala') and self.query_server['QUERY_TIMEOUT_S'] > 0:
       confOverlay['QUERY_TIMEOUT_S'] = str(self.query_server['QUERY_TIMEOUT_S'])
 
     req = TExecuteStatementReq(statement=statement.encode('utf-8'), confOverlay=confOverlay, runAsync=True)
-    (res, session) = self.call_return_result_and_session(self._client.ExecuteStatement, req, with_multiple_session=with_multiple_session)
+    (res, session) = self.call_return_result_and_session(self._client.ExecuteStatement, req, session=session)
 
     return HiveServerQueryHandle(
         secret=res.operationHandle.operationId.secret,
@@ -959,10 +940,11 @@ class HiveServerClient(object):
         operation_type=res.operationHandle.operationType,
         has_result_set=res.operationHandle.hasResultSet,
         modified_row_count=res.operationHandle.modifiedRowCount,
-        session_guid=session.guid
+        session_guid=thrift_util.unpack_guid(session.get_handle().sessionId.guid),
+        session_id=session.id
     )
 
-
+  # Note: An operation_handle is attached to a session. All operations that require operation_handle cannot recover if the session is closed. Passing the session is not required
   def fetch_data(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000):
     # Fetch until the result is empty dues to a HS2 bug instead of looking at hasMoreRows
     results, schema = self.fetch_result(operation_handle, orientation, max_rows)
@@ -971,34 +953,26 @@ class HiveServerClient(object):
 
   def cancel_operation(self, operation_handle):
     req = TCancelOperationReq(operationHandle=operation_handle)
-    return self.call(self._client.CancelOperation, req)
+    (res, session) = self.call(self._client.CancelOperation, req)
+    return res
 
 
   def close_operation(self, operation_handle):
     req = TCloseOperationReq(operationHandle=operation_handle)
-    return self.call(self._client.CloseOperation, req)
-
-
-  def get_columns(self, database, table):
-    req = TGetColumnsReq(schemaName=database, tableName=table)
-    res = self.call(self._client.GetColumns, req)
-
-    res, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT)
-    self.close_operation(res.operationHandle)
-
-    return res, schema
+    (res, session) = self.call(self._client.CloseOperation, req)
+    return res
 
 
   def fetch_result(self, operation_handle, orientation=TFetchOrientation.FETCH_FIRST, max_rows=1000):
     if operation_handle.hasResultSet:
       fetch_req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows)
-      res = self.call(self._client.FetchResults, fetch_req)
+      (res, session) = self.call(self._client.FetchResults, fetch_req)
     else:
       res = TFetchResultsResp(results=TRowSet(startRowOffset=0, rows=[], columns=[]))
 
     if operation_handle.hasResultSet and TFetchOrientation.FETCH_FIRST: # Only fetch for the first call that should be with start_over
       meta_req = TGetResultSetMetadataReq(operationHandle=operation_handle)
-      schema = self.call(self._client.GetResultSetMetadata, meta_req)
+      (schema, session) = self.call(self._client.GetResultSetMetadata, meta_req)
     else:
       schema = None
 
@@ -1007,7 +981,7 @@ class HiveServerClient(object):
 
   def fetch_log(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000):
     req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows, fetchType=1)
-    res = self.call(self._client.FetchResults, req)
+    (res, session) = self.call(self._client.FetchResults, req)
 
     if beeswax_conf.THRIFT_VERSION.get() >= 7:
       lines = res.results.columns[0].stringVal.values
@@ -1019,18 +993,14 @@ class HiveServerClient(object):
 
   def get_operation_status(self, operation_handle):
     req = TGetOperationStatusReq(operationHandle=operation_handle)
-    return self.call(self._client.GetOperationStatus, req)
-
-  def explain(self, query):
-    query_statement = query.get_query_statement(0)
-    configuration = self._get_query_configuration(query)
-    return self.execute_query_statement(statement='EXPLAIN %s' % query_statement, configuration=configuration, orientation=TFetchOrientation.FETCH_NEXT)
+    (res, session) = self.call(self._client.GetOperationStatus, req)
+    return res
 
 
   def get_log(self, operation_handle):
     try:
       req = TGetLogReq(operationHandle=operation_handle)
-      res = self.call(self._client.GetLog, req)
+      (res, session) = self.call(self._client.GetLog, req)
       return res.log
     except Exception as e:
       if 'Invalid query handle' in str(e):
@@ -1043,7 +1013,29 @@ class HiveServerClient(object):
       return message
 
 
-  def get_partitions(self, database, table_name, partition_spec=None, max_parts=None, reverse_sort=True):
+  def _close(self, operation_handle, session):
+    if self.has_close_sessions: # Close session will close all associated operation_handle
+      self.close_session(session)
+    else:
+      self.close_operation(operation_handle)
+
+
+  def get_columns(self, database, table):
+    req = TGetColumnsReq(schemaName=database, tableName=table)
+    (res, session) = self.call(self._client.GetColumns, req)
+
+    results, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT)
+    self._close(res.operationHandle, session)
+
+    return results, schema
+
+
+  def explain(self, query):
+    query_statement = query.get_query_statement(0)
+    configuration = self._get_query_configuration(query)
+    return self.execute_query_statement(statement='EXPLAIN %s' % query_statement, configuration=configuration, orientation=TFetchOrientation.FETCH_NEXT)
+
+  def get_partitions(self, database, table_name, partition_spec=None, max_parts=None, reverse_sort=True): #TODO execute both requests in same session
     table = self.get_table(database, table_name)
 
     query = 'SHOW PARTITIONS `%s`.`%s`' % (database, table_name)
@@ -1054,11 +1046,17 @@ class HiveServerClient(object):
     # Need to fetch more like this until SHOW PARTITIONS offers a LIMIT and ORDER BY
     partition_table = self.execute_query_statement(query, max_rows=10000, orientation=TFetchOrientation.FETCH_NEXT, close_operation=True)
 
+    if self.has_close_sessions:
+      self.close_session(partition_table.session)
+
     if self.query_server['server_name'].startswith('impala'):
       try:
         # Fetch all partition key names, which are listed before the #Rows column
         cols = [col.name for col in partition_table.cols()]
-        stop = cols.index('#Rows')
+        try:
+          stop = cols.index('#Rows')
+        except ValueError:
+          stop = -1
         partition_keys = cols[:stop]
         num_parts = len(partition_keys)
 
@@ -1090,16 +1088,16 @@ class HiveServerClient(object):
     return partitions[:max_parts]
 
 
-  def get_configuration(self):
+  def get_configuration(self, session=None):
     configuration = {}
 
     if self.query_server['server_name'].startswith('impala'):  # Return all configuration settings
       query = 'SET'
-      results = self.execute_query_statement(query, orientation=TFetchOrientation.FETCH_NEXT, close_operation=True)
+      results = self.execute_query_statement(query, orientation=TFetchOrientation.FETCH_NEXT, close_operation=True, session=session)
       configuration = dict((row[0], row[1]) for row in results.rows())
     else:  # For Hive, only return white-listed configurations
       query = 'SET -v'
-      results = self.execute_query_statement(query, orientation=TFetchOrientation.FETCH_FIRST, max_rows=-1, close_operation=True)
+      results = self.execute_query_statement(query, orientation=TFetchOrientation.FETCH_FIRST, max_rows=-1, close_operation=True, session=session)
       config_whitelist = [config.lower() for config in CONFIG_WHITELIST.get()]
       properties = [(row[0].split('=')[0], row[0].split('=')[1]) for row in results.rows() if '=' in row[0]]
       configuration = dict((prop, value) for prop, value in properties if prop.lower() in config_whitelist)
@@ -1224,8 +1222,8 @@ class HiveServerClientCompatible(object):
     self.query_server = client.query_server
 
 
-  def query(self, query, statement=0, with_multiple_session=False):
-    return self._client.execute_async_query(query, statement, with_multiple_session=with_multiple_session)
+  def query(self, query, statement=0, session=None):
+    return self._client.execute_async_query(query, statement, session=session)
 
 
   def get_state(self, handle):
@@ -1239,8 +1237,8 @@ class HiveServerClientCompatible(object):
     return self._client.get_operation_status(operationHandle)
 
 
-  def use(self, query):
-    data = self._client.execute_query(query)
+  def use(self, query, session=None):
+    data = self._client.execute_query(query, session=session)
     self._client.close_operation(data.operation_handle)
     return data
 
@@ -1282,8 +1280,7 @@ class HiveServerClientCompatible(object):
 
 
   def close_session(self, session):
-    operationHandle = session.get_handle()
-    return self._client.close_session(operationHandle)
+    return self._client.close_session(session)
 
 
   def dump_config(self):

+ 179 - 3
apps/beeswax/src/beeswax/server/hive_server2_lib_tests.py

@@ -24,15 +24,18 @@ if sys.version_info[0] > 2:
 else:
   from mock import patch, Mock, MagicMock
 
-from nose.tools import assert_equal, assert_true
+from nose.tools import assert_equal, assert_true, assert_raises, assert_not_equal
+from nose.plugins.skip import SkipTest
 from TCLIService.ttypes import TStatusCode
 
+from beeswax.conf import MAX_NUMBER_OF_SESSIONS, CLOSE_SESSIONS
+from beeswax.models import Session
+from beeswax.server.hive_server2_lib import HiveServerTable, HiveServerClient
+
 from desktop.lib.django_test_util import make_logged_in_client
 from desktop.lib.test_utils import grant_access
 from useradmin.models import User
 
-from beeswax.models import Session
-from beeswax.server.hive_server2_lib import HiveServerTable, HiveServerClient
 
 
 LOG = logging.getLogger(__name__)
@@ -350,3 +353,176 @@ class TestHiveServerTable():
       assert_equal(table.primary_keys[1].name, 'id2')
       assert_equal(table.primary_keys[1].type, 'NULL')
       assert_equal(table.primary_keys[1].comment, 'NULL')
+
+class SessionTest():
+  def test_call_session_single(self):
+    finish = (MAX_NUMBER_OF_SESSIONS.set_for_testing(1),
+                CLOSE_SESSIONS.set_for_testing(False))
+    try:
+      with patch('beeswax.server.hive_server2_lib.thrift_util.get_client') as get_client:
+        with patch('beeswax.server.hive_server2_lib.HiveServerClient.open_session') as open_session:
+          with patch('beeswax.server.hive_server2_lib.Session.objects.get_session') as get_session:
+            open_session.return_value = MagicMock(status_code=0)
+            get_session.return_value = None
+            fn = MagicMock(attr='test')
+            req = MagicMock()
+
+            client = HiveServerClient(MagicMock(), MagicMock())
+            (res, session1) = client.call(fn, req, status=None)
+            open_session.assert_called_once()
+
+            # Reuse session from argument
+            (res, session2) = client.call(fn, req, status=None, session=session1)
+            open_session.assert_called_once() # open_session should not be called again, because we're reusing session
+            assert_equal(session1, session2)
+
+            # Reuse session from get_session
+            get_session.return_value = session1
+            (res, session3) = client.call(fn, req, status=None)
+            open_session.assert_called_once() # open_session should not be called again, because we're reusing session
+            assert_equal(session1, session3)
+    finally:
+      for f in finish:
+        f()
+
+  def test_call_session_pool(self):
+    finish = (MAX_NUMBER_OF_SESSIONS.set_for_testing(2),
+                CLOSE_SESSIONS.set_for_testing(False))
+    try:
+      with patch('beeswax.server.hive_server2_lib.thrift_util.get_client') as get_client:
+        with patch('beeswax.server.hive_server2_lib.HiveServerClient.open_session') as open_session:
+          with patch('beeswax.server.hive_server2_lib.Session.objects.get_tez_session') as get_session:
+            open_session.return_value = MagicMock(status_code=0)
+            get_session.return_value = None
+            fn = MagicMock(return_value=MagicMock(status=MagicMock(statusCode=0)))
+            req = MagicMock()
+
+            client = HiveServerClient(MagicMock(), MagicMock())
+            (res, session1) = client.call(fn, req, status=None)
+            open_session.assert_called_once()
+
+            # Reuse session from argument
+            (res, session2) = client.call(fn, req, status=None, session=session1)
+            open_session.assert_called_once() # open_session should not be called again, because we're reusing session
+            assert_equal(session1, session2)
+
+            # Reuse session from get_session
+            get_session.return_value = session1
+            (res, session3) = client.call(fn, req, status=None)
+            open_session.assert_called_once() # open_session should not be called again, because we're reusing session
+            assert_equal(session1, session3)
+    finally:
+      for f in finish:
+        f()
+
+  def test_call_session_pool_limit(self):
+    finish = (MAX_NUMBER_OF_SESSIONS.set_for_testing(2),
+                CLOSE_SESSIONS.set_for_testing(False))
+    try:
+      with patch('beeswax.server.hive_server2_lib.thrift_util.get_client') as get_client:
+        with patch('beeswax.server.hive_server2_lib.HiveServerClient.open_session') as open_session:
+          with patch('beeswax.server.hive_server2_lib.Session.objects.get_tez_session') as get_tez_session:
+            get_tez_session.side_effect=Exception('')
+            open_session.return_value = MagicMock(status_code=0)
+            fn = MagicMock(return_value=MagicMock(status=MagicMock(statusCode=0)))
+            req = MagicMock()
+            client = HiveServerClient(MagicMock(), MagicMock())
+            assert_raises(Exception, client.call, fn, req, status=None)
+    finally:
+      for f in finish:
+        f()
+
+  def test_call_session_close_idle(self):
+    finish = (MAX_NUMBER_OF_SESSIONS.set_for_testing(-1),
+                CLOSE_SESSIONS.set_for_testing(True))
+    try:
+      with patch('beeswax.server.hive_server2_lib.thrift_util.get_client') as get_client:
+        with patch('beeswax.server.hive_server2_lib.HiveServerClient.open_session') as open_session:
+          open_session.return_value = MagicMock(status_code=0)
+          fn = MagicMock(return_value=MagicMock(status=MagicMock(statusCode=0)))
+          req = MagicMock()
+
+          client = HiveServerClient(MagicMock(), MagicMock())
+          (res, session1) = client.call(fn, req, status=None)
+          open_session.assert_called_once()
+
+          # Reuse session from argument
+          (res, session2) = client.call(fn, req, status=None, session=session1)
+          open_session.assert_called_once() # open_session should not be called again, because we're reusing session
+          assert_equal(session1, session2)
+
+          # Create new session
+          open_session.return_value = MagicMock(status_code=0)
+          (res, session3) = client.call(fn, req, status=None)
+          assert_equal(open_session.call_count, 2)
+          assert_not_equal(session1, session3)
+    finally:
+      for f in finish:
+        f()
+
+  def test_call_session_close_idle_managed_queries(self):
+    finish = (MAX_NUMBER_OF_SESSIONS.set_for_testing(-1),
+                CLOSE_SESSIONS.set_for_testing(True))
+    try:
+      with patch('beeswax.server.hive_server2_lib.thrift_util.get_client') as get_client:
+        with patch('beeswax.server.hive_server2_lib.HiveServerClient.open_session') as open_session:
+          with patch('beeswax.server.hive_server2_lib.HiveServerClient.close_session') as close_session:
+            with patch('beeswax.server.hive_server2_lib.HiveServerTRowSet') as HiveServerTRowSet:
+              status = MagicMock(status=MagicMock(statusCode=0))
+              status_return = MagicMock(return_value=status)
+              get_client.return_value = MagicMock(return_value=status, GetSchemas=status_return, FetchResults=status_return, GetResultSetMetadata=status_return, CloseOperation=status_return, ExecuteStatement=status_return, GetTables=status_return, GetColumns=status_return)
+
+              open_session.return_value = MagicMock(status_code=0)
+              client = HiveServerClient(MagicMock(), MagicMock())
+
+              res = client.get_databases()
+              assert_equal(open_session.call_count, 1)
+              assert_equal(close_session.call_count, 1)
+
+              res = client.get_database(MagicMock())
+              assert_equal(open_session.call_count, 2)
+              assert_equal(close_session.call_count, 2)
+
+              res = client.get_tables_meta(MagicMock(), MagicMock())
+              assert_equal(open_session.call_count, 3)
+              assert_equal(close_session.call_count, 3)
+
+              res = client.get_tables(MagicMock(), MagicMock())
+              assert_equal(open_session.call_count, 4)
+              assert_equal(close_session.call_count, 4)
+
+              res = client.get_table(MagicMock(), MagicMock())
+              assert_equal(open_session.call_count, 5)
+              assert_equal(close_session.call_count, 5)
+
+              res = client.get_columns(MagicMock(), MagicMock())
+              assert_equal(open_session.call_count, 6)
+              assert_equal(close_session.call_count, 6)
+
+              res = client.get_partitions(MagicMock(), MagicMock()) # get_partitions does 2 requests with 1 session each
+              assert_equal(open_session.call_count, 8)
+              assert_equal(close_session.call_count, 8)
+    finally:
+      for f in finish:
+        f()
+
+  def test_call_session_close_idle_limit(self):
+    finish = (MAX_NUMBER_OF_SESSIONS.set_for_testing(2),
+                CLOSE_SESSIONS.set_for_testing(True))
+    try:
+      with patch('beeswax.server.hive_server2_lib.thrift_util.get_client') as get_client:
+        with patch('beeswax.server.hive_server2_lib.HiveServerClient.open_session') as open_session:
+          with patch('beeswax.server.hive_server2_lib.Session.objects.get_n_sessions') as get_n_sessions:
+            get_n_sessions.return_value = [MagicMock(), MagicMock()]
+            open_session.return_value = MagicMock(status_code=0)
+            fn = MagicMock(return_value=MagicMock(status=MagicMock(statusCode=0)))
+            req = MagicMock()
+            client = HiveServerClient(MagicMock(), MagicMock())
+            assert_raises(Exception, client.call, fn, req, status=None)
+
+            get_n_sessions.return_value = [MagicMock()]
+            (res, session1) = client.call(fn, req, status=None)
+            open_session.assert_called_once()
+    finally:
+      for f in finish:
+        f()

+ 6 - 0
desktop/conf.dist/hue.ini

@@ -1253,8 +1253,14 @@
 
   # Hue will use at most this many HiveServer2 sessions per user at a time.
   # For Tez, increase the number to more if you need more than one query at the time, e.g. 2 or 3 (Tez has a maximum of 1 query by session).
+  # -1 is unlimited number of sessions.
   ## max_number_of_sessions=1
 
+  # When set to True, Hue will close sessions created for background queries and open new ones as needed.
+  # When set to False, Hue will keep sessions created for background queries opened and reuse them as needed.
+  # This flag is useful when max_number_of_sessions != 1
+  ## close_sessions=max_number_of_sessions != 1
+
   # Thrift version to use when communicating with HiveServer2.
   # Version 11 comes with Hive 3.0. If issues, try 7.
   ## thrift_version=11

+ 6 - 0
desktop/conf/pseudo-distributed.ini.tmpl

@@ -1237,8 +1237,14 @@
 
   # Hue will use at most this many HiveServer2 sessions per user at a time.
   # For Tez, increase the number to more if you need more than one query at the time, e.g. 2 or 3 (Tez has a maximum of 1 query by session).
+  # -1 is unlimited number of sessions.
   ## max_number_of_sessions=1
 
+  # When set to True, Hue will close sessions created for background queries and open new ones as needed.
+  # When set to False, Hue will keep sessions created for background queries opened and reuse them as needed.
+  # This flag is useful when max_number_of_sessions != 1
+  ## close_sessions=max_number_of_sessions != 1
+
   # Thrift version to use when communicating with HiveServer2.
   # Version 11 comes with Hive 3.0. If issues, try 7.
   ## thrift_version=11

+ 12 - 18
desktop/libs/notebook/src/notebook/api.py

@@ -130,7 +130,7 @@ def _execute_notebook(request, notebook, snippet):
 
   try:
     try:
-      session = notebook.get('sessions') and notebook['sessions'][0] # Session reference for snippet execution without persisting it
+      sessions = notebook.get('sessions') and notebook['sessions'] # Session reference for snippet execution without persisting it
 
       active_executable = json.loads(request.POST.get('executable', '{}')) # Editor v2
 
@@ -142,10 +142,14 @@ def _execute_notebook(request, notebook, snippet):
 
       interpreter = get_api(request, snippet)
       if snippet.get('interface') == 'sqlalchemy':
-        interpreter.options['session'] = session
+        interpreter.options['session'] = sessions[0]
 
       with opentracing.tracer.start_span('interpreter') as span:
+        # interpreter.execute needs the sessions, but we don't want to persist them
+        pre_execute_sessions = notebook['sessions']
+        notebook['sessions'] = sessions
         response['handle'] = interpreter.execute(notebook, snippet)
+        notebook['sessions'] = pre_execute_sessions
 
       # Retrieve and remove the result from the handle
       if response['handle'].get('sync'):
@@ -626,27 +630,17 @@ def close_notebook(request):
 
   notebook = json.loads(request.POST.get('notebook', '{}'))
 
-  for session in [_s for _s in notebook['sessions'] if _s['type'] in ('scala', 'spark', 'pyspark', 'sparkr', 'r')]:
+  for session in [_s for _s in notebook['sessions']]:
     try:
-      response['result'].append(get_api(request, session).close_session(session))
-    except QueryExpired:
-      pass
-    except Exception as e:
-      LOG.exception('Error closing session %s' % str(e))
-
-  for snippet in [_s for _s in notebook['snippets'] if _s['type'] in ('hive', 'impala')]:
-    try:
-      if snippet['status'] != 'running':
-        response['result'].append(get_api(request, snippet).close_statement(notebook, snippet))
+      api = get_api(request, session)
+      if hasattr(api, 'close_session_idle'):
+        response['result'].append(api.close_session_idle(notebook, session))
       else:
-        LOG.info('Not closing SQL snippet as still running.')
+        response['result'].append(api.close_session(session))
     except QueryExpired:
       pass
     except Exception as e:
-      LOG.exception('Error closing statement %s' % str(e))
-
-  response['status'] = 0
-  response['message'] = _('Notebook closed successfully')
+      LOG.exception('Error closing session %s' % str(e))
 
   return JsonResponse(response)
 

+ 49 - 6
desktop/libs/notebook/src/notebook/connectors/hiveserver2.py

@@ -56,7 +56,7 @@ LOG = logging.getLogger(__name__)
 try:
   from beeswax import conf as beeswax_conf, data_export
   from beeswax.api import _autocomplete, _get_sample_data
-  from beeswax.conf import CONFIG_WHITELIST as hive_settings, DOWNLOAD_ROW_LIMIT, DOWNLOAD_BYTES_LIMIT
+  from beeswax.conf import CONFIG_WHITELIST as hive_settings, DOWNLOAD_ROW_LIMIT, DOWNLOAD_BYTES_LIMIT, MAX_NUMBER_OF_SESSIONS, has_session_pool, has_multiple_sessions, CLOSE_SESSIONS
   from beeswax.data_export import upload
   from beeswax.design import hql_query
   from beeswax.models import QUERY_TYPES, HiveServerQueryHandle, HiveServerQueryHistory, QueryHistory, Session
@@ -178,7 +178,12 @@ class HS2Api(Api):
   def create_session(self, lang='hive', properties=None):
     application = 'beeswax' if lang == 'hive' or lang =='llap' else lang
 
-    session = Session.objects.get_session(self.user, application=application)
+    if has_session_pool():
+      session = Session.objects.get_tez_session(self.user, application, MAX_NUMBER_OF_SESSIONS.get())
+    elif not has_multiple_sessions():
+      session = Session.objects.get_session(self.user, application=application)
+    else:
+      session = None
 
     reuse_session = session is not None
     if not reuse_session:
@@ -251,6 +256,30 @@ class HS2Api(Api):
 
     return response
 
+  def close_session_idle(self, notebook, session):
+    idle = True
+    response = {'result': []}
+    for snippet in [_s for _s in notebook['snippets'] if _s['type'] == session['type']]:
+      try:
+        if snippet['status'] != 'running':
+          response['result'].append(self.close_statement(notebook, snippet))
+        else:
+          idle = False
+          LOG.info('Not closing SQL snippet as still running.')
+      except QueryExpired:
+        pass
+      except Exception as e:
+        LOG.exception('Error closing statement %s' % str(e))
+
+    try:
+      if idle and CLOSE_SESSIONS.get():
+        response['result'].append(self.close_session(session))
+    except QueryExpired:
+      pass
+    except Exception as e:
+      LOG.exception('Error closing statement %s' % str(e))
+
+    return response['result']
 
   @query_error_handler
   def execute(self, notebook, snippet):
@@ -260,12 +289,15 @@ class HS2Api(Api):
     session = self._get_session(notebook, snippet['type'])
 
     query = self._prepare_hql_query(snippet, statement['statement'], session)
+    _session = self._get_session_by_id(notebook, snippet['type'])
 
     try:
-      if statement.get('statement_id') == 0:
+      if statement.get('statement_id') == 0: # TODO: move this to client
         if query.database and not statement['statement'].lower().startswith('set'):
-          db.use(query.database)
-      handle = db.client.query(query, with_multiple_session=True) # Note: with_multiple_session currently ignored
+          result = db.use(query.database, session=_session)
+          if result.session:
+            _session = result.session
+      handle = db.client.query(query, session=_session)
     except QueryServerException as ex:
       raise QueryError(ex.message, handle=statement)
 
@@ -282,7 +314,8 @@ class HS2Api(Api):
       'has_result_set': handle.has_result_set,
       'modified_row_count': handle.modified_row_count,
       'log_context': handle.log_context,
-      'session_guid': handle.session_guid
+      'session_guid': handle.session_guid,
+      'session_id': handle.session_id
     }
     response.update(statement)
 
@@ -643,6 +676,16 @@ DROP TABLE IF EXISTS `%(table)s`;
     session = next((session for session in notebook['sessions'] if session['type'] == type), None)
     return session
 
+  def _get_session_by_id(self, notebook, type='hive'):
+    session = self._get_session(notebook, type)
+    if session:
+      session_id = session.get('id')
+      if session_id:
+        filters = {'id': session_id, 'application': 'beeswax' if type == 'hive' else type}
+        if not is_admin(self.user):
+          filters['owner'] = self.user
+        return Session.objects.get(**filters)
+
 
   def _get_hive_execution_engine(self, notebook, snippet):
     # Get hive.execution.engine from snippet properties, if none, then get from session