Просмотр исходного кода

[beeswax] fixed the incorrect set_nulls in hive results

The HS2 thrift api returns two fields for every column.
1. `values`: an array of the values for the column in sequential order
2. `nulls`: A binary field which is actually a bit array, with a `1` on
`nth` position indicates that `values[n]` should be considered as null.

In our unit tests we have assumed the null columns to be string in hex,
which is not the case. It is actually a bitarray showing up as hex when
printed.

This change skips the transformation with `.decode('utf-8')` in
`HiveServerTColumnValue2.set_nulls()` and passes the original value to
find the mask.

Change-Id: If9583b92f527fc780ea22af28673d584e2f30ca1
Amit Srivastava 1 год назад
Родитель
Сommit
ba387d49e2
2 измененных файлов с 105 добавлено и 257 удалено
  1. 29 90
      apps/beeswax/src/beeswax/server/hive_server2_lib.py
  2. 76 167
      apps/beeswax/src/beeswax/tests.py

+ 29 - 90
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -60,7 +60,7 @@ class HiveServerTable(Table):
       if not table_results.columns:
         raise QueryServerException('No table columns')
       self.table = table_results.columns
-    else: # Deprecated. To remove in Hue 4.
+    else:  # Deprecated. To remove in Hue 4.
       if not table_results.rows:
         raise QueryServerException('No table rows')
       self.table = table_results.rows and table_results.rows[0] or ''
@@ -68,7 +68,7 @@ class HiveServerTable(Table):
     self.table_schema = table_schema
     self.desc_results = desc_results
     self.desc_schema = desc_schema
-    self.is_impala_only = False # Aka Kudu
+    self.is_impala_only = False  # Aka Kudu
 
     self.describe = HiveServerTTableSchema(self.desc_results, self.desc_schema).cols()
     self._details = None
@@ -85,7 +85,7 @@ class HiveServerTable(Table):
   def partition_keys(self):
     try:
       return [PartitionKeyCompatible(row['col_name'], row['data_type'], row['comment']) for row in self._get_partition_columns()]
-    except:
+    except Exception:
       LOG.exception('failed to get partition keys')
       return []
 
@@ -99,7 +99,7 @@ class HiveServerTable(Table):
       rows = [row for row in rows if row['col_name'].startswith('Location:')]
       if rows:
         return rows[0]['data_type']
-    except:
+    except Exception:
       LOG.exception('failed to get path location')
       return None
 
@@ -119,18 +119,18 @@ class HiveServerTable(Table):
       return rows[col_row_index:][:end_cols_index] + self._get_partition_columns()
     except ValueError:  # DESCRIBE on nested columns does not always contain additional rows beyond cols
       return rows[col_row_index:]
-    except:
+    except Exception:
       return rows
 
   def _get_partition_columns(self):
     rows = self.describe
     try:
       col_row_index = list(map(itemgetter('col_name'), rows)).index('# Partition Information') + 2
-      if rows[col_row_index]['col_name'] == '': # Impala has a blank line
+      if rows[col_row_index]['col_name'] == '':  # Impala has a blank line
         col_row_index += 1
       end_cols_index = list(map(itemgetter('col_name'), rows[col_row_index:])).index('')
       return rows[col_row_index:][:end_cols_index]
-    except:
+    except Exception:
       # Not partitioned
       return []
 
@@ -198,7 +198,7 @@ class HiveServerTable(Table):
       col_row_index = list(map(itemgetter('col_name'), rows)).index('Table Parameters:') + 1
       end_cols_index = list(map(itemgetter('data_type'), rows[col_row_index:])).index(None)
       return rows[col_row_index:][:end_cols_index]
-    except:
+    except Exception:
       LOG.exception('Table stats could not be retrieved')
       return []
 
@@ -291,14 +291,14 @@ class HiveServerTRow2(object):
   def col(self, colName):
     pos = self._get_col_position(colName)
     try:
