Forráskód Böngészése

[impala][beeswax] Store cookies for beeswax http transport

Impala uses cookies for maintaining sticky sessions in active-active
setup.

With this commit,
1. we extract all the cookies from the response,
2. store the cookies in the properties (json) column of sessions
3. When making a call, we pull the cookies from session and set it in
the client.
4. More changes were needed to make sure that session info is flowing
all the way to the spot of making the call.
Amit Srivastava 4 hónapja
szülő
commit
88e04de22f

+ 17 - 25
apps/beeswax/src/beeswax/models.py

@@ -16,10 +16,10 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import ast
 import ast
-import json
 import base64
 import base64
-import logging
 import datetime
 import datetime
+import json
+import logging
 from enum import Enum
 from enum import Enum
 
 
 from django.contrib.contenttypes.fields import GenericRelation
 from django.contrib.contenttypes.fields import GenericRelation
@@ -394,7 +394,7 @@ class SessionManager(models.Manager):
       if filter_open:
       if filter_open:
         q = q.filter(status_code=0)
         q = q.filter(status_code=0)
       return q.latest("last_used")
       return q.latest("last_used")
-    except Session.DoesNotExist as e:
+    except Session.DoesNotExist:
       return None
       return None
 
 
   def get_n_sessions(self, user, n, application='beeswax', filter_open=True):
   def get_n_sessions(self, user, n, application='beeswax', filter_open=True):
@@ -490,6 +490,17 @@ class Session(models.Model):
   def get_formatted_properties(self):
   def get_formatted_properties(self):
     return [dict({'key': key, 'value': value}) for key, value in list(self.get_properties().items())]
     return [dict({'key': key, 'value': value}) for key, value in list(self.get_properties().items())]
 
 
+  def get_cookies(self):
+    return self.get_properties().get('cookies', {})
+
+  def set_cookies(self, cookies):
+    self.set_property('cookies', cookies)
+
+  def set_property(self, key, value):
+    props = self.get_properties()
+    props[key] = value
+    self.properties = json.dumps(props)
+
   def __str__(self):
   def __str__(self):
     return '%s %s' % (self.owner, self.last_used)
     return '%s %s' % (self.owner, self.last_used)
 
 
@@ -543,6 +554,9 @@ class HiveServerQueryHandle(QueryHandle):
         modifiedRowCount=self.modified_row_count
         modifiedRowCount=self.modified_row_count
     )
     )
 
 
+  def get_session(self):
+    return Session.objects.filter(id=self.session_id).first() if self.session_id else None
+
   @classmethod
   @classmethod
   def get_decoded(cls, secret, guid):
   def get_decoded(cls, secret, guid):
     return base64.b64decode(secret), base64.b64decode(guid)
     return base64.b64decode(secret), base64.b64decode(guid)
@@ -551,28 +565,6 @@ class HiveServerQueryHandle(QueryHandle):
     return base64.b64encode(self.secret), base64.b64encode(self.guid)
     return base64.b64encode(self.secret), base64.b64encode(self.guid)
 
 
 
 
-# Deprecated. Could be removed.
-
-class BeeswaxQueryHandle(QueryHandle):
-  """
-  QueryHandle for Beeswax.
-  """
-  def __init__(self, secret, has_result_set, log_context):
-    super(BeeswaxQueryHandle, self).__init__(secret=secret,
-                                             has_result_set=has_result_set,
-                                             log_context=log_context)
-
-  def get(self):
-    return self.secret, None
-
-  def get_rpc_handle(self):
-    return BeeswaxdQueryHandle(id=self.secret, log_context=self.log_context)
-
-  # TODO remove
-  def get_encoded(self):
-    return self.get(), None
-
-
 class MetaInstall(models.Model):
 class MetaInstall(models.Model):
   """
   """
   Metadata about the installation. Should have at most one row.
   Metadata about the installation. Should have at most one row.

+ 82 - 35
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -15,9 +15,9 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import re
 import json
 import json
 import logging
 import logging
+import re
 from operator import itemgetter
 from operator import itemgetter
 
 
 from django.utils.translation import gettext as _
 from django.utils.translation import gettext as _
