|
|
@@ -15,9 +15,9 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-import re
|
|
|
import json
|
|
|
import logging
|
|
|
+import re
|
|
|
from operator import itemgetter
|
|
|
|
|
|
from django.utils.translation import gettext as _
|
|
|
@@ -49,7 +49,7 @@ from beeswax import conf as beeswax_conf, hive_site
|
|
|
from beeswax.conf import CONFIG_WHITELIST, LIST_PARTITIONS_LIMIT, MAX_CATALOG_SQL_ENTRIES
|
|
|
from beeswax.hive_site import hiveserver2_use_ssl
|
|
|
from beeswax.models import HiveServerQueryHandle, HiveServerQueryHistory, Session
|
|
|
-from beeswax.server.dbms import DataTable, InvalidSessionQueryServerException, QueryServerException, Table, reset_ha
|
|
|
+from beeswax.server.dbms import DataTable, InvalidSessionQueryServerException, QueryServerException, reset_ha, Table
|
|
|
from desktop.conf import DEFAULT_USER, ENABLE_X_CSRF_TOKEN_FOR_HIVE_IMPALA, ENABLE_XFF_FOR_HIVE_IMPALA, USE_THRIFT_HTTP_JWT
|
|
|
from desktop.lib import python_util, thrift_util
|
|
|
from notebook.connectors.base import get_interpreter
|
|
|
@@ -413,7 +413,7 @@ class HiveServerDataTable(DataTable):
|
|
|
for row in self.row_set:
|
|
|
try:
|
|
|
yield row.fields()
|
|
|
- except StopIteration as e:
|
|
|
+ except StopIteration:
|
|
|
return # pep-0479: expected Py3.8 generator raised StopIteration
|
|
|
|
|
|
|
|
|
@@ -546,6 +546,30 @@ class HiveServerTColumnDesc(object):
|
|
|
return ttype.userDefinedTypeEntry
|
|
|
|
|
|
|
|
|
+def extract_cookies(connection):
|
|
|
+ if hasattr(connection, 'conf') and hasattr(connection.conf, 'transport_mode') and connection.conf.transport_mode == 'http':
|
|
|
+ http_transport = connection.transport
|
|
|
+ from thrift.transport.TTransport import TBufferedTransport
|
|
|
+ if isinstance(http_transport, TBufferedTransport):
|
|
|
+ http_transport = http_transport._TBufferedTransport__trans
|
|
|
+ if hasattr(http_transport, '_client') and hasattr(http_transport._client, '_cookies'):
|
|
|
+ cookies = http_transport._client._cookies
|
|
|
+ from requests.utils import dict_from_cookiejar
|
|
|
+ return dict_from_cookiejar(cookies) if cookies else {}
|
|
|
+
|
|
|
+
|
|
|
+def set_cookies(connection, cookies):
|
|
|
+ if hasattr(connection, 'conf') and hasattr(connection.conf, 'transport_mode') and connection.conf.transport_mode == 'http':
|
|
|
+ http_transport = connection.transport
|
|
|
+ from thrift.transport.TTransport import TBufferedTransport
|
|
|
+
|
|
|
+ if isinstance(http_transport, TBufferedTransport):
|
|
|
+ http_transport = http_transport._TBufferedTransport__trans
|
|
|
+ if hasattr(http_transport, '_client') and hasattr(http_transport._client, '_cookies'):
|
|
|
+ from requests.utils import cookiejar_from_dict
|
|
|
+ http_transport._client._cookies = cookiejar_from_dict(cookies or {})
|
|
|
+
|
|
|
+
|
|
|
class HiveServerClient(object):
|
|
|
HS2_MECHANISMS = {
|
|
|
'KERBEROS': 'GSSAPI',
|
|
|
@@ -733,6 +757,9 @@ class HiveServerClient(object):
|
|
|
sessionId = res.sessionHandle.sessionId
|
|
|
LOG.info('Session %s opened' % repr(sessionId.guid))
|
|
|
|
|
|
+ cookies = extract_cookies(self._client) or {} # Extract cookies from the response
|
|
|
+ res.configuration['cookies'] = cookies
|
|
|
+
|
|
|
encoded_status, encoded_guid = HiveServerQueryHandle(secret=sessionId.secret, guid=sessionId.guid).get()
|
|
|
properties = json.dumps(res.configuration)
|
|
|
|
|
|
@@ -799,11 +826,22 @@ class HiveServerClient(object):
|
|
|
return self._call_return_result_and_session(fn, req, status=status, session=session)
|
|
|
|
|
|
def _call_return_result_and_session(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
|
|
|
- if hasattr(req, 'sessionHandle') and session:
|
|
|
+ if hasattr(req, 'sessionHandle') and session and not isinstance(req, TCloseOperationReq):
|
|
|
req.sessionHandle = session.get_handle()
|
|
|
+ cookies = session.get_cookies() if session else {}
|
|
|
+ set_cookies(self._client, cookies)
|
|
|
+ LOG.debug('setting cookies for call req: %s, session: %s, cookies: %s' % (req, session, cookies))
|
|
|
|
|
|
res = fn(req)
|
|
|
|
|
|
+ cookies = extract_cookies(self._client) or {} # Extract cookies from the response
|
|
|
+ LOG.debug('storing cookies received from server cookies: %s' % cookies)
|
|
|
+ if hasattr(res, 'configuration') and isinstance(res.configuration, dict):
|
|
|
+ res.configuration['cookies'] = cookies
|
|
|
+ if session:
|
|
|
+ session.set_cookies(cookies)
|
|
|
+ session.save()
|
|
|
+
|
|
|
# Not supported currently in HS2 and Impala: TStatusCode.INVALID_HANDLE_STATUS
|
|
|
if res.status.statusCode == TStatusCode.ERROR_STATUS and \
|
|
|
re.search('Invalid SessionHandle|Invalid session|Client session expired|Could not connect', res.status.errorMessage or '', re.I):
|
|
|
@@ -845,7 +883,8 @@ class HiveServerClient(object):
|
|
|
(res, session) = self.call(self._client.GetSchemas, req)
|
|
|
|
|
|
results, schema = self.fetch_result(
|
|
|
- res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get()
|
|
|
+ res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get(),
|
|
|
+ session=session
|
|
|
)
|
|
|
self._close(res.operationHandle, session)
|
|
|
|
|
|
@@ -885,7 +924,8 @@ class HiveServerClient(object):
|
|
|
|
|
|
while True:
|
|
|
results, schema = self.fetch_result(
|
|
|
- res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get()
|
|
|
+ res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get(),
|
|
|
+ session=session
|
|
|
)
|
|
|
fetched_tables = HiveServerTRowSet(results.results, schema.schema).cols(cols)
|
|
|
table_metadata += fetched_tables
|
|
|
@@ -903,7 +943,8 @@ class HiveServerClient(object):
|
|
|
(res, session) = self.call(self._client.GetTables, req)
|
|
|
|
|
|
results, schema = self.fetch_result(
|
|
|
- res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get()
|
|
|
+ res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=MAX_CATALOG_SQL_ENTRIES.get(),
|
|
|
+ session=session
|
|
|
)
|
|
|
self._close(res.operationHandle, session)
|
|
|
|
|
|
@@ -913,7 +954,7 @@ class HiveServerClient(object):
|
|
|
req = TGetTablesReq(schemaName=database.lower(), tableName=table_name.lower()) # Impala returns empty if not lower case
|
|
|
(res, session) = self.call(self._client.GetTables, req)
|
|
|
|
|
|
- table_results, table_schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT)
|
|
|
+ table_results, table_schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, session=session)
|
|
|
self.close_operation(res.operationHandle)
|
|
|
|
|
|
if partition_spec:
|
|
|
@@ -1003,7 +1044,7 @@ class HiveServerClient(object):
|
|
|
)
|
|
|
|
|
|
if close_operation:
|
|
|
- self.close_operation(operation_handle)
|
|
|
+ self.close_operation(operation_handle, session=session)
|
|
|
|
|
|
return HiveServerDataTable(results, schema, operation_handle, self.query_server, session=session)
|
|
|
|
|
|
@@ -1034,7 +1075,7 @@ class HiveServerClient(object):
|
|
|
req = TExecuteStatementReq(statement=statement, confOverlay=configuration)
|
|
|
(res, session) = self.call(self._client.ExecuteStatement, req, session=session)
|
|
|
|
|
|
- results, schema = self.fetch_result(res.operationHandle, max_rows=max_rows, orientation=orientation)
|
|
|
+ results, schema = self.fetch_result(res.operationHandle, max_rows=max_rows, orientation=orientation, session=session)
|
|
|
return results, schema, res.operationHandle, session
|
|
|
|
|
|
def execute_async_statement(self, statement=None, thrift_function=None, thrift_request=None, conf_overlay=None, session=None):
|
|
|
@@ -1062,39 +1103,39 @@ class HiveServerClient(object):
|
|
|
|
|
|
# Note: An operation_handle is attached to a session. All operations that require operation_handle cannot recover if the session is
|
|
|
# closed. Passing the session is not required
|
|
|
- def fetch_data(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000):
|
|
|
+ def fetch_data(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000, session=None):
|
|
|
# Fetch until the result is empty dues to a HS2 bug instead of looking at hasMoreRows
|
|
|
- results, schema = self.fetch_result(operation_handle, orientation, max_rows)
|
|
|
+ results, schema = self.fetch_result(operation_handle, orientation, max_rows, session=session)
|
|
|
return HiveServerDataTable(results, schema, operation_handle, self.query_server)
|
|
|
|
|
|
- def cancel_operation(self, operation_handle):
|
|
|
+ def cancel_operation(self, operation_handle, session=None):
|
|
|
req = TCancelOperationReq(operationHandle=operation_handle)
|
|
|
- (res, session) = self.call(self._client.CancelOperation, req)
|
|
|
+ (res, session) = self.call(self._client.CancelOperation, req, session=session)
|
|
|
return res
|
|
|
|
|
|
- def close_operation(self, operation_handle):
|
|
|
+ def close_operation(self, operation_handle, session=None):
|
|
|
req = TCloseOperationReq(operationHandle=operation_handle)
|
|
|
- (res, session) = self.call(self._client.CloseOperation, req)
|
|
|
+ (res, session) = self.call(self._client.CloseOperation, req, session=session)
|
|
|
return res
|
|
|
|
|
|
- def fetch_result(self, operation_handle, orientation=TFetchOrientation.FETCH_FIRST, max_rows=1000):
|
|
|
+ def fetch_result(self, operation_handle, orientation=TFetchOrientation.FETCH_FIRST, max_rows=1000, session=None):
|
|
|
if operation_handle.hasResultSet:
|
|
|
fetch_req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows)
|
|
|
- (res, session) = self.call(self._client.FetchResults, fetch_req)
|
|
|
+ (res, session) = self.call(self._client.FetchResults, fetch_req, session=session)
|
|
|
else:
|
|
|
res = TFetchResultsResp(results=TRowSet(startRowOffset=0, rows=[], columns=[]))
|
|
|
|
|
|
if operation_handle.hasResultSet and TFetchOrientation.FETCH_FIRST: # Only fetch for the first call that should be with start_over
|
|
|
meta_req = TGetResultSetMetadataReq(operationHandle=operation_handle)
|
|
|
- (schema, session) = self.call(self._client.GetResultSetMetadata, meta_req)
|
|
|
+ (schema, session) = self.call(self._client.GetResultSetMetadata, meta_req, session=session)
|
|
|
else:
|
|
|
schema = None
|
|
|
|
|
|
return res, schema
|
|
|
|
|
|
- def fetch_log(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000):
|
|
|
+ def fetch_log(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000, session=None):
|
|
|
req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows, fetchType=1)
|
|
|
- (res, session) = self.call(self._client.FetchResults, req)
|
|
|
+ (res, session) = self.call(self._client.FetchResults, req, session=session)
|
|
|
|
|
|
if beeswax_conf.THRIFT_VERSION.get() >= 7:
|
|
|
lines = res.results.columns[0].stringVal.values
|
|
|
@@ -1103,15 +1144,15 @@ class HiveServerClient(object):
|
|
|
|
|
|
return '\n'.join(lines)
|
|
|
|
|
|
- def get_operation_status(self, operation_handle):
|
|
|
+ def get_operation_status(self, operation_handle, session=None):
|
|
|
req = TGetOperationStatusReq(operationHandle=operation_handle)
|
|
|
- (res, session) = self.call(self._client.GetOperationStatus, req)
|
|
|
+ (res, session) = self.call(self._client.GetOperationStatus, req, session=session)
|
|
|
return res
|
|
|
|
|
|
- def get_log(self, operation_handle):
|
|
|
+ def get_log(self, operation_handle, session=None):
|
|
|
try:
|
|
|
req = TGetLogReq(operationHandle=operation_handle)
|
|
|
- (res, session) = self.call(self._client.GetLog, req)
|
|
|
+ (res, session) = self.call(self._client.GetLog, req, session=session)
|
|
|
return res.log
|
|
|
except Exception as e:
|
|
|
if 'Invalid query handle' in str(e) or 'Invalid or unknown query handle' in str(e):
|
|
|
@@ -1127,13 +1168,13 @@ class HiveServerClient(object):
|
|
|
if self.has_close_sessions: # Close session will close all associated operation_handle
|
|
|
self.close_session(session)
|
|
|
else:
|
|
|
- self.close_operation(operation_handle)
|
|
|
+ self.close_operation(operation_handle, session=session)
|
|
|
|
|
|
def get_columns(self, database, table):
|
|
|
req = TGetColumnsReq(schemaName=database, tableName=table)
|
|
|
(res, session) = self.call(self._client.GetColumns, req)
|
|
|
|
|
|
- results, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT)
|
|
|
+ results, schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT, session=session)
|
|
|
self._close(res.operationHandle, session)
|
|
|
|
|
|
return results, schema
|
|
|
@@ -1433,16 +1474,18 @@ class HiveServerClientCompatible(object):
|
|
|
|
|
|
def get_state(self, handle):
|
|
|
operationHandle = handle.get_rpc_handle()
|
|
|
- res = self._client.get_operation_status(operationHandle)
|
|
|
+ session = handle.get_session()
|
|
|
+ res = self._client.get_operation_status(operationHandle, session=session)
|
|
|
return HiveServerQueryHistory.STATE_MAP[res.operationState]
|
|
|
|
|
|
def get_operation_status(self, handle):
|
|
|
operationHandle = handle.get_rpc_handle()
|
|
|
- return self._client.get_operation_status(operationHandle)
|
|
|
+ session = handle.get_session()
|
|
|
+ return self._client.get_operation_status(operationHandle, session=session)
|
|
|
|
|
|
def use(self, query, session=None):
|
|
|
data = self._client.execute_query(query, session=session)
|
|
|
- self._client.close_operation(data.operation_handle)
|
|
|
+ self._client.close_operation(data.operation_handle, session=session)
|
|
|
return data
|
|
|
|
|
|
def explain(self, query):
|
|
|
@@ -1453,6 +1496,7 @@ class HiveServerClientCompatible(object):
|
|
|
|
|
|
def fetch(self, handle, start_over=False, max_rows=None):
|
|
|
operationHandle = handle.get_rpc_handle()
|
|
|
+ session = handle.get_session()
|
|
|
if max_rows is None:
|
|
|
max_rows = 1000
|
|
|
|
|
|
@@ -1461,20 +1505,22 @@ class HiveServerClientCompatible(object):
|
|
|
else:
|
|
|
orientation = TFetchOrientation.FETCH_NEXT
|
|
|
|
|
|
- data_table = self._client.fetch_data(operationHandle, orientation=orientation, max_rows=max_rows)
|
|
|
+ data_table = self._client.fetch_data(operationHandle, orientation=orientation, max_rows=max_rows, session=session)
|
|
|
|
|
|
return ResultCompatible(data_table)
|
|
|
|
|
|
def cancel_operation(self, handle):
|
|
|
operationHandle = handle.get_rpc_handle()
|
|
|
- return self._client.cancel_operation(operationHandle)
|
|
|
+ session = handle.get_session()
|
|
|
+ return self._client.cancel_operation(operationHandle, session=session)
|
|
|
|
|
|
def close(self, handle):
|
|
|
return self.close_operation(handle)
|
|
|
|
|
|
def close_operation(self, handle):
|
|
|
operationHandle = handle.get_rpc_handle()
|
|
|
- return self._client.close_operation(operationHandle)
|
|
|
+ session = handle.get_session()
|
|
|
+ return self._client.close_operation(operationHandle, session=session)
|
|
|
|
|
|
def close_session(self, session):
|
|
|
return self._client.close_session(session)
|
|
|
@@ -1484,16 +1530,17 @@ class HiveServerClientCompatible(object):
|
|
|
|
|
|
def get_log(self, handle, start_over=True):
|
|
|
operationHandle = handle.get_rpc_handle()
|
|
|
+ session = handle.get_session()
|
|
|
|
|
|
if beeswax_conf.USE_GET_LOG_API.get() or self.query_server.get('dialect') == 'impala':
|
|
|
- return self._client.get_log(operationHandle)
|
|
|
+ return self._client.get_log(operationHandle, session=session)
|
|
|
else:
|
|
|
if start_over:
|
|
|
orientation = TFetchOrientation.FETCH_FIRST
|
|
|
else:
|
|
|
orientation = TFetchOrientation.FETCH_NEXT
|
|
|
|
|
|
- return self._client.fetch_log(operationHandle, orientation=orientation, max_rows=-1)
|
|
|
+ return self._client.fetch_log(operationHandle, orientation=orientation, max_rows=-1, session=session)
|
|
|
|
|
|
def get_databases(self, schemaName=None):
|
|
|
col = 'TABLE_SCHEM'
|