Browse Source

[impala] Fix TSessionHandle guid and secret corrupted after save (#1918)

Ying Chen 4 năm trước cách đây
mục cha
commit
9df3435f86

+ 12 - 1
apps/beeswax/src/beeswax/models.py

@@ -17,6 +17,7 @@
 
 from builtins import range
 from builtins import object
+import ast
 import base64
 import datetime
 import json
@@ -475,11 +476,21 @@ class Session(models.Model):
   objects = SessionManager()
 
   def get_handle(self):
-    secret, guid = HiveServerQueryHandle.get_decoded(secret=self.secret, guid=self.guid)
+    secret, guid = self.get_adjusted_guid_secret()
+    secret, guid = HiveServerQueryHandle.get_decoded(secret=secret, guid=guid)
 
     handle_id = THandleIdentifier(secret=secret, guid=guid)
     return TSessionHandle(sessionId=handle_id)
 
+  def get_adjusted_guid_secret(self):
+    secret = self.secret
+    guid = self.guid
+    if sys.version_info[0] > 2 and not isinstance(self.secret, bytes) and not isinstance(self.guid, bytes):
+      # only for py3, after bytes saved, bytes wrapped in a string object
+      secret = ast.literal_eval(secret)
+      guid = ast.literal_eval(guid)
+    return secret, guid
+
   def get_properties(self):
     return json.loads(self.properties) if self.properties else {}
 

+ 35 - 10
apps/beeswax/src/beeswax/server/hive_server2_lib_tests.py

@@ -27,7 +27,7 @@ from desktop.lib.django_test_util import make_logged_in_client
 from useradmin.models import User
 
 from beeswax.conf import MAX_NUMBER_OF_SESSIONS, CLOSE_SESSIONS
-from beeswax.models import Session
+from beeswax.models import HiveServerQueryHandle, Session
 from beeswax.server.dbms import get_query_server_config, QueryServerException
 from beeswax.server.hive_server2_lib import HiveServerTable, HiveServerClient
 
@@ -64,6 +64,8 @@ class TestHiveServerClient():
     )
 
     with patch('beeswax.server.hive_server2_lib.thrift_util.get_client') as get_client:
+      original_secret = b's\xb6\x0ePP\xbdL\x17\xa3\x0f\\\xf7K\xe8Y\x1d'
+      original_guid = b'\xd9\xe0hT\xd6wO\xe1\xa3S\xfb\x04\xca\x93V\x01'
       get_client.return_value = Mock(
         OpenSession=Mock(
           return_value=Mock(
@@ -73,8 +75,8 @@ class TestHiveServerClient():
             configuration={},
             sessionHandle=Mock(
               sessionId=Mock(
-                secret=b'1',
-                guid=b'1'
+                secret=original_secret,
+                guid=original_guid
               )
             ),
             serverProtocolVersion=11
@@ -91,11 +93,29 @@ class TestHiveServerClient():
         session_count + 1,  # +1 as setUp resets the user which deletes cascade the sessions
         Session.objects.filter(owner=self.user, application=self.query_server['server_name']).count()
       )
+
+      session = Session.objects.get_session(self.user, self.query_server['server_name'])
+      secret, guid = session.get_adjusted_guid_secret()
+      secret, guid = HiveServerQueryHandle.get_decoded(secret, guid)
+      assert_equal(
+        original_secret,
+        secret
+      )
+      assert_equal(
+        original_guid,
+        guid
+      )
+      handle = session.get_handle()
+      assert_equal(
+        original_secret,
+        handle.sessionId.secret
+      )
       assert_equal(
-        str(session.guid),
-        Session.objects.get_session(self.user, self.query_server['server_name']).guid
+        original_guid,
+        handle.sessionId.guid
       )
 
+
   def test_get_configuration(self):
 
     with patch('beeswax.server.hive_server2_lib.HiveServerClient.execute_query_statement') as execute_query_statement:
@@ -236,6 +256,8 @@ class TestHiveServerClient():
       settings=[]
     )
 
+    original_secret = b's\xb6\x0ePP\xbdL\x17\xa3\x0f\\\xf7K\xe8Y\x1d'
+    original_guid = b'\xd9\xe0hT\xd6wO\xe1\xa3S\xfb\x04\xca\x93V\x01'
     with patch('beeswax.server.hive_server2_lib.thrift_util.get_client') as get_client:
       get_client.return_value = Mock(
         OpenSession=Mock(
@@ -246,8 +268,8 @@ class TestHiveServerClient():
             configuration={},
             sessionHandle=Mock(
               sessionId=Mock(
-                secret=b'1',
-                guid=b'1'
+                secret=original_secret,
+                guid=original_guid
               )
             ),
             serverProtocolVersion=11
@@ -298,11 +320,14 @@ class TestHiveServerClient():
         client.get_table(database='database', table_name='table_name')
       except QueryServerException as e:
         if sys.version_info[0] > 2:
-          req_string = ("TGetTablesReq(sessionHandle=TSessionHandle(sessionId=THandleIdentifier(guid=b'l\\xc4', secret=b'l\\xc4')), "
-            "catalogName=None, schemaName='database', tableName='table_name', tableTypes=None)")
+          req_string = ("TGetTablesReq(sessionHandle=TSessionHandle(sessionId=THandleIdentifier(guid=%s, secret=%s)), "
+            "catalogName=None, schemaName='database', tableName='table_name', tableTypes=None)")\
+            % (str(original_guid), str(original_secret))
         else:
           req_string = ("TGetTablesReq(schemaName='database', sessionHandle=TSessionHandle(sessionId=THandleIdentifier"
-            "(secret='1', guid='1')), tableName='table_name', tableTypes=None, catalogName=None)")
+            "(secret='%s', guid='%s')), tableName='table_name', tableTypes=None, catalogName=None)")\
+            % ('s\\xb6\\x0ePP\\xbdL\\x17\\xa3\\x0f\\\\\\xf7K\\xe8Y\\x1d',
+               '\\xd9\\xe0hT\\xd6wO\\xe1\\xa3S\\xfb\\x04\\xca\\x93V\\x01') # manually adding '\'
         assert_equal(
           "Bad status for request %s:\n%s" % (req_string, get_tables_res),
           str(e)