@@ -49,7 +49,7 @@ from beeswax import conf as beeswax_conf, hive_site
 from beeswax.conf import CONFIG_WHITELIST, LIST_PARTITIONS_LIMIT, MAX_CATALOG_SQL_ENTRIES
 from beeswax.conf import CONFIG_WHITELIST, LIST_PARTITIONS_LIMIT, MAX_CATALOG_SQL_ENTRIES
 from beeswax.hive_site import hiveserver2_use_ssl
 from beeswax.hive_site import hiveserver2_use_ssl
 from beeswax.models import HiveServerQueryHandle, HiveServerQueryHistory, Session
 from beeswax.models import HiveServerQueryHandle, HiveServerQueryHistory, Session
-from beeswax.server.dbms import DataTable, InvalidSessionQueryServerException, QueryServerException, Table, reset_ha
+from beeswax.server.dbms import DataTable, InvalidSessionQueryServerException, QueryServerException, reset_ha, Table
 from desktop.conf import DEFAULT_USER, ENABLE_X_CSRF_TOKEN_FOR_HIVE_IMPALA, ENABLE_XFF_FOR_HIVE_IMPALA, USE_THRIFT_HTTP_JWT
 from desktop.conf import DEFAULT_USER, ENABLE_X_CSRF_TOKEN_FOR_HIVE_IMPALA, ENABLE_XFF_FOR_HIVE_IMPALA, USE_THRIFT_HTTP_JWT
 from desktop.lib import python_util, thrift_util
 from desktop.lib import python_util, thrift_util
 from notebook.connectors.base import get_interpreter
 from notebook.connectors.base import get_interpreter
@@ -413,7 +413,7 @@ class HiveServerDataTable(DataTable):
     for row in self.row_set:
     for row in self.row_set:
       try:
       try:
         yield row.fields()
         yield row.fields()
-      except StopIteration as e:
+      except StopIteration:
         return  # pep-0479: expected Py3.8 generator raised StopIteration
         return  # pep-0479: expected Py3.8 generator raised StopIteration
 
 
 
 
@@ -546,6 +546,30 @@ class HiveServerTColumnDesc(object):
         return ttype.userDefinedTypeEntry
         return ttype.userDefinedTypeEntry
 
 
 
 