-      return HiveServerTColumnValue2(self.cols[pos]).val[0] # Return only first element
-    except:
+      return HiveServerTColumnValue2(self.cols[pos]).val[0]  # Return only first element
+    except Exception:
       # Bug with SparkSql
       return ''
 
   def full_col(self, colName):
     pos = self._get_col_position(colName)
-    return HiveServerTColumnValue2(self.cols[pos]).val # Return the full column and its values
+    return HiveServerTColumnValue2(self.cols[pos]).val  # Return the full column and its values
 
   def _get_col_position(self, column_name):
     return list(filter(lambda i_col1: i_col1[1].columnName == column_name, enumerate(self.schema.columns)))[0][0]
@@ -337,7 +337,7 @@ class HiveServerTColumnValue2(object):
   @classmethod
   def _get_val(cls, column):
     column.values = cls.set_nulls(column.values, column.nulls)
-    column.nulls = '' # Clear the null values for not re-marking again the column with nulls at the next call
+    column.nulls = ''  # Clear the null values for not re-marking again the column with nulls at the next call
     return column.values
 
   @classmethod
@@ -360,19 +360,20 @@ class HiveServerTColumnValue2(object):
       yield n & 0x80
 
   @classmethod
-  def set_nulls(cls, values, bytestring):
+  def set_nulls(cls, values, nulls):
     can_decode = True
+    bytestring = nulls
     if sys.version_info[0] == 3 and isinstance(bytestring, bytes):
       try:
         bytestring = bytestring.decode('utf-8')
-      except:
+      except Exception:
         can_decode = False
 
-    if bytestring == '' or (can_decode and re.match('^(\x00)+$', bytestring)): # HS2 has just \x00 or '', Impala can have \x00\x00...
+    if bytestring == '' or (can_decode and re.match('^(\x00)+$', bytestring)):  # HS2 has just \x00 or '', Impala can have \x00\x00...
       return values
     else:
-      _values = [None if is_null else value for value, is_null in zip(values, cls.mark_nulls(values, bytestring))]
-      if len(values) != len(_values): # HS2 can have just \x00\x01 instead of \x00\x01\x00...
+      _values = [None if is_null else value for value, is_null in zip(values, cls.mark_nulls(values, nulls))]
+      if len(values) != len(_values):  # HS2 can have just \x00\x01 instead of \x00\x01\x00...
         _values.extend(values[len(_values):])
       return _values
 
@@ -410,7 +411,6 @@ class HiveServerDataTable(DataTable):
           raise e
 
 
-
 class HiveServerTTableSchema(object):
   def __init__(self, columns, schema):
     self.columns = columns
@@ -419,7 +419,7 @@ class HiveServerTTableSchema(object):
   def cols(self):
     try:
       return HiveServerTRowSet(self.columns, self.schema).cols(('col_name', 'data_type', 'comment'))
-    except:
+    except Exception:
       # Impala API is different
       cols = HiveServerTRowSet(self.columns, self.schema).cols(('name', 'type', 'comment'))
       for col in cols:
@@ -628,7 +628,6 @@ class HiveServerClient(object):
         coordinator_host=self.coordinator_host
     )
 
-
   def get_security(self):
     principal = self.query_server['principal']
     impersonation_enabled = False
@@ -658,7 +657,6 @@ class HiveServerClient(object):
 
     return use_sasl, mechanism, kerberos_principal_short_name, impersonation_enabled, auth_username, auth_password
 
