Bladeren bron

HUE-8758 [hive] Start adding the notion of dialect to simplify the logic

Romain 5 jaren geleden
bovenliggende
commit
33a24c8abe

+ 2 - 2
apps/beeswax/src/beeswax/server/dbms.py

@@ -81,7 +81,7 @@ def get(user, query_server=None, cluster=None):
       # Avoid circular dependency
       from beeswax.server.hive_server2_lib import HiveServerClientCompatible
 
-      if query_server['server_name'].startswith('impala'):
+      if query_server.get('dialect') == 'impala':
         from impala.dbms import ImpalaDbms
         from impala.server import ImpalaServerClient
         DBMS_CACHE[user.id][query_server['server_name']] = ImpalaDbms(HiveServerClientCompatible(ImpalaServerClient(query_server, user)), QueryHistory.SERVER_TYPE[1][0])
@@ -471,7 +471,7 @@ class HiveServer2Dbms(object):
 
   def cancel_operation(self, query_handle):
     resp = self.client.cancel_operation(query_handle)
-    if self.client.query_server['server_name'].startswith('impala'):
+    if self.client.query_server.get('dialect') == 'impala':
       resp = self.client.close_operation(query_handle)
     return resp
 

+ 15 - 15
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -347,7 +347,7 @@ class HiveServerDataTable(DataTable):
     self.schema = schema and schema.schema
     self.row_set = HiveServerTRowSet(results.results, schema)
     self.operation_handle = operation_handle
-    if query_server['server_name'].startswith('impala'):
+    if query_server.get('dialect') == 'impala':
       self.has_more = results.hasMoreRows
     else:
       self.has_more = not self.row_set.is_empty()    # Should be results.hasMoreRows but always True in HS2
@@ -534,7 +534,7 @@ class HiveServerClient(object):
     self.kerberos_principal_short_name = kerberos_principal_short_name
     self.impersonation_enabled = impersonation_enabled
 
-    if self.query_server['server_name'].startswith('impala'):
+    if self.query_server.get('dialect') == 'impala':
       from impala import conf as impala_conf
 
       ssl_enabled = impala_conf.SSL.ENABLED.get()
@@ -559,7 +559,7 @@ class HiveServerClient(object):
       password = None
 
     thrift_class = TCLIService
-    if self.query_server['server_name'].startswith('impala'):
+    if self.query_server.get('dialect') == 'impala':
       from ImpalaService import ImpalaHiveServer2Service
       thrift_class = ImpalaHiveServer2Service
 
@@ -599,7 +599,7 @@ class HiveServerClient(object):
       kerberos_principal_short_name = None
 
     use_sasl = self.query_server['use_sasl']
-    if self.query_server['server_name'].startswith('impala'):
+    if self.query_server.get('dialect') == 'impala':
       if auth_password: # Force LDAP/PAM.. auth if auth_password is provided
         mechanism = HiveServerClient.HS2_MECHANISMS['NONE']
       else:
@@ -626,7 +626,7 @@ class HiveServerClient(object):
     if self.impersonation_enabled:
       kwargs.update({'username': DEFAULT_USER})
 
-      if self.query_server['server_name'].startswith('impala'): # Only when Impala accepts it
+      if self.query_server.get('dialect') == 'impala': # Only when Impala accepts it
         kwargs['configuration'].update({'impala.doas.user': user.username})
 
     if self.query_server['server_name'] == 'beeswax': # All the time
@@ -638,7 +638,7 @@ class HiveServerClient(object):
     if self.query_server['server_name'] == 'sparksql': # All the time
       kwargs['configuration'].update({'hive.server2.proxy.user': user.username})
 
-    if self.query_server['server_name'].startswith('impala') and self.query_server['SESSION_TIMEOUT_S'] > 0:
+    if self.query_server.get('dialect') == 'impala' and self.query_server['SESSION_TIMEOUT_S'] > 0:
       kwargs['configuration'].update({'idle_session_timeout': str(self.query_server['SESSION_TIMEOUT_S'])})
 
     LOG.info('Opening %s thrift session for user %s' % (self.query_server['server_name'], user.username))
@@ -765,7 +765,7 @@ class HiveServerClient(object):
     req = TGetSchemasReq()
     if schemaName is not None:
       req.schemaName = schemaName
-    if self.query_server['server_name'].startswith('impala'):
+    if self.query_server.get('dialect') == 'impala':
       req.schemaName = None
 
     (res, session) = self.call(self._client.GetSchemas, req)
@@ -783,7 +783,7 @@ class HiveServerClient(object):
     desc_results, desc_schema, operation_handle, session = self.execute_statement(query, max_rows=5000, orientation=TFetchOrientation.FETCH_NEXT)
     self._close(operation_handle, session)
 
-    if self.query_server['server_name'].startswith('impala'):
+    if self.query_server.get('dialect') == 'impala':
       cols = ('name', 'location', 'comment') # Skip owner as on a new line
     else:
       cols = ('db_name', 'comment', 'location', 'owner_name', 'owner_type', 'parameters')
@@ -908,7 +908,7 @@ class HiveServerClient(object):
 
     configuration = {}
 
-    if self.query_server['server_name'].startswith('impala') and self.query_server['querycache_rows'] > 0:
+    if self.query_server.get('dialect') == 'impala' and self.query_server['querycache_rows'] > 0:
       configuration[IMPALA_RESULTSET_CACHE_SIZE] = str(self.query_server['querycache_rows'])
 
     # The query can override the default configuration