+def extract_cookies(connection):
+  if hasattr(connection, 'conf') and hasattr(connection.conf, 'transport_mode') and connection.conf.transport_mode == 'http':
+    http_transport = connection.transport
+    from thrift.transport.TTransport import TBufferedTransport
+    if isinstance(http_transport, TBufferedTransport):
+      http_transport = http_transport._TBufferedTransport__trans
+    if hasattr(http_transport, '_client') and hasattr(http_transport._client, '_cookies'):
+      cookies = http_transport._client._cookies
+      from requests.utils import dict_from_cookiejar
+      return dict_from_cookiejar(cookies) if cookies else {}
+
+
+def set_cookies(connection, cookies):
+  if hasattr(connection, 'conf') and hasattr(connection.conf, 'transport_mode') and connection.conf.transport_mode == 'http':
+    http_transport = connection.transport
+    from thrift.transport.TTransport import TBufferedTransport
+
+    if isinstance(http_transport, TBufferedTransport):
+      http_transport = http_transport._TBufferedTransport__trans
+    if hasattr(http_transport, '_client') and hasattr(http_transport._client, '_cookies'):
+      from requests.utils import cookiejar_from_dict
+      http_transport._client._cookies = cookiejar_from_dict(cookies or {})
+
+
 class HiveServerClient(object):
 class HiveServerClient(object):
   HS2_MECHANISMS = {
   HS2_MECHANISMS = {
       'KERBEROS': 'GSSAPI',
       'KERBEROS': 'GSSAPI',
@@ -733,6 +757,9 @@ class HiveServerClient(object):
     sessionId = res.sessionHandle.sessionId
     sessionId = res.sessionHandle.sessionId
     LOG.info('Session %s opened' % repr(sessionId.guid))
     LOG.info('Session %s opened' % repr(sessionId.guid))
 
 
+    cookies = extract_cookies(self._client) or {}  # Extract cookies from the response
+    res.configuration['cookies'] = cookies
+
     encoded_status, encoded_guid = HiveServerQueryHandle(secret=sessionId.secret, guid=sessionId.guid).get()
     encoded_status, encoded_guid = HiveServerQueryHandle(secret=sessionId.secret, guid=sessionId.guid).get()
     properties = json.dumps(res.configuration)
     properties = json.dumps(res.configuration)
 
 
@@ -799,11 +826,22 @@ class HiveServerClient(object):
     return self._call_return_result_and_session(fn, req, status=status, session=session)
     return self._call_return_result_and_session(fn, req, status=status, session=session)
 
 
   def _call_return_result_and_session(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
   def _call_return_result_and_session(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
-    if hasattr(req, 'sessionHandle') and session:
+    if hasattr(req, 'sessionHandle') and session and not isinstance(req, TCloseOperationReq):
       req.sessionHandle = session.get_handle()
       req.sessionHandle = session.get_handle()
+    cookies = session.get_cookies() if session else {}
+    set_cookies(self._client, cookies)
+    LOG.debug('setting cookies for call req: %s, session: %s, cookies: %s' % (req, session, cookies))
 
 
     res = fn(req)
     res = fn(req)
 
 
+    cookies = extract_cookies(self._client) or {}  # Extract cookies from the response
+    LOG.debug('storing cookies received from server cookies: %s' % cookies)
+    if hasattr(res, 'configuration') and isinstance(res.configuration, dict):
+      res.configuration['cookies'] = cookies
+    if session:
+      session.set_cookies(cookies)
+      session.save()
+
     # Not supported currently in HS2 and Impala: TStatusCode.INVALID_HANDLE_STATUS
     # Not supported currently in HS2 and Impala: TStatusCode.INVALID_HANDLE_STATUS
     if res.status.statusCode == TStatusCode.ERROR_STATUS and \
     if res.status.statusCode == TStatusCode.ERROR_STATUS and \
       re.search('Invalid SessionHandle|Invalid session|Client session expired|Could not connect', res.status.errorMessage or '', re.I):
       re.search('Invalid SessionHandle|Invalid session|Client session expired|Could not connect', res.status.errorMessage or '', re.I):
@@ -845,7 +883,8 @@ class HiveServerClient(object):
     (res, session) = self.call(self._client.GetSchemas, req)
     (res, session) = self.call(self._client.GetSchemas, req)
 
 
     results, schema = self.fetch_result(
     results, schema = self.fetch_result(
-      res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get()
+      res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get(),
+      session=session
     )
     )
     self._close(res.operationHandle, session)
     self._close(res.operationHandle, session)
 
 
@@ -885,7 +924,8 @@ class HiveServerClient(object):
 
 
     while True:
     while True:
       results, schema = self.fetch_result(
       results, schema = self.fetch_result(
-        res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get()
+        res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get(),
+        session=session
       )
       )
       fetched_tables = HiveServerTRowSet(results.results, schema.schema).cols(cols)
       fetched_tables = HiveServerTRowSet(results.results, schema.schema).cols(cols)
       table_metadata += fetched_tables
       table_metadata += fetched_tables
@@ -903,7 +943,8 @@ class HiveServerClient(object):
     (res, session) = self.call(self._client.GetTables, req)
     (res, session) = self.call(self._client.GetTables, req)
 
 
     results, schema = self.fetch_result(
     results, schema = self.fetch_result(
-      res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get()
+      res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get(),
+      session=session
     )
     )
     self._close(res.operationHandle, session)
     self._close(res.operationHandle, session)
 
 
@@ -913,7 +954,7 @@ class HiveServerClient(object):
     req = TGetTablesReq(schemaName=database.lower(), tableName=table_name.lower())  # Impala returns empty if not lower case
     req = TGetTablesReq(schemaName=database.lower(), tableName=table_name.lower())  # Impala returns empty if not lower case
     (res, session) = self.call(self._client.GetTables, req)
     (res, session) = self.call(self._client.GetTables, req)
 
 
-    table_results, table_schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT)
+    table_results, table_schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, session=session)
     self.close_operation(res.operationHandle)
     self.close_operation(res.operationHandle)
 
 
     if partition_spec:
     if partition_spec:
@@ -1003,7 +1044,7 @@ class HiveServerClient(object):
     )
     )
 
 
     if close_operation:
     if close_operation:
-      self.close_operation(operation_handle)
+      self.close_operation(operation_handle, session=session)
 
 
     return HiveServerDataTable(results, schema, operation_handle, self.query_server, session=session)
     return HiveServerDataTable(results, schema, operation_handle, self.query_server, session=session)
 
 
@@ -1034,7 +1075,7 @@ class HiveServerClient(object):
     req = TExecuteStatementReq(statement=statement, confOverlay=configuration)
     req = TExecuteStatementReq(statement=statement, confOverlay=configuration)
     (res, session) = self.call(self._client.ExecuteStatement, req, session=session)
     (res, session) = self.call(self._client.ExecuteStatement, req, session=session)
 
 
