Przeglądaj źródła

HUE-2534 [core] HTTP Thrift transport implementation

Re-using requests in order to get HTTPS and Kerberos for free.
Implement for HBase Thrift Server v1
Implement for HiveServer2

To activate the mode, set in their corresponding properties in the
hive-site.xml and hbase-site.xml pointed by Hue.
Romain Rigaux 11 lat temu
rodzic
commit
b3cebeccaa

+ 1 - 0
apps/beeswax/src/beeswax/conf.py

@@ -95,6 +95,7 @@ THRIFT_VERSION = Config(
   default=7
 )
 
+
 SSL = ConfigSection(
   key='ssl',
   help=_t('SSL configuration for the server.'),

+ 13 - 0
apps/beeswax/src/beeswax/hive_site.py

@@ -46,6 +46,10 @@ _CNF_HIVESERVER2_KERBEROS_PRINCIPAL = 'hive.server2.authentication.kerberos.prin
 _CNF_HIVESERVER2_AUTHENTICATION = 'hive.server2.authentication'
 _CNF_HIVESERVER2_IMPERSONATION = 'hive.server2.enable.doAs'
 
+_CNF_HIVESERVER2_TRANSPORT_MODE = 'hive.server2.transport.mode'
+_CNF_HIVESERVER2_THRIFT_HTTP_PORT = 'hive.server2.thrift.http.port'
+_CNF_HIVESERVER2_THRIFT_HTTP_PATH = 'hive.server2.thrift.http.path'
+
 
 # Host is whatever up to the colon. Allow and ignore a trailing slash.
 _THRIFT_URI_RE = re.compile("^thrift://([^:]+):(\d+)[/]?$")
@@ -126,6 +130,15 @@ def hiveserver2_impersonation_enabled():
 def hiveserver2_jdbc_url():
   return 'jdbc:hive2://%s:%s/default' % (beeswax.conf.HIVE_SERVER_HOST.get(), beeswax.conf.HIVE_SERVER_PORT.get())
 
+def hiveserver2_transport_mode():
+  return get_conf().get(_CNF_HIVESERVER2_TRANSPORT_MODE, 'TCP').upper()
+
+def hiveserver2_thrift_http_port():
+  return get_conf().get(_CNF_HIVESERVER2_THRIFT_HTTP_PORT, '10001')
+
+def hiveserver2_thrift_http_path():
+  return get_conf().get(_CNF_HIVESERVER2_THRIFT_HTTP_PATH, 'cliservice')
+
 
 def _parse_hive_site():
   """

+ 9 - 2
apps/beeswax/src/beeswax/server/dbms.py

@@ -24,7 +24,7 @@ from django.utils.encoding import force_unicode
 from django.utils.translation import ugettext as _
 
 from beeswax import hive_site
-from beeswax.conf import HIVE_SERVER_HOST, HIVE_SERVER_PORT, BROWSE_PARTITIONED_TABLE_LIMIT
+from beeswax.conf import HIVE_SERVER_HOST, HIVE_SERVER_PORT, BROWSE_PARTITIONED_TABLE_LIMIT, SSL
 from beeswax.design import hql_query
 from beeswax.models import QueryHistory, QUERY_TYPES
 
@@ -81,7 +81,14 @@ def get_query_server_config(name='beeswax', server=None):
         'server_name': 'beeswax', # Aka HiveServer2 now
         'server_host': HIVE_SERVER_HOST.get(),
         'server_port': HIVE_SERVER_PORT.get(),
-        'principal': kerberos_principal
+        'principal': kerberos_principal,
+        'http_url': '%(protocol)s://%(host)s:%(port)s/%(end_point)s' % {
+            'protocol': 'https' if SSL.ENABLED.get() else 'http',
+            'host': HIVE_SERVER_HOST.get(),
+            'port': hive_site.hiveserver2_thrift_http_port(),
+            'end_point': hive_site.hiveserver2_thrift_http_path()
+        },
+        'transport_mode': hive_site.hiveserver2_transport_mode(),
     }
 
   LOG.debug("Query Server: %s" % query_server)

+ 4 - 1
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -478,7 +478,10 @@ class HiveServerClient:
                                           ca_certs=ca_certs,
                                           keyfile=keyfile,
                                           certfile=certfile,
-                                          validate=validate)
+                                          validate=validate,
+                                          transport_mode=query_server.get('transport_mode', 'socket'),
+                                          http_url=query_server.get('http_url', '')
+    )
 
 
   def get_security(self):

+ 20 - 5
apps/hbase/src/hbase/api.py

@@ -37,6 +37,9 @@ LOG = logging.getLogger(__name__)
 # Format methods similar to Thrift API, for similarity with catch-all
 class HbaseApi(object):
 
+  def __init__(self, user):
+    self.user = user
+
   def query(self, action, *args):
     try:
       if hasattr(self, action):
@@ -57,15 +60,16 @@ class HbaseApi(object):
   def getClusters(self):
     clusters = []
     try:
-      full_config = json.loads(conf.HBASE_CLUSTERS.get().replace("'","\""))
+      full_config = json.loads(conf.HBASE_CLUSTERS.get().replace("'", "\""))
     except:
       full_config = [conf.HBASE_CLUSTERS.get()]
     for config in full_config: #hack cause get() is weird
-      match = re.match('\((?P<name>[^\(\)\|]+)\|(?P<host>.+):(?P<port>[0-9]+)\)', config)
+
+      match = re.match('\((?P<name>[^\(\)\|]+)\|(?P<protocol>https?://)?(?P<host>.+):(?P<port>[0-9]+)\)', config)
       if match:
         clusters += [{
           'name': match.group('name'),
-          'host': match.group('host'),
+          'host': (match.group('protocol') + match.group('host')) if match.group('protocol') else match.group('host'),
           'port': int(match.group('port'))
         }]
       else:
@@ -85,14 +89,25 @@ class HbaseApi(object):
   def connectCluster(self, name):
     _security = self._get_security()
     target = self.getCluster(name)
-    return thrift_util.get_client(get_client_type(),
+    client = thrift_util.get_client(get_client_type(),
                                   target['host'],
                                   target['port'],
                                   service_name="Hue HBase Thrift Client for %s" % name,
                                   kerberos_principal=_security['kerberos_principal_short_name'],
                                   use_sasl=_security['use_sasl'],
                                   timeout_seconds=None,
-                                  transport=conf.THRIFT_TRANSPORT.get())
+                                  transport=conf.THRIFT_TRANSPORT.get(),
+                                  transport_mode=conf.TRANSPORT_MODE.get(),
+                                  http_url=\
+                                      ('http://' if (conf.TRANSPORT_MODE.get() == 'http' and not target['host'].startswith('http')) else '') \
+                                      + target['host'] + ':' + str(target['port'])
+    )
+
+    if hasattr(client, 'setCustomHeaders'):
+      client.setCustomHeaders({'doAs': self.user.username})
+
+    return client
+
   @classmethod
   def _get_security(cls):
     principal = get_server_principal()

+ 11 - 3
apps/hbase/src/hbase/conf.py

@@ -25,14 +25,17 @@ from desktop.lib.conf import Config, validate_thrift_transport
 HBASE_CLUSTERS = Config(
   key="hbase_clusters",
   default="(Cluster|localhost:9090)",
-  help=_t("Comma-separated list of HBase Thrift servers for clusters in the format of '(name|host:port)'. Use full hostname with security."),
-  type=str)
+  help=_t("Comma-separated list of HBase Thrift servers for clusters in the format of '(name|host:port)'. Use full hostname with security."
+          "Prefix hostname with https:// if using SSL and http mode with impersonation."),
+  type=str
+)
 
 TRUNCATE_LIMIT = Config(
   key="truncate_limit",
   default="500",
   help=_t("Hard limit of rows or columns per row fetched before truncating."),
-  type=int)
+  type=int
+)
 
 THRIFT_TRANSPORT = Config(
   key="thrift_transport",
@@ -49,6 +52,11 @@ HBASE_CONF_DIR = Config(
   default=os.environ.get("HBASE_CONF_DIR", '/etc/hbase/conf')
 )
 
+TRANSPORT_MODE = Config(
+  key="transport_mode",
+  help=_t("Force the underlying mode of the Thrift Transport: socket|http. http is required for using the doAs impersonation."),
+  default='socket'
+)
 
 
 def config_validator(user):

+ 1 - 1
apps/hbase/src/hbase/views.py

@@ -73,7 +73,7 @@ def api_router(request, url): # On split, deserialize anything
   if request.POST.get('dest', False):
     url_params += [request.FILES.get(request.REQUEST.get('dest'))]
 
-  return api_dump(HbaseApi().query(*url_params))
+  return api_dump(HbaseApi(request.user).query(*url_params))
 
 def api_dump(response):
   ignored_fields = ('thrift_spec', '__.+__')

+ 4 - 0
desktop/conf.dist/hue.ini

@@ -941,6 +941,10 @@
   # which is useful when used in conjunction with the nonblocking server in Thrift.
   ## thrift_transport=buffered
 
+  # The underlying mode of the Thrift Transport: socket or http.
+  # http is required for using the doAs impersonation.
+  ## transport_mode=socket
+
 
 ###########################################################################
 # Settings to configure Solr Search

+ 3 - 0
desktop/conf/pseudo-distributed.ini.tmpl

@@ -821,6 +821,9 @@
   # Thrift version to use when communicating with HiveServer2.
   ## thrift_version=7
 
+  # The underlying mode of the Thrift Transport: socket or http
+  ## transport_mode=socket
+
   [[ssl]]
     # SSL communication enabled for this server.
     ## enabled=false

+ 5 - 2
desktop/core/src/desktop/lib/rest/http_client.py

@@ -22,12 +22,14 @@ import urllib
 from django.utils.encoding import iri_to_uri, smart_str
 
 from requests import exceptions
+from requests.auth import HTTPBasicAuth
 from requests_kerberos import HTTPKerberosAuth, OPTIONAL
 
 __docformat__ = "epytext"
 
 LOG = logging.getLogger(__name__)
 
+
 class RestException(Exception):
   """
   Any error result from the Rest API is converted into this exception type.
@@ -75,8 +77,6 @@ class HttpClient(object):
     """
     @param base_url: The base url to the API.
     @param exc_class: An exception class to handle non-200 results.
-
-    Creates an HTTP(S) client to connect to the Cloudera Manager API.
     """
     self._base_url = base_url.rstrip('/')
     self._exc_class = exc_class or RestException
@@ -88,6 +88,9 @@ class HttpClient(object):
     self._session.auth = HTTPKerberosAuth(mutual_authentication=OPTIONAL)
     return self
 
+  def set_basic_auth(self, username, password):
+    self._session.auth = HTTPBasicAuth(username, password)
+    return self
 
   def set_headers(self, headers):
     """

+ 17 - 0
desktop/core/src/desktop/lib/thrift_/__init__.py

@@ -0,0 +1,17 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+

+ 95 - 0
desktop/core/src/desktop/lib/thrift_/http_client.py

@@ -0,0 +1,95 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import httplib
+import logging
+import os
+import socket
+import sys
+import urllib
+import urlparse
+import warnings
+
+from cStringIO import StringIO
+
+from thrift.transport.TTransport import *
+
+from desktop.lib.rest.http_client import HttpClient
+from desktop.lib.rest.resource import Resource
+
+
+LOG = logging.getLogger(__name__)
+
+
+class THttpClient(TTransportBase):
+  """
+  HTTP transport mode for Thrift.
+
+  HTTPS and Kerberos support with Request.
+
+  e.g.
+  mode = THttpClient('http://hbase-thrift-v1.com:9090')
+  mode = THttpClient('http://hive-localhost:10001/cliservice')
+  """
+
+  def __init__(self, base_url):
+    self._base_url = base_url
+    self._client = HttpClient(self._base_url, logger=LOG)
+    self._data = None
+    self._headers = None
+    self._wbuf = StringIO()
+
+  def open(self):
+    pass
+
+  def set_kerberos_auth(self):
+    self._client.set_kerberos_auth()
+
+  def set_basic_auth(self, username, password):
+    self._client.set_basic_auth(username, password)
+
+  def close(self):
+    self._headers = None
+    # Close session too?
+
+  def isOpen(self):
+    return self._client is not None
+
+  def setTimeout(self, ms):
+    pass
+
+  def setCustomHeaders(self, headers):
+    self._headers = headers
+
+  def read(self, sz):
+    return self._data
+
+  def write(self, buf):
+    self._wbuf.write(buf)
+
+  def flush(self):
+    if self.isOpen():
+      self.close()
+    self.open()
+
+    data = self._wbuf.getvalue()
+    self._wbuf = StringIO()
+
+    # POST
+    self._root = Resource(self._client)
+    self._data = self._root.post('', data=data)

+ 34 - 14
desktop/core/src/desktop/lib/thrift_util.py

@@ -15,7 +15,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-# Utilities for Thrift
 
 import Queue
 import logging
@@ -32,7 +31,9 @@ from thrift.transport.TTransport import TBufferedTransport, TFramedTransport, TM
                                         TTransportException
 from thrift.protocol.TBinaryProtocol import TBinaryProtocol
 from thrift.protocol.TMultiplexedProtocol import TMultiplexedProtocol
+
 from desktop.lib.python_util import create_synchronous_io_multiplexer
+from desktop.lib.thrift_.http_client import THttpClient
 from desktop.lib.thrift_sasl import TSaslClientTransport
 from desktop.lib.exceptions import StructuredException, StructuredThriftTransportException
 
@@ -46,6 +47,7 @@ MAX_RECURSION_DEPTH = 50
 WARN_LEVEL_CALL_DURATION_MS = 5000
 INFO_LEVEL_CALL_DURATION_MS = 1000
 
+
 class LifoQueue(Queue.Queue):
     '''
     Variant of Queue that retrieves most recently added entries first.
@@ -67,6 +69,7 @@ class LifoQueue(Queue.Queue):
     def _get(self):
         return self.queue.pop()
 
+
 class ConnectionConfig(object):
   """ Struct-like class encapsulating the configuration of a Thrift client. """
   def __init__(self, klass, host, port, service_name,
@@ -82,7 +85,9 @@ class ConnectionConfig(object):
                validate=False,
                timeout_seconds=45,
                transport='buffered',
-               multiple=False):
+               multiple=False,
+               transport_mode='socket',
+               http_url=''):
     """
     @param klass The thrift client class
     @param host Host to connect to
@@ -103,6 +108,8 @@ class ConnectionConfig(object):
     @param timeout_seconds Timeout for thrift calls
     @param transport string representation of thrift transport to use
     @param multiple Whether Use MultiplexedProtocol
+    @param transport_mode Can be socket or http
+    @param Url used when using http transport mode
     """
     self.klass = klass
     self.host = host
@@ -121,10 +128,13 @@ class ConnectionConfig(object):
     self.timeout_seconds = timeout_seconds
     self.transport = transport
     self.multiple = multiple
+    self.transport_mode = transport_mode
+    self.http_url = http_url
 
   def __str__(self):
     return ', '.join(map(str, [self.klass, self.host, self.port, self.service_name, self.use_sasl, self.kerberos_principal, self.timeout_seconds,
-                               self.mechanism, self.username, self.use_ssl, self.ca_certs, self.keyfile, self.certfile, self.validate, self.transport, self.multiple]))
+                               self.mechanism, self.username, self.use_ssl, self.ca_certs, self.keyfile, self.certfile, self.validate, self.transport,
+                               self.multiple, self.transport_mode, self.http_url]))
 
 class ConnectionPooler(object):
   """
@@ -238,30 +248,40 @@ def connect_to_thrift(conf):
 
   Returns a tuple of (service, protocol, transport)
   """
-  if conf.use_ssl:
-    sock = TSSLSocket(conf.host, conf.port, validate=conf.validate, ca_certs=conf.ca_certs, keyfile=conf.keyfile, certfile=conf.certfile)
+  if conf.transport_mode == 'TCP':
+    if conf.use_ssl:
+      mode = TSSLSocket(conf.host, conf.port, validate=conf.validate, ca_certs=conf.ca_certs, keyfile=conf.keyfile, certfile=conf.certfile)
+    else:
+      mode = TSocket(conf.host, conf.port)
   else:
-    sock = TSocket(conf.host, conf.port)
+    mode = THttpClient(conf.http_url)
+
   if conf.timeout_seconds:
     # Thrift trivia: You can do this after the fact with
     # _grab_transport_from_wrapper(self.wrapped.transport).setTimeout(seconds*1000)
-    sock.setTimeout(conf.timeout_seconds * 1000.0)
-  if conf.use_sasl:
+    mode.setTimeout(conf.timeout_seconds * 1000.0)
+
+  if conf.transport_mode == 'HTTP':
+    if conf.use_sasl and conf.mechanism != 'PLAIN':
+      mode.set_kerberos_auth()
+    else:
+      mode.set_basic_auth(conf.username, conf.password)
+
+  if conf.transport_mode == 'TCP' and conf.use_sasl:
     def sasl_factory():
       saslc = sasl.Client()
       saslc.setAttr("host", str(conf.host))
       saslc.setAttr("service", str(conf.kerberos_principal))
       if conf.mechanism == 'PLAIN':
         saslc.setAttr("username", str(conf.username))
-        saslc.setAttr("password", str(conf.password)) # defaults to hue for a non-empty string unless using ldap
+        saslc.setAttr("password", str(conf.password)) # Defaults to 'hue' for a non-empty string unless using LDAP
       saslc.init()
       return saslc
-
-    transport = TSaslClientTransport(sasl_factory, conf.mechanism, sock)
+    transport = TSaslClientTransport(sasl_factory, conf.mechanism, mode)
   elif conf.transport == 'framed':
-    transport = TFramedTransport(sock)
+    transport = TFramedTransport(mode)
   else:
-    transport = TBufferedTransport(sock)
+    transport = TBufferedTransport(mode)
 
   protocol = TBinaryProtocol(transport)
   if conf.multiple:
@@ -324,7 +344,7 @@ class PooledClient(object):
         try:
           # Poke it to see if it's closed on the other end. This can happen if a connection
           # sits in the connection pool longer than the read timeout of the server.
-          sock = _grab_transport_from_wrapper(superclient.transport).handle
+          sock = self.conf.transport_mode == 'TCP' and _grab_transport_from_wrapper(superclient.transport).handle
           if sock and create_synchronous_io_multiplexer().read([sock]):
             # the socket is readable, meaning there is either data from a previous call
             # (i.e our protocol is out of sync), or the connection was shut down on the