Browse Source

[beeswax] Support HiveServer2 PLAIN SASL authentication

Previously only NOSASL was supported
Remove unused hive.metastore.local property
Romain Rigaux 12 years ago
parent
commit
5a8d816

+ 3 - 1
apps/beeswax/src/beeswax/hive_site.py

@@ -37,11 +37,11 @@ _HIVE_SITE_PATH = None                  # Path to hive-site.xml
 _HIVE_SITE_DICT = None                  # A dictionary of name/value config options
 _METASTORE_LOC_CACHE = None
 
-_CNF_METASTORE_LOCAL = 'hive.metastore.local'
 _CNF_METASTORE_SASL = 'hive.metastore.sasl.enabled'
 _CNF_METASTORE_URIS = 'hive.metastore.uris'
 _CNF_METASTORE_KERBEROS_PRINCIPAL = 'hive.metastore.kerberos.principal'
 _CNF_HIVESERVER2_KERBEROS_PRINCIPAL = 'hive.server2.authentication.kerberos.principal'
+_CNF_HIVESERVER2_AUTHENTICATION = 'hive.server2.authentication'
 
 # Host is whatever up to the colon. Allow and ignore a trailing slash.
 _THRIFT_URI_RE = re.compile("^thrift://([^:]+):(\d+)[/]?$")
@@ -108,6 +108,8 @@ def get_metastore():
 def get_hiveserver2_kerberos_principal():
   return security_util.get_kerberos_principal(get_conf().get(_CNF_HIVESERVER2_KERBEROS_PRINCIPAL, None), socket.getfqdn())
 