-    results, schema = self.fetch_result(res.operationHandle, max_rows=max_rows, orientation=orientation)
+    results, schema = self.fetch_result(res.operationHandle, max_rows=max_rows, orientation=orientation, session=session)
     return results, schema, res.operationHandle, session
     return results, schema, res.operationHandle, session
 
 
   def execute_async_statement(self, statement=None, thrift_function=None, thrift_request=None, conf_overlay=None, session=None):
   def execute_async_statement(self, statement=None, thrift_function=None, thrift_request=None, conf_overlay=None, session=None):
@@ -1062,39 +1103,39 @@ class HiveServerClient(object):
 
 
   # Note: An operation_handle is attached to a session. All operations that require operation_handle cannot recover if the session is
   # Note: An operation_handle is attached to a session. All operations that require operation_handle cannot recover if the session is
   # closed. Passing the session is not required
   # closed. Passing the session is not required
-  def fetch_data(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000):
+  def fetch_data(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000, session=None):
     # Fetch until the result is empty dues to a HS2 bug instead of looking at hasMoreRows
     # Fetch until the result is empty dues to a HS2 bug instead of looking at hasMoreRows
-    results, schema = self.fetch_result(operation_handle, orientation, max_rows)
+    results, schema = self.fetch_result(operation_handle, orientation, max_rows, session=session)
     return HiveServerDataTable(results, schema, operation_handle, self.query_server)
     return HiveServerDataTable(results, schema, operation_handle, self.query_server)
 
 
-  def cancel_operation(self, operation_handle):
+  def cancel_operation(self, operation_handle, session=None):
     req = TCancelOperationReq(operationHandle=operation_handle)
     req = TCancelOperationReq(operationHandle=operation_handle)
-    (res, session) = self.call(self._client.CancelOperation, req)
+    (res, session) = self.call(self._client.CancelOperation, req, session=session)
     return res
     return res
 
 
-  def close_operation(self, operation_handle):
+  def close_operation(self, operation_handle, session=None):
     req = TCloseOperationReq(operationHandle=operation_handle)
     req = TCloseOperationReq(operationHandle=operation_handle)
-    (res, session) = self.call(self._client.CloseOperation, req)
+    (res, session) = self.call(self._client.CloseOperation, req, session=session)
     return res
     return res
 
 
-  def fetch_result(self, operation_handle, orientation=TFetchOrientation.FETCH_FIRST, max_rows=1000):
+  def fetch_result(self, operation_handle, orientation=TFetchOrientation.FETCH_FIRST, max_rows=1000, session=None):
     if operation_handle.hasResultSet:
     if operation_handle.hasResultSet:
       fetch_req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows)
       fetch_req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows)
-      (res, session) = self.call(self._client.FetchResults, fetch_req)
+      (res, session) = self.call(self._client.FetchResults, fetch_req, session=session)
     else:
     else:
       res = TFetchResultsResp(results=TRowSet(startRowOffset=0, rows=[], columns=[]))
       res = TFetchResultsResp(results=TRowSet(startRowOffset=0, rows=[], columns=[]))
 
 
     if operation_handle.hasResultSet and TFetchOrientation.FETCH_FIRST:  # Only fetch for the first call that should be with start_over
     if operation_handle.hasResultSet and TFetchOrientation.FETCH_FIRST:  # Only fetch for the first call that should be with start_over
       meta_req = TGetResultSetMetadataReq(operationHandle=operation_handle)
       meta_req = TGetResultSetMetadataReq(operationHandle=operation_handle)
-      (schema, session) = self.call(self._client.GetResultSetMetadata, meta_req)
+      (schema, session) = self.call(self._client.GetResultSetMetadata, meta_req, session=session)
     else:
     else:
       schema = None
       schema = None
 
 
     return res, schema
     return res, schema
 
 
-  def fetch_log(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000):
+  def fetch_log(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000, session=None):
     req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows, fetchType=1)
     req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows, fetchType=1)
-    (res, session) = self.call(self._client.FetchResults, req)
+    (res, session) = self.call(self._client.FetchResults, req, session=session)
 
 
     if beeswax_conf.THRIFT_VERSION.get() >= 7:
     if beeswax_conf.THRIFT_VERSION.get() >= 7:
       lines = res.results.columns[0].stringVal.values
       lines = res.results.columns[0].stringVal.values