@@ -921,7 +921,7 @@ class HiveServerClient(object):
   def execute_statement(self, statement, max_rows=1000, configuration=None, orientation=TFetchOrientation.FETCH_NEXT, session=None):
     if configuration is None:
       configuration = {}
-    if self.query_server['server_name'].startswith('impala') and self.query_server['QUERY_TIMEOUT_S'] > 0:
+    if self.query_server.get('dialect') == 'impala' and self.query_server['QUERY_TIMEOUT_S'] > 0:
       configuration['QUERY_TIMEOUT_S'] = str(self.query_server['QUERY_TIMEOUT_S'])
 
     if sys.version_info[0] == 2:
@@ -935,7 +935,7 @@ class HiveServerClient(object):
 
 
   def execute_async_statement(self, statement, confOverlay, session=None):
-    if self.query_server['server_name'].startswith('impala') and self.query_server['QUERY_TIMEOUT_S'] > 0:
+    if self.query_server.get('dialect') == 'impala' and self.query_server['QUERY_TIMEOUT_S'] > 0:
       confOverlay['QUERY_TIMEOUT_S'] = str(self.query_server['QUERY_TIMEOUT_S'])
 
     if sys.version_info[0] == 2:
@@ -1059,7 +1059,7 @@ class HiveServerClient(object):
     if self.has_close_sessions:
       self.close_session(partition_table.session)
 
-    if self.query_server['server_name'].startswith('impala'):
+    if self.query_server.get('dialect') == 'impala':
       try:
         # Fetch all partition key names, which are listed before the #Rows column
         cols = [col.name for col in partition_table.cols()]
@@ -1101,7 +1101,7 @@ class HiveServerClient(object):
   def get_configuration(self, session=None):
     configuration = {}
 
-    if self.query_server['server_name'].startswith('impala'):  # Return all configuration settings
+    if self.query_server.get('dialect') == 'impala':  # Return all configuration settings
       query = 'SET'
       results = self.execute_query_statement(query, orientation=TFetchOrientation.FETCH_NEXT, close_operation=True, session=session)
       configuration = dict((row[0], row[1]) for row in results.rows())
@@ -1265,7 +1265,7 @@ class HiveServerClientCompatible(object):
     if max_rows is None:
       max_rows = 1000
 
-    if start_over and not (self.query_server['server_name'].startswith('impala') and self.query_server['querycache_rows'] == 0): # Backward compatibility for impala
+    if start_over and not (self.query_server.get('dialect') == 'impala' and self.query_server['querycache_rows'] == 0): # Backward compatibility for impala
       orientation = TFetchOrientation.FETCH_FIRST
     else:
       orientation = TFetchOrientation.FETCH_NEXT
@@ -1300,7 +1300,7 @@ class HiveServerClientCompatible(object):
   def get_log(self, handle, start_over=True):
     operationHandle = handle.get_rpc_handle()
 
-    if beeswax_conf.USE_GET_LOG_API.get() or self.query_server['server_name'].startswith('impala'):
+    if beeswax_conf.USE_GET_LOG_API.get() or self.query_server.get('dialect') == 'impala':
       return self._client.get_log(operationHandle)
     else:
       if start_over:

+ 31 - 3
apps/beeswax/src/beeswax/server/hive_server2_lib_tests.py

@@ -30,7 +30,6 @@ from beeswax.server.hive_server2_lib import HiveServerTable, HiveServerClient
 from useradmin.models import User
 
 from desktop.lib.django_test_util import make_logged_in_client
-from desktop.lib.test_utils import grant_access
 
 if sys.version_info[0] > 2:
   from unittest.mock import patch, Mock, MagicMock
@@ -47,8 +46,6 @@ class TestHiveServerClient():
     self.client = make_logged_in_client(username="test_hive_server2_lib", groupname="default", recreate=True, is_superuser=False)
     self.user = User.objects.get(username="test_hive_server2_lib")
 
-    grant_access(self.user.username, self.user.username, "beeswax")
-
     self.query_server = {
         'principal': 'hue',
         'server_name': 'hive',
@@ -180,6 +177,37 @@ class TestHiveServerClient():
         Session.objects.filter(owner=self.user, application=self.query_server['server_name']).count()
       )
 
+  def test_get_databases_impala_specific(self):
+    query = Mock(
+      get_query_statement=Mock(return_value=['SELECT 1']),
+      settings=[]
+    )
+
+    with patch('beeswax.server.hive_server2_lib.HiveServerTRowSet') as HiveServerTRowSet:
+
+      client = HiveServerClient(self.query_server, self.user)
+
+      client.call = Mock(return_value=(Mock(), Mock()))
+      client.fetch_result = Mock(return_value=(Mock(), Mock()))
+      client._close = Mock()
+
+      client.get_databases(query)
+
+      assert_not_equal(
+        None,
+        client.call.call_args[0][1].schemaName,
+        client.call.call_args.args
+      )
+
+      with patch.dict(self.query_server, {'dialect': 'impala'}, clear=True):
+        client.get_databases(query)
+
+        assert_equal(
+          None, # Should be empty and not '*' with Impala
+          client.call.call_args[0][1].schemaName,
+          client.call.call_args.args
+        )
+
 
 class TestHiveServerTable():