+def get_hiveserver2_authentication():
+  return get_conf().get(_CNF_HIVESERVER2_AUTHENTICATION, 'NONE').upper() # NONE == PLAIN SASL
 
 def _parse_hive_site():
   """

+ 1 - 1
apps/beeswax/src/beeswax/server/dbms.py

@@ -78,7 +78,7 @@ def get_query_server_config(name='beeswax'):
       kerberos_principal = KERBEROS.HUE_PRINCIPAL.get()
 
     query_server = {
-        'server_name': 'beeswax',
+        'server_name': 'beeswax', # Aka HS2 too
         'server_host': BEESWAX_SERVER_HOST.get(),
         'server_port': BEESWAX_SERVER_PORT.get(),
         'server_interface': SERVER_INTERFACE.get(),

+ 31 - 2
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -20,6 +20,7 @@ import re
 import thrift
 
 from desktop.lib import thrift_util
+from hadoop import cluster
 
 from TCLIService import TCLIService
 from TCLIService.ttypes import TOpenSessionReq, TGetTablesReq, TFetchResultsReq,\
@@ -28,10 +29,10 @@ from TCLIService.ttypes import TOpenSessionReq, TGetTablesReq, TFetchResultsReq,
   TCloseSessionReq, TGetSchemasReq, TGetLogReq, TCancelOperationReq
 
 from beeswax import conf
+from beeswax import hive_site
 from beeswax.models import Session, HiveServerQueryHandle, HiveServerQueryHistory
 from beeswax.server.dbms import Table, NoSuchObjectException, DataTable,\
   QueryServerException
-from beeswax.server.beeswax_lib import BeeswaxClient
 
 
 LOG = logging.getLogger(__name__)
@@ -249,12 +250,13 @@ class HiveServerTColumnDesc:
 
 
 class HiveServerClient:
+  HS2_MECHANISMS = {'KERBEROS': 'GSSAPI', 'NONE': 'PLAIN', 'NOSASL': 'NOSASL'}
 
   def __init__(self, query_server, user):
     self.query_server = query_server
     self.user = user
 
-    use_sasl, kerberos_principal_short_name = BeeswaxClient.get_security(query_server)
+    use_sasl, mechanism, kerberos_principal_short_name = HiveServerClient.get_security(query_server)
 
     self._client = thrift_util.get_client(TCLIService.Client,
                                           query_server['server_host'],
@@ -262,9 +264,36 @@ class HiveServerClient:
                                           service_name=query_server['server_name'],
                                           kerberos_principal=kerberos_principal_short_name,
                                           use_sasl=use_sasl,
+                                          mechanism=mechanism,
+                                          username=user.username,
                                           timeout_seconds=conf.BEESWAX_SERVER_CONN_TIMEOUT.get())
 
 
+  @classmethod
+  def get_security(cls, query_server):
+    principal = query_server['principal']
+
+    if query_server['server_name'] == 'impala':
+      cluster_conf = cluster.get_cluster_conf_for_job_submission()
+      use_sasl = cluster_conf is not None and cluster_conf.SECURITY_ENABLED.get()
+      mechanism = HiveServerClient.HS2_MECHANISMS['KERBEROS']
+    else:
+      hive_mechanism = hive_site.get_hiveserver2_authentication()
+      if hive_mechanism not in HiveServerClient.HS2_MECHANISMS:
+        raise Exception(_('%s server authentication not supported. Valid are %s.' % (hive_mechanism, HiveServerClient.HS2_MECHANISMS.keys())))
+      use_sasl = hive_mechanism in ('KERBEROS', 'NONE')
+      mechanism = 'NOSASL'
+      if use_sasl:
+        mechanism = HiveServerClient.HS2_MECHANISMS[hive_mechanism]
+
+    if principal:
+      kerberos_principal_short_name = principal.split('/', 1)[0]
+    else:
+      kerberos_principal_short_name = None
+
+    return use_sasl, mechanism, kerberos_principal_short_name
+
+
   def open_session(self, user):
     req = TOpenSessionReq(username=user.username, configuration={})
     res = self._client.OpenSession(req)

+ 84 - 40
apps/beeswax/src/beeswax/tests.py

@@ -30,7 +30,7 @@ import socket
 import tempfile
 import threading
 
-from nose.tools import assert_true, assert_equal, assert_false
+from nose.tools import assert_true, assert_equal, assert_false, assert_not_equal
 from nose.plugins.skip import SkipTest
 
 from django.utils.encoding import smart_str
@@ -50,7 +50,7 @@ import beeswax.hive_site
 import beeswax.models
 import beeswax.views
 
-from beeswax import conf
+from beeswax import conf, hive_site
 from beeswax.views import collapse_whitespace
 from beeswax.test_base import make_query, wait_for_query_to_finish, verify_history, get_query_server_config,\
   BEESWAXD_TEST_PORT
@@ -59,6 +59,7 @@ from beeswax.data_export import download
 from beeswax.models import SavedQuery, QueryHistory, HQL
 from beeswax.server import dbms
 from beeswax.server.beeswax_lib import BeeswaxDataTable, BeeswaxClient
+from beeswax.server.hive_server2_lib import HiveServerClient
 from beeswax.test_base import BeeswaxSampleProvider
 import hadoop
 
@@ -110,27 +111,6 @@ class TestBeeswaxWithHadoop(BeeswaxSampleProvider):
     assert_equal(beeswax.models.QueryHistory.STATE[last_state], state)
     return history.id
 
-  def test_beeswax_get_kerberos_security(self):
-    principal = get_query_server_config('beeswax')['principal']
-    assert_true(principal.startswith('hue/'), principal)
-
-    principal = get_query_server_config('impala')['principal']
-    assert_true(principal.startswith('impala/'), principal)
-
-    beeswax_query_server = {'server_name': 'beeswax', 'principal': 'hue'}
-    impala_query_server = {'server_name': 'impala', 'principal': 'impala'}
-
-    assert_equal((False, 'hue'), BeeswaxClient.get_security(beeswax_query_server))
-    assert_equal((False, 'impala'), BeeswaxClient.get_security(impala_query_server))
-
-    cluster_conf = hadoop.cluster.get_cluster_conf_for_job_submission()
-    finish = cluster_conf.SECURITY_ENABLED.set_for_testing(True)
-    try:
-      assert_equal((True, 'hue'), BeeswaxClient.get_security(beeswax_query_server))
-      assert_equal((True, 'impala'), BeeswaxClient.get_security(impala_query_server))
-    finally:
-      finish()
-
   def test_query_with_error(self):
     """
     Creating a table "again" should not work; error should be displayed.
@@ -1320,7 +1300,7 @@ def test_hive_site():
       def get(self):
         return tmpdir
 
-    xml = hive_site_xml(is_local=False, use_sasl=False)
+    xml = hive_site_xml(is_local=True, use_sasl=False)
     file(os.path.join(tmpdir, 'hive-site.xml'), 'w').write(xml)
 
     beeswax.hive_site.reset()