@@ -1103,15 +1144,15 @@ class HiveServerClient(object):
 
 
     return '\n'.join(lines)
     return '\n'.join(lines)
 
 
-  def get_operation_status(self, operation_handle):
+  def get_operation_status(self, operation_handle, session=None):
     req = TGetOperationStatusReq(operationHandle=operation_handle)
     req = TGetOperationStatusReq(operationHandle=operation_handle)
-    (res, session) = self.call(self._client.GetOperationStatus, req)
+    (res, session) = self.call(self._client.GetOperationStatus, req, session=session)
     return res
     return res
 
 
-  def get_log(self, operation_handle):
+  def get_log(self, operation_handle, session=None):
     try:
     try:
       req = TGetLogReq(operationHandle=operation_handle)
       req = TGetLogReq(operationHandle=operation_handle)
-      (res, session) = self.call(self._client.GetLog, req)
+      (res, session) = self.call(self._client.GetLog, req, session=session)
       return res.log
       return res.log
     except Exception as e:
     except Exception as e:
       if 'Invalid query handle' in str(e) or 'Invalid or unknown query handle' in str(e):
       if 'Invalid query handle' in str(e) or 'Invalid or unknown query handle' in str(e):
@@ -1127,13 +1168,13 @@ class HiveServerClient(object):
     if self.has_close_sessions:  # Close session will close all associated operation_handle
     if self.has_close_sessions:  # Close session will close all associated operation_handle
       self.close_session(session)
       self.close_session(session)
     else:
     else:
-      self.close_operation(operation_handle)
+      self.close_operation(operation_handle, session=session)
 
 
   def get_columns(self, database, table):
   def get_columns(self, database, table):
     req = TGetColumnsReq(schemaName=database, tableName=table)
     req = TGetColumnsReq(schemaName=database, tableName=table)
     (res, session) = self.call(self._client.GetColumns, req)
     (res, session) = self.call(self._client.GetColumns, req)
 
 
-    results, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT)
+    results, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, session=session)
     self._close(res.operationHandle, session)
     self._close(res.operationHandle, session)
 
 
     return results, schema
     return results, schema
@@ -1433,16 +1474,18 @@ class HiveServerClientCompatible(object):
 
 
   def get_state(self, handle):
   def get_state(self, handle):
     operationHandle = handle.get_rpc_handle()
     operationHandle = handle.get_rpc_handle()
-    res = self._client.get_operation_status(operationHandle)
+    session = handle.get_session()
+    res = self._client.get_operation_status(operationHandle, session=session)
     return HiveServerQueryHistory.STATE_MAP[res.operationState]
     return HiveServerQueryHistory.STATE_MAP[res.operationState]
 
 
   def get_operation_status(self, handle):
   def get_operation_status(self, handle):
     operationHandle = handle.get_rpc_handle()
     operationHandle = handle.get_rpc_handle()
-    return self._client.get_operation_status(operationHandle)
+    session = handle.get_session()
+    return self._client.get_operation_status(operationHandle, session=session)
 
 
   def use(self, query, session=None):
   def use(self, query, session=None):
     data = self._client.execute_query(query, session=session)
     data = self._client.execute_query(query, session=session)
-    self._client.close_operation(data.operation_handle)
+    self._client.close_operation(data.operation_handle, session=session)
     return data
     return data
 
 
   def explain(self, query):
   def explain(self, query):
@@ -1453,6 +1496,7 @@ class HiveServerClientCompatible(object):
 
 
   def fetch(self, handle, start_over=False, max_rows=None):
   def fetch(self, handle, start_over=False, max_rows=None):
     operationHandle = handle.get_rpc_handle()
     operationHandle = handle.get_rpc_handle()
+    session = handle.get_session()
     if max_rows is None:
     if max_rows is None:
       max_rows = 1000
       max_rows = 1000
 
 
@@ -1461,20 +1505,22 @@ class HiveServerClientCompatible(object):
     else:
     else:
       orientation = TFetchOrientation.FETCH_NEXT
       orientation = TFetchOrientation.FETCH_NEXT
 
 
-    data_table = self._client.fetch_data(operationHandle, orientation=orientation, max_rows=max_rows)
+    data_table = self._client.fetch_data(operationHandle, orientation=orientation, max_rows=max_rows, session=session)
 
 
     return ResultCompatible(data_table)
     return ResultCompatible(data_table)
 
 
   def cancel_operation(self, handle):
   def cancel_operation(self, handle):
     operationHandle = handle.get_rpc_handle()
     operationHandle = handle.get_rpc_handle()
