Переглянути джерело

[jwt] Injecting JWT as bearer token over THttpClient to Impala (#2533)

- Add bearer token method for THttpClient.
- Inject jwt in THttpClient when USE_THRIFT_HTTP_JWT is enabled.
- Resolve User import error by cyclic dependency.
- Check for exception conditions of user not found and no token and update tests.
- Fix pylint long lines and bad spacing issues.
Harsh Gupta 4 роки тому
батько
коміт
56a6c734b0

+ 3 - 0
desktop/core/src/desktop/lib/thrift_/http_client.py

@@ -61,6 +61,9 @@ class THttpClient(TTransportBase):
   def set_basic_auth(self, username, password):
     self._client.set_basic_auth(username, password)
 
+  def set_bearer_auth(self, token):
+    self._client.set_bearer_auth(token)
+
   def set_verify(self, verify=True):
     self._client.set_verify(verify)
 

+ 17 - 1
desktop/core/src/desktop/lib/thrift_util.py

@@ -40,7 +40,7 @@ from thrift.protocol.TBinaryProtocol import TBinaryProtocol
 from thrift.protocol.TMultiplexedProtocol import TMultiplexedProtocol
 
 from django.conf import settings
-from desktop.conf import SASL_MAX_BUFFER, CHERRYPY_SERVER_THREADS, ENABLE_SMART_THRIFT_POOL
+from desktop.conf import SASL_MAX_BUFFER, CHERRYPY_SERVER_THREADS, ENABLE_SMART_THRIFT_POOL, USE_THRIFT_HTTP_JWT, ENABLE_ORGANIZATIONS
 
 from desktop.lib.apputil import WARN_LEVEL_CALL_DURATION_MS, INFO_LEVEL_CALL_DURATION_MS
 from desktop.lib.python_util import create_synchronous_io_multiplexer
@@ -338,6 +338,22 @@ def connect_to_thrift(conf):
   if conf.transport_mode == 'http':
     if conf.use_sasl and conf.mechanism != 'PLAIN':
       mode.set_kerberos_auth(service=conf.kerberos_principal)
+
+    elif USE_THRIFT_HTTP_JWT.get():
+      from desktop.auth.backend import find_user, rewrite_user # Cyclic dependency
+      user = rewrite_user(find_user(conf.username))
+
+      if user is None:
+        raise Exception("JWT: User not found.")
+
+      if ENABLE_ORGANIZATIONS.get() and user.token:
+        token = user.token
+      elif user.profile.data.get('jwt_access_token'):
+        token = user.profile.data['jwt_access_token']
+      else:
+        raise Exception("JWT: Could not retrive saved token from user.")
+
+      mode.set_bearer_auth(token)
     else:
       mode.set_basic_auth(conf.username, conf.password)
 

+ 121 - 20
desktop/core/src/desktop/lib/thrift_util_test.py

@@ -36,7 +36,7 @@ if not gen_py_path in sys.path:
 
 from djangothrift_test_gen import TestService
 from djangothrift_test_gen.ttypes import TestStruct, TestNesting, TestEnum, TestManyTypes
-from nose.tools import assert_equal, assert_true
+from nose.tools import assert_equal, assert_true, assert_raises
 from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
 from thrift.server import TServer
 from thrift.transport import TSocket
@@ -44,7 +44,11 @@ from thrift.transport.TTransport import TBufferedTransportFactory, TTransportExc
 
 from desktop.lib import python_util, thrift_util
 from desktop.lib.thrift_util import jsonable2thrift, thrift2json, _unpack_guid_secret_in_handle
+from desktop.conf import USE_THRIFT_HTTP_JWT
+from desktop.lib.django_test_util import make_logged_in_client
+from desktop.auth.backend import rewrite_user, find_or_create_user, ensure_has_a_group, create_user
 
+from useradmin.models import User
 
 LOG = logging.getLogger(__name__)
 
@@ -208,8 +212,8 @@ class ThriftUtilTest(unittest.TestCase):
 
   def test_enum_as_sequence(self):
     seq = thrift_util.enum_as_sequence(TestEnum)
-    self.assertEquals(len(seq),3)
-    self.assertEquals(sorted(seq),sorted(['ENUM_ONE','ENUM_TWO','ENUM_THREE']))
+    self.assertEquals(len(seq), 3)
+    self.assertEquals(sorted(seq), sorted(['ENUM_ONE', 'ENUM_TWO', 'ENUM_THREE']))
 
   def test_is_thrift_struct(self):
     self.assertTrue(thrift_util.is_thrift_struct(TestStruct()))
@@ -218,29 +222,54 @@ class ThriftUtilTest(unittest.TestCase):
   def test_fixup_enums(self):
     enum = TestEnum()
     struct1 = TestStruct()
-    self.assertTrue(hasattr(enum,"_VALUES_TO_NAMES"))
+    self.assertTrue(hasattr(enum, "_VALUES_TO_NAMES"))
     struct1.myenum = 0
-    thrift_util.fixup_enums(struct1,{"myenum":TestEnum})
-    self.assertTrue(hasattr(struct1,"myenumAsString"))
-    self.assertEquals(struct1.myenumAsString,'ENUM_ONE')
+    thrift_util.fixup_enums(struct1, {"myenum": TestEnum})
+    self.assertTrue(hasattr(struct1, "myenumAsString"))
+    self.assertEquals(struct1.myenumAsString, 'ENUM_ONE')
 
   def test_unpack_guid_secret_in_handle(self):
     if sys.version_info[0] > 2:
-      hive_handle = """(TGetTablesReq(sessionHandle=TSessionHandle(sessionId=THandleIdentifier(guid=%s, secret=%s)), catalogName=None, schemaName='default', tableName='customers', tableTypes=None),)""" % (str(b'N\xc5\xed\x14k\xbeI\xda\xb9\x14\xe7\xf2\x9a\xb7\xf0\xa5'), str(b']s(\xb5\xf6ZO\x03\x99\x955\xacl\xb4\x98\xae'))
-      self.assertEqual(_unpack_guid_secret_in_handle(hive_handle), """(TGetTablesReq(sessionHandle=TSessionHandle(sessionId=THandleIdentifier(guid=da49be6b14edc54e:a5f0b79af2e714b9, secret=034f5af6b528735d:ae98b46cac359599)), catalogName=None, schemaName=\'default\', tableName=\'customers\', tableTypes=None),)""")
+      hive_handle = ("(TGetTablesReq(sessionHandle=TSessionHandle(sessionId=THandleIdentifier(guid=%s, secret=%s)), catalogName=None,"
+      " schemaName='default', tableName='customers', tableTypes=None),"
+      ")") % (str(b'N\xc5\xed\x14k\xbeI\xda\xb9\x14\xe7\xf2\x9a\xb7\xf0\xa5'), str(b']s(\xb5\xf6ZO\x03\x99\x955\xacl\xb4\x98\xae'))
 
-      impala_handle = """(TExecuteStatementReq(sessionHandle=TSessionHandle(sessionId=THandleIdentifier(guid=%s, secret=%s)), statement=b\'USE `default`\', confOverlay={\'QUERY_TIMEOUT_S\': \'300\'}, runAsync=False),)""" % (str(b'\xc4\xccnI\xf1\xbdJ\xc3\xb2\n\xd5[9\xe1Mr'), str(b'\xb0\x9d\xfd\x82\x94%L\xae\x9ch$f=\xfa{\xd0'))
-      self.assertEqual(_unpack_guid_secret_in_handle(impala_handle), """(TExecuteStatementReq(sessionHandle=TSessionHandle(sessionId=THandleIdentifier(guid=c34abdf1496eccc4:724de1395bd50ab2, secret=ae4c259482fd9db0:d07bfa3d6624689c)), statement=b\'USE `default`\', confOverlay={\'QUERY_TIMEOUT_S\': \'300\'}, runAsync=False),)""")
+      self.assertEqual(_unpack_guid_secret_in_handle(hive_handle), ("(TGetTablesReq(sessionHandle=TSessionHandle(sessionId="
+      "THandleIdentifier(guid=da49be6b14edc54e:a5f0b79af2e714b9, secret=034f5af6b528735d:ae98b46cac359599)), catalogName=None, "
+      "schemaName=\'default\', tableName=\'customers\', tableTypes=None),)"))
+
+      impala_handle = ("(TExecuteStatementReq(sessionHandle=TSessionHandle(sessionId=THandleIdentifier(guid=%s, secret=%s)), "
+      "statement=b\'USE `default`\', confOverlay={\'QUERY_TIMEOUT_S\': \'300\'}, runAsync=False)"
+      ",)") % (str(b'\xc4\xccnI\xf1\xbdJ\xc3\xb2\n\xd5[9\xe1Mr'), str(b'\xb0\x9d\xfd\x82\x94%L\xae\x9ch$f=\xfa{\xd0'))
+
+      self.assertEqual(_unpack_guid_secret_in_handle(impala_handle), ("(TExecuteStatementReq(sessionHandle=TSessionHandle("
+      "sessionId=THandleIdentifier(guid=c34abdf1496eccc4:724de1395bd50ab2, secret=ae4c259482fd9db0:d07bfa3d6624689c)), "
+      "statement=b\'USE `default`\', confOverlay={\'QUERY_TIMEOUT_S\': \'300\'}, runAsync=False),)"))
     else:
-      hive_handle = """(TExecuteStatementReq(confOverlay={}, sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret=\'\x1aOYj\xf3\x86M\x95\xbb\xc8\xe9/;\xb0{9\', guid=\'\x86\xa6$\xb2\xb8\xdaF\xbd\xbd\xf5\xc5\xf4\xcb\x96\x03<\')), runAsync=True, statement="SELECT \'Hello World!\'"),)"""
-      self.assertEqual(_unpack_guid_secret_in_handle(hive_handle), """(TExecuteStatementReq(confOverlay={}, sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret=954d86f36a594f1a:397bb03b2fe9c8bb, guid=bd46dab8b224a686:3c0396cbf4c5f5bd)), runAsync=True, statement="SELECT \'Hello World!\'"),)""")
+      hive_handle = ("(TExecuteStatementReq(confOverlay={}, sessionHandle=TSessionHandle(sessionId=THandleIdentifier("
+      "secret=\'\x1aOYj\xf3\x86M\x95\xbb\xc8\xe9/;\xb0{9\', guid=\'\x86\xa6$\xb2\xb8\xdaF\xbd\xbd\xf5\xc5\xf4\xcb\x96\x03<\')), "
+      'runAsync=True, statement="SELECT \'Hello World!\'"),)')
+
+      self.assertEqual(_unpack_guid_secret_in_handle(hive_handle), ("(TExecuteStatementReq(confOverlay={}, sessionHandle=TSessionHandle("
+      "sessionId=THandleIdentifier(secret=954d86f36a594f1a:397bb03b2fe9c8bb, guid=bd46dab8b224a686:3c0396cbf4c5f5bd)), runAsync=True, "
+      'statement="SELECT \'Hello World!\'"),)'))
+
+      impala_handle = ("(TGetTablesReq(schemaName=u\'default\', sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret="
+      "\'\x7f\x98\x97s\xe1\xa8G\xf4\x8a\x8a\\r\x0e6\xc2\xee\xf0\', guid=\'\xfa\xb0/\x04 \xfeDX\x99\xfcq\xff2\x07\x02\xfe\')), "
+      "tableName=u\'customers\', tableTypes=None, catalogName=None),)")
 
-      impala_handle = """(TGetTablesReq(schemaName=u\'default\', sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret=\'\x7f\x98\x97s\xe1\xa8G\xf4\x8a\x8a\\r\x0e6\xc2\xee\xf0\', guid=\'\xfa\xb0/\x04 \xfeDX\x99\xfcq\xff2\x07\x02\xfe\')), tableName=u\'customers\', tableTypes=None, catalogName=None),)"""
-      self.assertEqual(_unpack_guid_secret_in_handle(impala_handle), """(TGetTablesReq(schemaName=u\'default\', sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret=f447a8e17397987f:f0eec2360e0d8a8a, guid=5844fe20042fb0fa:fe020732ff71fc99)), tableName=u\'customers\', tableTypes=None, catalogName=None),)""")
+      self.assertEqual(_unpack_guid_secret_in_handle(impala_handle), ("(TGetTablesReq(schemaName=u\'default\', sessionHandle="
+      "TSessionHandle(sessionId=THandleIdentifier(secret=f447a8e17397987f:f0eec2360e0d8a8a, guid=5844fe20042fb0fa:fe020732ff71fc99)),"
+      " tableName=u\'customers\', tableTypes=None, catalogName=None),)"))
 
     # Following should be added to test, but fails because eval doesn't handle null bytes
-    #impala_handle = """(TGetTablesReq(schemaName=u\'default\', sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret=\'\x7f\x98\x97s\xe1\xa8G\xf4\x8a\x8a\\r\x0e6\xc2\xee\xf0\', guid=\'\xd23\xfa\x150\xf5D\x91\x00\x00\x00\x00\xd7\xef\x91\x00\')), tableName=u\'customers\', tableTypes=None, catalogName=None),)"""
-    #self.assertEqual(_unpack_guid_secret_in_handle(impala_handle), """(TGetTablesReq(schemaName=u\'default\', sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret=f447a8e17397987f:f0eec2360e0d8a8a, guid=9144f53015fa33d2:0091efd700000000)), tableName=u\'customers\', tableTypes=None, catalogName=None),)""")
+    # impala_handle = ("(TGetTablesReq(schemaName=u\'default\', sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret="
+    # "\'\x7f\x98\x97s\xe1\xa8G\xf4\x8a\x8a\\r\x0e6\xc2\xee\xf0\', guid=\'\xd23\xfa\x150\xf5D\x91\x00\x00\x00\x00\xd7\xef\x91\x00\')), "
+    # "tableName=u\'customers\', tableTypes=None, catalogName=None),)")
+
+    # self.assertEqual(_unpack_guid_secret_in_handle(impala_handle), ("(TGetTablesReq(schemaName=u\'default\', "
+    # "sessionHandle=TSessionHandle(sessionId=THandleIdentifier(secret=f447a8e17397987f:f0eec2360e0d8a8a, "
+    # "guid=9144f53015fa33d2:0091efd700000000)), tableName=u\'customers\', tableTypes=None, catalogName=None),)"))
 
 class TestJsonable2Thrift(unittest.TestCase):
   """
@@ -281,17 +310,17 @@ class TestJsonable2Thrift(unittest.TestCase):
 
   def test_set(self):
     x = TestManyTypes()
-    x.a_set=set([1,2,3,4,5])
+    x.a_set = set([1, 2, 3, 4, 5])
     self.assertBackAndForth(x)
 
   def test_list(self):
     x = TestManyTypes()
-    x.a_list = [ TestStruct(b=i) for i in range(4) ]
+    x.a_list = [TestStruct(b=i) for i in range(4)]
     self.assertBackAndForth(x)
 
   def test_map(self):
     x = TestManyTypes()
-    x.a_map = dict([ (i, TestStruct(b=i)) for i in range(4) ])
+    x.a_map = dict([(i, TestStruct(b=i)) for i in range(4)])
     self.assertBackAndForth(x)
 
   def test_limits(self):
@@ -343,5 +372,77 @@ class TestSuperClient(unittest.TestCase):
       # Could check output for several "Thrift exception; retrying: some error"
 
 
+class TestThriftJWT(unittest.TestCase):
+  def setUp(self):
+    self.sample_token = "some_jwt_token"
+
+    self.client = make_logged_in_client(username="test_user", groupname="default", recreate=True, is_superuser=False)
+    self.user = rewrite_user(User.objects.get(username="test_user"))
+
+
+  def test_jwt_thrift(self):
+    with patch('desktop.lib.thrift_util.TBinaryProtocol'):
+      with patch('desktop.lib.thrift_util.TBufferedTransport'):
+        with patch('desktop.lib.thrift_util.TBufferedTransport'):
+          with patch('desktop.lib.thrift_util.THttpClient.set_bearer_auth') as set_bearer_auth:
+
+            self.user.profile.update_data({'jwt_access_token': self.sample_token})
+            self.user.profile.save()
+
+            reset = USE_THRIFT_HTTP_JWT.set_for_testing(True)
+
+            conf = Mock(
+              klass=Mock(),
+              username="test_user",
+              transport_mode='http',
+              timeout_seconds=None,
+              use_sasl=None,
+              http_url='some_http_url'
+            )
+
+            try:
+              service, protocol, transport = thrift_util.connect_to_thrift(conf)
+              set_bearer_auth.assert_called_with('some_jwt_token')
+            finally:
+              reset()
+
+
+  def test_jwt_thrift_exceptions(self):
+    with patch('desktop.lib.thrift_util.TBinaryProtocol'):
+      with patch('desktop.lib.thrift_util.TBufferedTransport'):
+        with patch('desktop.lib.thrift_util.TBufferedTransport'):
+          with patch('desktop.lib.thrift_util.THttpClient.set_bearer_auth') as set_bearer_auth:
+            reset = USE_THRIFT_HTTP_JWT.set_for_testing(True)
+
+            try:
+              # When token not stored in user profile
+              conf = Mock(
+                klass=Mock(),
+                username="test_user",
+                transport_mode='http',
+                timeout_seconds=None,
+                use_sasl=None,
+                http_url='some_http_url'
+              )
+
+              assert_raises(Exception, thrift_util.connect_to_thrift, conf)
+
+              # When user not found
+              self.user.profile.update_data({'jwt_access_token': self.sample_token})
+              self.user.profile.save()
+
+              conf = Mock(
+                klass=Mock(),
+                username="test_not_user",
+                transport_mode='http',
+                timeout_seconds=None,
+                use_sasl=None,
+                http_url='some_http_url'
+              )
+              assert_raises(Exception, thrift_util.connect_to_thrift, conf)
+            finally:
+              reset()
+
+
 if __name__ == '__main__':
   unittest.main()