@@ -1328,12 +1308,14 @@ def test_hive_site():
     beeswax.conf.BEESWAX_HIVE_CONF_DIR = Getter()
 
     is_local, host, port, kerberos_principal = beeswax.hive_site.get_metastore()
-    assert_false(is_local)
-    assert_equal(host, 'darkside-1234')
-    assert_equal(port, 9999)
+    assert_true(is_local)
+    # Local so don't use hive-site.xml
+    assert_not_equal(host, 'darkside-1234')
+    assert_not_equal(port, 9999)
+    assert_not_equal(kerberos_principal, 'test/test.com@TEST.COM')
     assert_equal(beeswax.hive_site.get_conf()['hive.metastore.warehouse.dir'], u'/abc')
-    assert_equal(kerberos_principal, 'test/test.com@TEST.COM')
     assert_equal(beeswax.hive_site.get_hiveserver2_kerberos_principal(), 'hs2test/test.com@TEST.COM')
+    assert_equal(beeswax.hive_site.get_hiveserver2_authentication(), 'NONE')
   finally:
     beeswax.hive_site.reset()
     if saved is not None:
@@ -1400,7 +1382,7 @@ def test_hive_site_external_metastore():
     beeswax.hive_site.reset()
     is_local, host, port, kerberos_principal = beeswax.hive_site.get_metastore()
     assert_false(is_local)
-    assert_equal(host, 'test.com')
+    assert_equal(host, 'darkside-1234')
     assert_equal(port, 9999)
     assert_equal(beeswax.hive_site.get_conf()['hive.metastore.warehouse.dir'], u'/abc')
     assert_equal(kerberos_principal, 'test/test.com@TEST.COM')
@@ -1627,21 +1609,78 @@ def search_log_line(component, expected_log, all_logs):
   """Checks if 'expected_log' can be found in one line of 'all_logs' outputed by the logging component 'component'."""
   return re.compile('.+?%(component)s(.+?)%(expected_log)s' % {'component': component, 'expected_log': expected_log}).search(all_logs)
 
+def test_beeswax_get_kerberos_security():
+  principal = get_query_server_config('beeswax')['principal']
+  assert_true(principal.startswith('hue/'), principal)
+
+  principal = get_query_server_config('impala')['principal']
+  assert_true(principal.startswith('impala/'), principal)
+
+  beeswax_query_server = {'server_name': 'beeswax', 'principal': 'hue'}
+  impala_query_server = {'server_name': 'impala', 'principal': 'impala'}
+
+  assert_equal((False, 'hue'), BeeswaxClient.get_security(beeswax_query_server))
+  assert_equal((False, 'impala'), BeeswaxClient.get_security(impala_query_server))
+
+  cluster_conf = hadoop.cluster.get_cluster_conf_for_job_submission()
+  finish = cluster_conf.SECURITY_ENABLED.set_for_testing(True)
+  try:
+    assert_equal((True, 'hue'), BeeswaxClient.get_security(beeswax_query_server))
+    assert_equal((True, 'impala'), BeeswaxClient.get_security(impala_query_server))
+  finally:
+    finish()
+
+def test_hiveserver2_get_security():
+  principal = get_query_server_config('beeswax')['principal']
+  assert_true(principal.startswith('hue/'), principal)
+
+  principal = get_query_server_config('impala')['principal']
+  assert_true(principal.startswith('impala/'), principal)
+
+  beeswax_query_server = {'server_name': 'beeswax', 'principal': 'hue'}
+  impala_query_server = {'server_name': 'impala', 'principal': 'impala'}
+
+  assert_equal((True, 'PLAIN', 'hue'), HiveServerClient.get_security(beeswax_query_server))
+  assert_equal((False, 'GSSAPI', 'impala'), HiveServerClient.get_security(impala_query_server))
+
+  cluster_conf = hadoop.cluster.get_cluster_conf_for_job_submission()
+  finish = cluster_conf.SECURITY_ENABLED.set_for_testing(True)
+  try:
+    assert_equal((True, 'GSSAPI', 'impala'), HiveServerClient.get_security(impala_query_server))
+  finally:
+    finish()
+
+  # Bad but easy mocking
+  prev = hive_site._HIVE_SITE_DICT.get(hive_site._CNF_HIVESERVER2_AUTHENTICATION)
+  try:
+    hive_site._HIVE_SITE_DICT[hive_site._CNF_HIVESERVER2_AUTHENTICATION] = 'NOSASL'
+    assert_equal((False, 'NOSASL', 'hue'), HiveServerClient.get_security(beeswax_query_server))
+    hive_site._HIVE_SITE_DICT[hive_site._CNF_HIVESERVER2_AUTHENTICATION] = 'KERBEROS'
+    assert_equal((True, 'GSSAPI', 'hue'), HiveServerClient.get_security(beeswax_query_server))
+  finally:
+    if prev is not None:
+      hive_site._HIVE_SITE_DICT[hive_site._CNF_HIVESERVER2_AUTHENTICATION] = prev
+    else:
+      del hive_site._HIVE_SITE_DICT[hive_site._CNF_HIVESERVER2_AUTHENTICATION]
+
+
 def hive_site_xml(is_local=False, use_sasl=False, thrift_uris='thrift://darkside-1234:9999',
                   warehouse_dir='/abc', kerberos_principal='test/test.com@TEST.COM',
-                  hs2_kerberos_principal='hs2test/test.com@TEST.COM'):
-  return """
-    <configuration>
-      <property>
-        <name>hive.metastore.local</name>
-        <value>%(is_local)s</value>
-      </property>
-
-      <property>
+                  hs2_kerberos_principal='hs2test/test.com@TEST.COM',
+                  hs2_kauthentication='NOSASL'):
+  if not is_local:
+    uris = """
+       <property>
         <name>hive.metastore.uris</name>
         <value>%(thrift_uris)s</value>
       </property>