-    return self._client.cancel_operation(operationHandle)
+    session = handle.get_session()
+    return self._client.cancel_operation(operationHandle, session=session)
 
 
   def close(self, handle):
   def close(self, handle):
     return self.close_operation(handle)
     return self.close_operation(handle)
 
 
   def close_operation(self, handle):
   def close_operation(self, handle):
     operationHandle = handle.get_rpc_handle()
     operationHandle = handle.get_rpc_handle()
-    return self._client.close_operation(operationHandle)
+    session = handle.get_session()
+    return self._client.close_operation(operationHandle, session=session)
 
 
   def close_session(self, session):
   def close_session(self, session):
     return self._client.close_session(session)
     return self._client.close_session(session)
@@ -1484,16 +1530,17 @@ class HiveServerClientCompatible(object):
 
 
   def get_log(self, handle, start_over=True):
   def get_log(self, handle, start_over=True):
     operationHandle = handle.get_rpc_handle()
     operationHandle = handle.get_rpc_handle()
+    session = handle.get_session()
 
 
     if beeswax_conf.USE_GET_LOG_API.get() or self.query_server.get('dialect') == 'impala':
     if beeswax_conf.USE_GET_LOG_API.get() or self.query_server.get('dialect') == 'impala':
-      return self._client.get_log(operationHandle)
+      return self._client.get_log(operationHandle, session=session)
     else:
     else:
       if start_over:
       if start_over:
         orientation = TFetchOrientation.FETCH_FIRST
         orientation = TFetchOrientation.FETCH_FIRST
       else:
       else:
         orientation = TFetchOrientation.FETCH_NEXT
         orientation = TFetchOrientation.FETCH_NEXT
 
 
-      return self._client.fetch_log(operationHandle, orientation=orientation, max_rows=-1)
+      return self._client.fetch_log(operationHandle, orientation=orientation, max_rows=-1, session=session)
 
 
   def get_databases(self, schemaName=None):
   def get_databases(self, schemaName=None):
     col = 'TABLE_SCHEM'
     col = 'TABLE_SCHEM'

+ 3 - 4
apps/impala/src/impala/server.py

@@ -15,7 +15,6 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import sys
 import json
 import json
 import logging
 import logging
 import threading
 import threading
@@ -52,7 +51,7 @@ def _get_impala_server_url(session):
     properties = session.get_properties()
     properties = session.get_properties()
     http_addr = properties.get('coordinator_host', properties.get('http_addr'))
     http_addr = properties.get('coordinator_host', properties.get('http_addr'))
 
 
-  http_addr = http_addr.replace('http://', '').replace('https://', '')
+  http_addr = http_addr.replace('http://', '').replace('https://', '').replace('coordinator-int.', '')
   return ('https://' if get_webserver_certificate_file() else 'http://') + http_addr
   return ('https://' if get_webserver_certificate_file() else 'http://') + http_addr
 
 
 
 
@@ -76,7 +75,7 @@ class ImpalaServerClient(HiveServerClient):
     # GetExecSummary() only works for closed queries
     # GetExecSummary() only works for closed queries
     try:
     try:
       self.close_operation(operation_handle)
       self.close_operation(operation_handle)
-    except QueryServerException as e:
+    except QueryServerException:
       LOG.warning('Failed to close operation for query handle, query may be invalid or already closed.')
       LOG.warning('Failed to close operation for query handle, query may be invalid or already closed.')
 
 
     resp = self.call(self._client.GetExecSummary, req)
     resp = self.call(self._client.GetExecSummary, req)
@@ -93,7 +92,7 @@ class ImpalaServerClient(HiveServerClient):
     # TGetRuntimeProfileReq() only works for closed queries
     # TGetRuntimeProfileReq() only works for closed queries
     try:
     try:
       self.close_operation(operation_handle)
       self.close_operation(operation_handle)
-    except QueryServerException as e:
+    except QueryServerException:
       LOG.warning('Failed to close operation for query handle, query may be invalid or already closed.')
       LOG.warning('Failed to close operation for query handle, query may be invalid or already closed.')
 
 
     resp = self.call(self._client.GetRuntimeProfile, req)
     resp = self.call(self._client.GetRuntimeProfile, req)