-
   def open_session(self, user):
     self.user = user
     kwargs = {
@@ -688,13 +686,13 @@ class HiveServerClient(object):
       if csrf_header and ENABLE_X_CSRF_TOKEN_FOR_HIVE_IMPALA.get():
         kwargs['configuration'].update({'X-CSRF-TOKEN': csrf_header})
 
-    if self.query_server['server_name'] == 'hplsql' or interpreter_dialect == 'hplsql': # All the time
+    if self.query_server['server_name'] == 'hplsql' or interpreter_dialect == 'hplsql':  # All the time
       kwargs['configuration'].update({'hive.server2.proxy.user': user.username, 'set:hivevar:mode': 'HPLSQL'})
 
-    if self.query_server['server_name'] == 'llap': # All the time
+    if self.query_server['server_name'] == 'llap':  # All the time
       kwargs['configuration'].update({'hive.server2.proxy.user': user.username})
 
-    if self.query_server['server_name'] == 'sparksql': # All the time
+    if self.query_server['server_name'] == 'sparksql':  # All the time
       kwargs['configuration'].update({'hive.server2.proxy.user': user.username})
 
     if self.query_server.get('dialect') == 'impala' and self.query_server['SESSION_TIMEOUT_S'] > 0:
@@ -753,11 +751,9 @@ class HiveServerClient(object):
 
     return session
 
-
   def call(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
     return self.call_return_result_and_session(fn, req, status, session=session)
 
-
   def call_return_result_and_session(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
     if not hasattr(req, 'sessionHandle'):
       return self._call_return_result_and_session(fn, req, status=TStatusCode.SUCCESS_STATUS, session=session)
@@ -775,7 +771,7 @@ class HiveServerClient(object):
 
     if self.has_session_pool:
       session = Session.objects.get_tez_session(self.user, self.query_server['server_name'], self.max_number_of_sessions)
-    elif self.max_number_of_sessions == 1: # Default behaviour: reuse opened session
+    elif self.max_number_of_sessions == 1:  # Default behaviour: reuse opened session
       session = Session.objects.get_session(self.user, self.query_server['server_name'])
 
     if session:
@@ -796,7 +792,6 @@ class HiveServerClient(object):
 
     return self._call_return_result_and_session(fn, req, status=status, session=session)
 
-
   def _call_return_result_and_session(self, fn, req, status=TStatusCode.SUCCESS_STATUS, session=None):
     if hasattr(req, 'sessionHandle') and session:
       req.sessionHandle = session.get_handle()
@@ -821,7 +816,6 @@ class HiveServerClient(object):
     else:
       return (res, session)
 
-
   def close_session(self, session):
     req = TCloseSessionReq(sessionHandle=session.get_handle())
     try:
@@ -834,7 +828,6 @@ class HiveServerClient(object):
       session.save()
       raise e
 
-
   def get_databases(self, schemaName=None):
     # GetCatalogs() is not implemented in HS2
     req = TGetSchemasReq()
@@ -853,7 +846,6 @@ class HiveServerClient(object):
     col = 'TABLE_SCHEM'
     return HiveServerTRowSet(results.results, schema.schema).cols((col,))
 
-
   def get_database(self, database):
     query = 'DESCRIBE DATABASE EXTENDED `%s`' % (database)
 
@@ -863,7 +855,7 @@ class HiveServerClient(object):
     self._close(operation_handle, session)
 
     if self.query_server.get('dialect') == 'impala':
-      cols = ('name', 'location', 'comment') # Skip owner as on a new line
+      cols = ('name', 'location', 'comment')  # Skip owner as on a new line
     else:
       cols = ('db_name', 'comment', 'location', 'owner_name', 'owner_type', 'parameters')
 
@@ -890,7 +882,6 @@ class HiveServerClient(object):
     cols = ('TABLE_NAME', 'TABLE_TYPE', 'REMARKS')
     return HiveServerTRowSet(results.results, schema.schema).cols(cols)
 
-
   def get_tables(self, database, table_names, table_types=None):
     if not table_types:
       table_types = self.DEFAULT_TABLE_TYPES
@@ -904,9 +895,8 @@ class HiveServerClient(object):
 
     return HiveServerTRowSet(results.results, schema.schema).cols(('TABLE_NAME',))
 
-
   def get_table(self, database, table_name, partition_spec=None):
-    req = TGetTablesReq(schemaName=database.lower(), tableName=table_name.lower()) # Impala returns empty if not lower case
+    req = TGetTablesReq(schemaName=database.lower(), tableName=table_name.lower())  # Impala returns empty if not lower case
     (res, session) = self.call(self._client.GetTables, req)
 
     table_results, table_schema = self.fetch_result(res.operationHandle, orientation=TFetchOrientation.FETCH_NEXT)
@@ -927,7 +917,7 @@ class HiveServerClient(object):
       self.close_operation(operation_handle)
     except Exception as e:
       ex_string = str(e)
-      if 'cannot find field' in ex_string: # Workaround until Hive 2.0 and HUE-3751
+      if 'cannot find field' in ex_string:  # Workaround until Hive 2.0 and HUE-3751
         desc_results, desc_schema, operation_handle, session = self.execute_statement('USE `%s`' % database, session=session)
         self.close_operation(operation_handle)
         if partition_spec:
@@ -965,7 +955,7 @@ class HiveServerClient(object):
           desc_results.results.columns[1].stringVal.values = desc_results.results.columns[1].stringVal.values[:part_index]
           desc_results.results.columns[2].stringVal.values = desc_results.results.columns[2].stringVal.values[:part_index]
 
-          desc_results.results.columns[1].stringVal.nulls = '' # Important to not clear the last two types
+          desc_results.results.columns[1].stringVal.nulls = ''  # Important to not clear the last two types
 
           desc_results.results.columns[1].stringVal.values[-1] = None
           desc_results.results.columns[2].stringVal.values[-1] = None
@@ -981,12 +971,10 @@ class HiveServerClient(object):
 
     return HiveServerTable(table_results.results, table_schema.schema, desc_results.results, desc_schema.schema)
 
-
   def execute_query(self, query, max_rows=1000, session=None):
     configuration = self._get_query_configuration(query)
     return self.execute_query_statement(statement=query.query['query'], max_rows=max_rows, configuration=configuration, session=session)
 
-
   def execute_query_statement(self, statement, max_rows=1000, configuration=None, orientation=TFetchOrientation.FETCH_FIRST,
       close_operation=False, session=None):
     if configuration is None:
@@ -1005,7 +993,6 @@ class HiveServerClient(object):
 
     return HiveServerDataTable(results, schema, operation_handle, self.query_server, session=session)
 
-
   def execute_async_query(self, query, statement=0, session=None):
     if statement == 0:
       # Impala just has settings currently
@@ -1024,7 +1011,6 @@ class HiveServerClient(object):
 
     return self.execute_async_statement(statement=query_statement, conf_overlay=configuration, session=session)
 
-
   def execute_statement(self, statement, max_rows=1000, configuration=None, orientation=TFetchOrientation.FETCH_NEXT, session=None):
     if configuration is None:
       configuration = {}
@@ -1040,7 +1026,6 @@ class HiveServerClient(object):
     results, schema = self.fetch_result(res.operationHandle, max_rows=max_rows, orientation=orientation)
     return results, schema, res.operationHandle, session
 
-
   def execute_async_statement(self, statement=None, thrift_function=None, thrift_request=None, conf_overlay=None, session=None):
     if conf_overlay is None:
       conf_overlay = {}
@@ -1074,19 +1059,16 @@ class HiveServerClient(object):
     results, schema = self.fetch_result(operation_handle, orientation, max_rows)
     return HiveServerDataTable(results, schema, operation_handle, self.query_server)
 
-
   def cancel_operation(self, operation_handle):
     req = TCancelOperationReq(operationHandle=operation_handle)
     (res, session) = self.call(self._client.CancelOperation, req)
     return res
 
-
   def close_operation(self, operation_handle):
     req = TCloseOperationReq(operationHandle=operation_handle)
     (res, session) = self.call(self._client.CloseOperation, req)
     return res
 
-
   def fetch_result(self, operation_handle, orientation=TFetchOrientation.FETCH_FIRST, max_rows=1000):
     if operation_handle.hasResultSet:
       fetch_req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows)
@@ -1102,7 +1084,6 @@ class HiveServerClient(object):
 
     return res, schema
 
-
   def fetch_log(self, operation_handle, orientation=TFetchOrientation.FETCH_NEXT, max_rows=1000):
     req = TFetchResultsReq(operationHandle=operation_handle, orientation=orientation, maxRows=max_rows, fetchType=1)
     (res, session) = self.call(self._client.FetchResults, req)
@@ -1114,13 +1095,11 @@ class HiveServerClient(object):
 
     return '\n'.join(lines)
 
-
   def get_operation_status(self, operation_handle):
     req = TGetOperationStatusReq(operationHandle=operation_handle)
     (res, session) = self.call(self._client.GetOperationStatus, req)
     return res
 
-
   def get_log(self, operation_handle):
     try:
       req = TGetLogReq(operationHandle=operation_handle)
@@ -1136,14 +1115,12 @@ class HiveServerClient(object):
 
       return message
 
-
   def _close(self, operation_handle, session):
-    if self.has_close_sessions: # Close session will close all associated operation_handle
+    if self.has_close_sessions:  # Close session will close all associated operation_handle
       self.close_session(session)
     else:
       self.close_operation(operation_handle)
 
-
   def get_columns(self, database, table):
     req = TGetColumnsReq(schemaName=database, tableName=table)
     (res, session) = self.call(self._client.GetColumns, req)
@@ -1153,14 +1130,12 @@ class HiveServerClient(object):
 
     return results, schema
 
-
   def explain(self, query):
     query_statement = query.get_query_statement(0)
     configuration = self._get_query_configuration(query)
     return self.execute_query_statement(statement='EXPLAIN %s' % query_statement, configuration=configuration,
                                         orientation=TFetchOrientation.FETCH_NEXT)
 
-
   # TODO execute both requests in same session
   def get_partitions(self, database, table_name, partition_spec=None, max_parts=None, reverse_sort=True):
     table = self.get_table(database, table_name)
@@ -1214,7 +1189,6 @@ class HiveServerClient(object):
 
     return partitions[:max_parts]
 
-
   def get_configuration(self, session=None):
     configuration = {}
 
@@ -1237,11 +1211,9 @@ class HiveServerClient(object):
 
     return configuration
 
-
   def _get_query_configuration(self, query):
     return dict([(setting['key'], setting['value']) for setting in query.settings])
 
-
   def get_functions(self):
     '''
     Could support parameters.
@@ -1256,7 +1228,6 @@ class HiveServerClient(object):
 
     return self.execute_async_statement(thrift_function=thrift_function, thrift_request=req)
 
-
   def get_primary_keys(self, database_name, table_name, catalog_name=None):
     '''
     Get the Primary Keys of a Table entity (seems like database name is required).
@@ -1290,7 +1261,6 @@ class HiveServerClient(object):
 
     return results
 
-
   def get_foreign_keys(self, parent_catalog_name=None, parent_database_name=None, parent_table_name=None, foreign_catalog_name=None,
       foreign_database_name=None, foreign_table_name=None):
     '''
@@ -1352,7 +1322,7 @@ class HiveServerTableCompatible(HiveServerTable):
     self._details = None
     try:
       self.is_impala_only = 'org.apache.hadoop.hive.kudu.KuduSerDe' in str(hive_table.properties) or \
-        'org.apache.kudu.mapreduce.KuduTableOutputFormat' in str(hive_table.properties) # Deprecated since CDP
+        'org.apache.kudu.mapreduce.KuduTableOutputFormat' in str(hive_table.properties)  # Deprecated since CDP
     except Exception as e:
       LOG.warning('Autocomplete data fetching error: %s' % e)
       self.is_impala_only = False
@@ -1419,11 +1389,9 @@ class PartitionValueCompatible(object):
     self.values = [pv[1] for pv in [part.split('=') for part in parts]]
     self.sd = type('Sd', (object,), properties,)
 
-
   def __repr__(self):
     return 'PartitionValueCompatible(spec:%s, values:%s, sd:%s)' % (self.partition_spec, self.values, self.sd)
 
-
   def _get_partition_spec(self, name, value):
     partition_spec = "`%s`='%s'" % (name, value)
     partition_key = next((key for key in self.partition_keys if key.name == name), None)
@@ -1452,35 +1420,29 @@ class HiveServerClientCompatible(object):
     self.user = client.user
     self.query_server = client.query_server
 
-
   def query(self, query, statement=0, session=None):
     return self._client.execute_async_query(query, statement, session=session)
 
-
   def get_state(self, handle):
     operationHandle = handle.get_rpc_handle()
     res = self._client.get_operation_status(operationHandle)
     return HiveServerQueryHistory.STATE_MAP[res.operationState]
 
-
   def get_operation_status(self, handle):
     operationHandle = handle.get_rpc_handle()
     return self._client.get_operation_status(operationHandle)
 
-
   def use(self, query, session=None):
     data = self._client.execute_query(query, session=session)
     self._client.close_operation(data.operation_handle)
     return data
 
-
   def explain(self, query):
     data_table = self._client.explain(query)
     data = ExplainCompatible(data_table)
     self._client.close_operation(data_table.operation_handle)
     return data
 
-
   def fetch(self, handle, start_over=False, max_rows=None):
     operationHandle = handle.get_rpc_handle()
     if max_rows is None:
@@ -1495,29 +1457,23 @@ class HiveServerClientCompatible(object):
 
     return ResultCompatible(data_table)
 
-
   def cancel_operation(self, handle):
     operationHandle = handle.get_rpc_handle()
     return self._client.cancel_operation(operationHandle)
 
-
   def close(self, handle):
     return self.close_operation(handle)
 
-
   def close_operation(self, handle):
     operationHandle = handle.get_rpc_handle()
     return self._client.close_operation(operationHandle)
 
-
   def close_session(self, session):
     return self._client.close_session(session)
 
-
   def dump_config(self):
     return 'Does not exist in HS2'
 
-
   def get_log(self, handle, start_over=True):
     operationHandle = handle.get_rpc_handle()
 
@@ -1531,16 +1487,13 @@ class HiveServerClientCompatible(object):
 
       return self._client.fetch_log(operationHandle, orientation=orientation, max_rows=-1)
 
-
   def get_databases(self, schemaName=None):
     col = 'TABLE_SCHEM'
     return [table[col] for table in self._client.get_databases(schemaName)]
 
-
   def get_database(self, database):
     return self._client.get_database(database)
 
-
   def get_tables_meta(self, database, table_names, table_types=None):
     tables = self._client.get_tables_meta(database, table_names, table_types)
     massaged_tables = []
@@ -1554,57 +1507,43 @@ class HiveServerClientCompatible(object):
     massaged_tables = sorted(massaged_tables, key=lambda table_: table_['name'])
     return massaged_tables
 
-
   def get_tables(self, database, table_names, table_types=None):
     tables = [table['TABLE_NAME'] for table in self._client.get_tables(database, table_names, table_types)]
     tables.sort()
     return tables
 
-
   def get_table(self, database, table_name, partition_spec=None):
     table = self._client.get_table(database, table_name, partition_spec)
     return HiveServerTableCompatible(table)
 
-
   def get_columns(self, database, table):
     return self._client.get_columns(database, table)
 
-
   def get_default_configuration(self, *args, **kwargs):
     return []
 
-
   def get_results_metadata(self, handle):
     # We just need to mock
     return ResultMetaCompatible()
 
-
   def create_database(self, name, description): raise NotImplementedError()
 
-
   def alter_table(self, dbname, tbl_name, new_tbl): raise NotImplementedError()
 
-
   def open_session(self, user):
     return self._client.open_session(user)
 
-
   def add_partition(self, new_part): raise NotImplementedError()
 
-
   def get_partition(self, *args, **kwargs): raise NotImplementedError()
 
-
   def get_partitions(self, database, table_name, partition_spec, max_parts, reverse_sort=True):
     return self._client.get_partitions(database, table_name, partition_spec, max_parts, reverse_sort)
 
-
   def alter_partition(self, db_name, tbl_name, new_part): raise NotImplementedError()
 
-
   def get_configuration(self):
     return self._client.get_configuration()
 
-
   def get_functions(self):
     return self._client.get_functions()

Разница между файлами не показана из-за своего большого размера
+ 76 - 167
apps/beeswax/src/beeswax/tests.py


Некоторые файлы не были показаны из-за большого количества измененных файлов