+    """ % {'thrift_uris': thrift_uris}
+  else:
+    uris = ''
 
+  return """
+    <configuration>
+      %(uris)s
       <property>
         <name>hive.metastore.warehouse.dir</name>
         <value>%(warehouse_dir)s</value>
@@ -1657,16 +1696,21 @@ def hive_site_xml(is_local=False, use_sasl=False, thrift_uris='thrift://darkside
         <value>%(hs2_kerberos_principal)s</value>
       </property>
 
+      <property>
+        <name>hive.metastore.sasl.enabled</name>
+        <value>%(hs2_kauthentication)s</value>
+      </property>
+
       <property>
         <name>hive.metastore.sasl.enabled</name>
         <value>%(use_sasl)s</value>
       </property>
     </configuration>
   """ % {
-    'is_local': str(is_local).lower(),
-    'thrift_uris': thrift_uris,
+    'uris': uris,
     'warehouse_dir': warehouse_dir,
     'kerberos_principal': kerberos_principal,
     'hs2_kerberos_principal': hs2_kerberos_principal,
+    'hs2_kauthentication': hs2_kauthentication,
     'use_sasl': str(use_sasl).lower()
   }

+ 1 - 1
desktop/core/src/desktop/lib/thrift_sasl.py

@@ -35,7 +35,7 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
   def __init__(self, sasl_client_factory, mechanism, trans):
     """
     @param sasl_client_factory: a callable that returns a new sasl.Client object
-    @param mechanism: the SASL mechanism (e.g. "GSSAPI")
+    @param mechanism: the SASL mechanism (e.g. "GSSAPI", "PLAIN")
     @param trans: the underlying transport over which to communicate.
     """
     self._trans = trans

+ 22 - 12
desktop/core/src/desktop/lib/thrift_util.py

@@ -70,13 +70,17 @@ class ConnectionConfig(object):
   def __init__(self, klass, host, port, service_name,
                use_sasl=False,
                kerberos_principal="thrift",
+               mechanism='GSSAPI',
+               username='hue',
                timeout_seconds=45):
     """
     @param klass The thrift client class
     @param host Host to connect to
     @param port Port to connect to
     @param service_name A human-readable name to describe the service
-    @param use_sasl If true, will use Kerberos over SASL to authenticate
+    @param use_sasl If true, will use KERBEROS or PLAIN over SASL to authenticate
+    @param mechanism: GSSAPI or PLAIN if SASL
+    @param username: username if PLAIN SASL only
     @param kerberos_principal The Kerberos service name to connect to.
               NOTE: for a service like fooservice/foo.blah.com@REALM only
               specify "fooservice", NOT the full principal name.
@@ -87,11 +91,14 @@ class ConnectionConfig(object):
     self.port = port
     self.service_name = service_name
     self.use_sasl = use_sasl
+    self.mechanism = mechanism
+    self.username = username
     self.kerberos_principal = kerberos_principal
     self.timeout_seconds = timeout_seconds
 
   def __str__(self):
-    return ', '.join(map(str, [self.klass, self.host, self.port, self.service_name, self.use_sasl, self.kerberos_principal, self.timeout_seconds]))
+    return ', '.join(map(str, [self.klass, self.host, self.port, self.service_name, self.use_sasl, self.kerberos_principal, self.timeout_seconds,
+                               self.mechanism, self.username]))
 
 class ConnectionPooler(object):
   """
@@ -209,16 +216,19 @@ def connect_to_thrift(conf):
   if conf.timeout_seconds:
     # Thrift trivia: You can do this after the fact with
     # _grab_transport_from_wrapper(self.wrapped.transport).setTimeout(seconds*1000)
-    sock.setTimeout(conf.timeout_seconds*1000.0)
+    sock.setTimeout(conf.timeout_seconds * 1000.0)
   if conf.use_sasl:
     def sasl_factory():
       saslc = sasl.Client()
       saslc.setAttr("host", str(conf.host))
       saslc.setAttr("service", str(conf.kerberos_principal))
+      if conf.mechanism == 'PLAIN':
+        saslc.setAttr("username", str(conf.username))
+        saslc.setAttr("password", 'hue') # Just a non empty string
       saslc.init()
       return saslc
 
-    transport = TSaslClientTransport(sasl_factory, "GSSAPI", sock)
+    transport = TSaslClientTransport(sasl_factory, conf.mechanism, sock)
   else:
     transport = TBufferedTransport(sock)
 
@@ -432,7 +442,7 @@ def thrift2json(tft):
       N.B.: For maximal compatibility, the key type for map should be a basic type
       rather than a struct or container type. There are some languages which do not
       support more complex key types in their native map types. In addition the
-      JSON protocol only supports key types that are base types. 
+      JSON protocol only supports key types that are base types.
   I believe this ought to be true for sets, as well.
   """
   if isinstance(tft,type(None)):
@@ -468,7 +478,7 @@ def _jsonable2thrift_helper(jsonable, type_enum, spec_args, default, recursion_d
   Recursive implementation method of jsonable2thrift.
 
   type_enum corresponds to TType.  spec_args is part of the
-  thrift_spec explained in Thrift's code generator.  See 
+  thrift_spec explained in Thrift's code generator.  See
   compiler/cpp/src/generate/t_py_generator.cc .
   default is the default value.
 
@@ -490,7 +500,7 @@ def _jsonable2thrift_helper(jsonable, type_enum, spec_args, default, recursion_d
     """
     Helper function to check bounds.
 
-    The Thrift IDL specifies how many bytes numbers can be, and always uses 
+    The Thrift IDL specifies how many bytes numbers can be, and always uses
     signed integers.  This makes sure that the Thrift struct that comes out
     conforms to that schema.
     """
@@ -542,7 +552,7 @@ def _jsonable2thrift_helper(jsonable, type_enum, spec_args, default, recursion_d
         # thrift_spec is indexed by thrift tag id, so None shows up
         continue
       _, cur_type_enum, cur_name, cur_spec_args, cur_default = spec
-      value = _jsonable2thrift_helper(jsonable.get(cur_name), 
+      value = _jsonable2thrift_helper(jsonable.get(cur_name),
         cur_type_enum, cur_spec_args, cur_default, recursion_depth + 1)
       setattr(out, cur_name, value)
     return out
@@ -577,7 +587,7 @@ def _jsonable2thrift_helper(jsonable, type_enum, spec_args, default, recursion_d
 
   else:
     raise Exception("Unrecognized type: %s.  Value was %s." % (repr(type_enum), repr(jsonable)))
-    
+
 def jsonable2thrift(jsonable, thrift_class):
   """
   Converts a JSON-able x that represents a thrift struct
@@ -588,9 +598,9 @@ def jsonable2thrift(jsonable, thrift_class):
   This is compatible with thrift2json.
   """
   return _jsonable2thrift_helper(
-    jsonable, 
-    TType.STRUCT, 
-    (thrift_class, thrift_class.thrift_spec), 
+    jsonable,
+    TType.STRUCT,
+    (thrift_class, thrift_class.thrift_spec),
     default=None,
     recursion_depth=0
   )