+ 6 - 6
apps/jobbrowser/src/jobbrowser/apis/query_api.py

@@ -15,18 +15,16 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
+import itertools
+import logging
 import os
 import os
 import re
 import re
-import sys
 import time
 import time
-import logging
-import itertools
 from builtins import filter, range
 from builtins import filter, range
 from datetime import datetime
 from datetime import datetime
 from urllib.parse import urlparse
 from urllib.parse import urlparse
 
 
 import pytz
 import pytz
-from babel import localtime
 from django.utils.translation import gettext as _
 from django.utils.translation import gettext as _
 
 
 from desktop.lib import export_csvxls
 from desktop.lib import export_csvxls
@@ -52,11 +50,13 @@ def _get_api(user, cluster=None):
       compute = Compute.objects.get(id=compute['id']).to_dict()  # Reload the full compute from db
       compute = Compute.objects.get(id=compute['id']).to_dict()  # Reload the full compute from db
     if compute.get('options') and compute['options'].get('api_url'):
     if compute.get('options') and compute['options'].get('api_url'):
       server_url = compute['options'].get('api_url')
       server_url = compute['options'].get('api_url')
+    application = compute.get('name')
   else:
   else:
     # TODO: multi computes if snippet.get('compute') or snippet['type'] has computes
     # TODO: multi computes if snippet.get('compute') or snippet['type'] has computes
     application = cluster['compute']['type'] if cluster.get('compute') else cluster.get('interface', 'impala')
     application = cluster['compute']['type'] if cluster.get('compute') else cluster.get('interface', 'impala')
-    session = Session.objects.get_session(user, application=application)
-    server_url = _get_impala_server_url(session)
+
+  session = Session.objects.get_session(user, application=application)
+  server_url = _get_impala_server_url(session)
   return get_impalad_api(user=user, url=server_url)
   return get_impalad_api(user=user, url=server_url)
 
 
 
 

+ 18 - 20
desktop/libs/notebook/src/notebook/connectors/hiveserver2.py

@@ -15,13 +15,12 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import re
-import sys
+import binascii
 import copy
 import copy
+import importlib.util
 import json
 import json
-import struct
 import logging
 import logging
-import binascii
+import re
 from builtins import next, object
 from builtins import next, object
 from urllib.parse import quote as urllib_quote, unquote as urllib_unquote
 from urllib.parse import quote as urllib_quote, unquote as urllib_unquote
 
 
@@ -30,7 +29,7 @@ from django.utils.translation import gettext as _
 
 
 from beeswax.common import is_compute
 from beeswax.common import is_compute
 from desktop.auth.backend import is_admin
 from desktop.auth.backend import is_admin
-from desktop.conf import USE_DEFAULT_CONFIGURATION, has_connectors
+from desktop.conf import has_connectors, USE_DEFAULT_CONFIGURATION
 from desktop.lib.conf import BoundConfig
 from desktop.lib.conf import BoundConfig
 from desktop.lib.exceptions import StructuredException
 from desktop.lib.exceptions import StructuredException
 from desktop.lib.exceptions_renderable import PopupException
 from desktop.lib.exceptions_renderable import PopupException
@@ -40,47 +39,45 @@ from desktop.lib.rest.http_client import RestException
 from desktop.lib.thrift_util import unpack_guid, unpack_guid_base64
 from desktop.lib.thrift_util import unpack_guid, unpack_guid_base64
 from desktop.models import DefaultConfiguration, Document2
 from desktop.models import DefaultConfiguration, Document2
 from notebook.connectors.base import (
 from notebook.connectors.base import (
+  _get_snippet_name,
   Api,
   Api,
+  get_interpreter,
   Notebook,
   Notebook,
   OperationNotSupported,
   OperationNotSupported,
   OperationTimeout,
   OperationTimeout,
+  patch_snippet_for_connector,
   QueryError,
   QueryError,
   QueryExpired,
   QueryExpired,
-  _get_snippet_name,
-  get_interpreter,
-  patch_snippet_for_connector,
 )
 )
 
 
 LOG = logging.getLogger()
 LOG = logging.getLogger()
 
 
 
 
 try:
 try:
-  from beeswax import conf as beeswax_conf, data_export
   from beeswax.api import _autocomplete, _get_sample_data
   from beeswax.api import _autocomplete, _get_sample_data
   from beeswax.conf import (
   from beeswax.conf import (
     CLOSE_SESSIONS,
     CLOSE_SESSIONS,
     CONFIG_WHITELIST as hive_settings,
     CONFIG_WHITELIST as hive_settings,
     DOWNLOAD_BYTES_LIMIT,
     DOWNLOAD_BYTES_LIMIT,
     DOWNLOAD_ROW_LIMIT,
     DOWNLOAD_ROW_LIMIT,
-    MAX_NUMBER_OF_SESSIONS,
     has_multiple_sessions,
     has_multiple_sessions,
     has_session_pool,
     has_session_pool,
+    MAX_NUMBER_OF_SESSIONS,
   )
   )
   from beeswax.data_export import upload
   from beeswax.data_export import upload
   from beeswax.design import hql_query
   from beeswax.design import hql_query
-  from beeswax.models import QUERY_TYPES, HiveServerQueryHandle, HiveServerQueryHistory, QueryHistory, Session
+  from beeswax.models import HiveServerQueryHandle, HiveServerQueryHistory, QUERY_TYPES, QueryHistory, Session
   from beeswax.server import dbms
   from beeswax.server import dbms
-  from beeswax.server.dbms import QueryServerException, get_query_server_config, reset_ha
+  from beeswax.server.dbms import get_query_server_config, QueryServerException, reset_ha
   from beeswax.views import parse_out_jobs, parse_out_queries
   from beeswax.views import parse_out_jobs, parse_out_queries
 except ImportError as e:
 except ImportError as e:
   LOG.warning('Hive and HiveServer2 interfaces are not enabled: %s' % e)
   LOG.warning('Hive and HiveServer2 interfaces are not enabled: %s' % e)
   hive_settings = None
   hive_settings = None
 
 
-try:
-  from impala import api  # Force checking if Impala is enabled
+if importlib.util.find_spec('impala.api') is not None:
   from impala.conf import CONFIG_WHITELIST as impala_settings
   from impala.conf import CONFIG_WHITELIST as impala_settings
-  from impala.server import ImpalaDaemonApiException, _get_impala_server_url, get_api as get_impalad_api
-except ImportError as e:
+  from impala.server import _get_impala_server_url, get_api as get_impalad_api, ImpalaDaemonApiException
+else:
   LOG.warning("Impala app is not enabled")
   LOG.warning("Impala app is not enabled")
   impala_settings = None
   impala_settings = None
 
 
@@ -91,7 +88,7 @@ try:
   has_query_browser = ENABLE_QUERY_BROWSER.get()
   has_query_browser = ENABLE_QUERY_BROWSER.get()
   has_hive_query_browser = ENABLE_HIVE_QUERY_BROWSER.get()
   has_hive_query_browser = ENABLE_HIVE_QUERY_BROWSER.get()
   has_jobbrowser = True
   has_jobbrowser = True
-except (AttributeError, ImportError, RuntimeError) as e:
+except (AttributeError, ImportError, RuntimeError):
   LOG.warning("Job Browser app is not enabled")
   LOG.warning("Job Browser app is not enabled")
   has_jobbrowser = False
   has_jobbrowser = False
   has_query_browser = False
   has_query_browser = False
@@ -123,11 +120,11 @@ def query_error_handler(func):
 
 
 
 
 def is_hive_enabled():
 def is_hive_enabled():
-  return hive_settings is not None and type(hive_settings) == BoundConfig
+  return hive_settings is not None and isinstance(hive_settings, BoundConfig)
 
 
 
 
 def is_impala_enabled():
 def is_impala_enabled():
-  return impala_settings is not None and type(impala_settings) == BoundConfig
+  return impala_settings is not None and isinstance(impala_settings, BoundConfig)
 
 
 
 
 class HiveConfiguration(object):
 class HiveConfiguration(object):
@@ -783,7 +780,8 @@ DROP TABLE IF EXISTS `%(table)s`;
       LOG.warning('Handle already base 64 decoded')
       LOG.warning('Handle already base 64 decoded')
 
 
     for key in list(handle.keys()):
     for key in list(handle.keys()):
-      if key not in ('log_context', 'secret', 'has_result_set', 'operation_type', 'modified_row_count', 'guid'):
+      if key not in ('log_context', 'secret', 'has_result_set', 'operation_type',
+                     'modified_row_count', 'guid', 'session_id', 'session_guid'):
         handle.pop(key)
         handle.pop(key)
 
 
     return HiveServerQueryHandle(**handle)
     return HiveServerQueryHandle(**handle)