Browse Source

Refactor thrift client config into a utility class

Todd Lipcon 15 years ago
parent
commit
7db5306726

+ 45 - 43
desktop/core/src/desktop/lib/thrift_util.py

@@ -30,7 +30,7 @@ from thrift.Thrift import TType
 from thrift.transport.TSocket import TSocket
 from thrift.transport.TTransport import TBufferedTransport, TMemoryBuffer,\
                                         TTransportException
-from thrift.protocol.TBinaryProtocol import TBinaryProtocol
+from thrift.protocol.TBinaryProtocol import TBinaryProtocol, TBinaryProtocolAccelerated
 from desktop.lib.thrift_sasl import TSaslClientTransport
 
 # The maximum depth that we will recurse through a "jsonable" structure
@@ -43,6 +43,20 @@ MAX_RECURSION_DEPTH = 50
 WARN_LEVEL_CALL_DURATION_MS = 5000
 INFO_LEVEL_CALL_DURATION_MS = 1000
 
+class ConnectionConfig(object):
+  def __init__(self, klass, host, port, service_name,
+               use_sasl=False,
+               kerberos_principal="thrift",
+               timeout_seconds=45):
+    self.klass = klass
+    self.host = host
+    self.port = port
+    self.service_name = service_name
+    self.use_sasl = use_sasl
+    self.kerberos_principal = kerberos_principal
+    self.timeout_seconds = timeout_seconds
+
+
 class ConnectionPooler(object):
   """
   Thread-safe connection pooling for thrift. (With about 3 changes,
@@ -65,8 +79,7 @@ class ConnectionPooler(object):
     self.poolsize = poolsize
     self.dictlock = threading.Lock()
 
-  def get_client(self, klass, host, port, service_name="Unknown",
-                 kerberos_principal="thrift",
+  def get_client(self, conf,
                  get_client_timeout=None):
     """
     Could block while we wait for the pool to become non-empty.
@@ -75,7 +88,7 @@ class ConnectionPooler(object):
                                to get a client before failing
     """
     # First up, check to see if we have a pool for this endpoint
-    if (host,port) not in self.pooldict:
+    if (conf.host, conf.port) not in self.pooldict:
       # Uh-oh, we need to initialise the queue. Take the dict lock.
       # Note that this is 'double-checked locking'.
 
@@ -91,13 +104,11 @@ class ConnectionPooler(object):
 
       self.dictlock.acquire()
       try:
-        if (host, port) not in self.pooldict:
+        if (conf.host, conf.port) not in self.pooldict:
           q = Queue.Queue(self.poolsize)
-          self.pooldict[(host, port)] = q
+          self.pooldict[(conf.host, conf.port)] = q
           for i in xrange(self.poolsize):
-            client = construct_client(klass, host, port,
-                                      service_name=service_name,
-                                      kerberos_principal=kerberos_principal)
+            client = construct_client(conf)
             client.CID = i
             q.put(client, False)
       finally:
@@ -114,16 +125,16 @@ class ConnectionPooler(object):
         this_round_timeout = None
 
       try:
-        connection = self.pooldict[(host, port)].get(
+        connection = self.pooldict[(conf.host, conf.port)].get(
           block=True, timeout=this_round_timeout)
       except Queue.Empty:
         has_waited_for = time.time() - start_pool_get_time
         if get_client_timeout is not None and has_waited_for > get_client_timeout:
           raise socket.timeout(
             ("Timed out after %.2f seconds waiting to retrieve a " +
-             "%s client from the pool.") % (has_waited_for, service_name))
+             "%s client from the pool.") % (has_waited_for, conf.service_name))
         logging.warn("Waited %d seconds for a thrift client to %s:%d" %
-          (has_waited_for, host, port))
+          (has_waited_for, conf.host, conf.port))
 
     return connection
 
@@ -135,35 +146,37 @@ class ConnectionPooler(object):
     """
     self.pooldict[(host, port)].put(client)
 
-def construct_client(klass, host, port, service_name, kerberos_principal="thrift", timeout_seconds=45):
+def construct_client(conf):
   """
   Constructs a thrift client, lazily.
   """
 
   def sasl_factory():
     saslc = sasl.Client()
-    saslc.setAttr("host", host)
-    saslc.setAttr("service", kerberos_principal)
+    saslc.setAttr("host", conf.host)
+    saslc.setAttr("service", conf.kerberos_principal)
     saslc.init()
     return saslc
 
-  logging.info("service: %s   host: %s" % (kerberos_principal, host))
-  sock = TSocket(host, port)
-  if timeout_seconds:
+  logging.info("service: %s   host: %s" % (conf.kerberos_principal, conf.host))
+  sock = TSocket(conf.host, conf.port)
+  if conf.timeout_seconds:
     # Thrift trivia: You can do this after the fact with
     # self.wrapped.transport._TBufferedTransport__trans.setTimeout(seconds*1000)
-    sock.setTimeout(timeout_seconds*1000.0)
+    sock.setTimeout(conf.timeout_seconds*1000.0)
   transport = TSaslClientTransport(sasl_factory, "GSSAPI", sock)
-  protocol = TBinaryProtocol(transport)
-  service = klass(protocol)
-  return SuperClient(service, transport, timeout_seconds=timeout_seconds)
+  protocol = TBinaryProtocolAccelerated(transport)
+  service = conf.klass(protocol)
+  return SuperClient(service, transport, timeout_seconds=conf.timeout_seconds)
 
 _connection_pool = ConnectionPooler()
 
-def get_client(klass, host, port, service_name, kerberos_principal="thrift", timeout_seconds=None):
-  return PooledClient(klass,host,port,service_name,
-                      kerberos_principal=kerberos_principal,
-                      timeout_seconds=timeout_seconds)
+def get_client(klass, host, port, service_name,
+               **kwargs):
+  conf = ConnectionConfig(
+    klass,host,port,service_name,
+    **kwargs)
+  return PooledClient(conf)
 
 def _grab_transport_from_wrapper(outer_transport):
   if isinstance(outer_transport, TBufferedTransport):
@@ -177,26 +190,15 @@ class PooledClient(object):
   """
   A wrapper for a SuperClient
   """
-  def __init__(self, klass, host, port,
-               service_name = "Unknown",
-               kerberos_principal="thrift",
-               timeout_seconds=None):
-    self.klass = klass
-    self.host = host
-    self.port = port
-    self.kerberos_principal = kerberos_principal
-    self.timeout_seconds = timeout_seconds
-    self.service_name = service_name
+  def __init__(self, conf):
+    self.conf = conf
 
   def __getattr__(self,attr):
     if attr in self.__dict__:
       return self.__dict__[attr]
 
     # Fetch the thrift client from the pool
-    superclient = _connection_pool.get_client(self.klass, self.host, self.port,
-                                              kerberos_principal=self.kerberos_principal,
-                                              get_client_timeout=self.timeout_seconds,
-                                              service_name=self.service_name)
+    superclient = _connection_pool.get_client(self.conf)
 
     res = getattr(superclient, attr)
     if hasattr(res,"__call__"):
@@ -215,18 +217,18 @@ class PooledClient(object):
                 superclient.transport.close()
                 superclient.transport.open()
 
-            superclient.set_timeout(self.timeout_seconds)
+            superclient.set_timeout(self.conf.timeout_seconds)
             ret = res(*args, **kwargs)
             return ret
           except Exception, e:
             # Stack tends to be only noisy here.
             logging.info("Thrift saw exception: " + str(e), exc_info=False)
             msg = "Exception communicating with %s at %s:%d: %s" % (
-              self.service_name, self.host, self.port, str(e))
+              self.conf.service_name, self.conf.host, self.conf.port, str(e))
             e.response_data = dict(code="THRIFT_EXCEPTION", message=msg, data="")
             raise
         finally:
-          _connection_pool.return_client(self.host,self.port,superclient)
+          _connection_pool.return_client(self.conf.host,self.conf.port,superclient)
       return wrapper
     return res
 

+ 2 - 5
desktop/libs/hadoop/src/hadoop/cluster.py

@@ -34,11 +34,8 @@ def _make_filesystem(identifier):
     return LocalSubFileSystem(path)
   else:
     cluster_conf = conf.HDFS_CLUSTERS[identifier]
-    return hadoopfs.HadoopFileSystem(
-      cluster_conf.NN_HOST.get(),
-      cluster_conf.NN_THRIFT_PORT.get(),
-      cluster_conf.NN_HDFS_PORT.get(),
-      kerberos_principal=cluster_conf.NN_KERBEROS_PRINCIPAL.get(),
+    return hadoopfs.HadoopFileSystem.from_config(
+      cluster_conf,
       hadoop_bin_path=conf.HADOOP_BIN.get())
     raise Exception("Unknown choice: %s" % choice)
 

+ 2 - 0
desktop/libs/hadoop/src/hadoop/conf.py

@@ -117,6 +117,8 @@ HDFS_CLUSTERS = UnspecifiedConfigSection(
                             type=int),
       NN_KERBEROS_PRINCIPAL=Config("kerberos_principal", help="Kerberos principal for NameNode",
                                    default="hdfs", type=str),
+      SECURITY_ENABLED=Config("security_enabled", help="Is running with Kerberos authentication",
+                              default=False, type=bool),
     )
   )
 )

+ 20 - 9
desktop/libs/hadoop/src/hadoop/fs/hadoopfs.py

@@ -93,11 +93,8 @@ def test_fs_configuration(fs_config, hadoop_bin_conf):
 
   # Check thrift plugin
   try:
-    fs = HadoopFileSystem(host=fs_config.NN_HOST.get(),
-                          thrift_port=fs_config.NN_THRIFT_PORT.get(),
-                          hdfs_port=fs_config.NN_HDFS_PORT.get(),
-                          kerberos_principal=fs_config.NN_KERBEROS_PRINCIPAL.get(),
-                          hadoop_bin_path=hadoop_bin_conf.get())
+    fs = HadoopFileSystem.from_config(
+      fs_config, hadoop_bin_path=hadoop_bin_conf.get())
 
     fs.setuser(fs.superuser)
     ls = fs.listdir('/')
@@ -179,6 +176,7 @@ class HadoopFileSystem(object):
 
   def __init__(self, host, thrift_port, hdfs_port=8020,
                kerberos_principal="hdfs",
+               security_enabled=False,
                hadoop_bin_path="hadoop"):
     """
     @param host hostname or IP of the namenode
@@ -191,14 +189,17 @@ class HadoopFileSystem(object):
     self.host = host
     self.thrift_port = thrift_port
     self.hdfs_port = hdfs_port
+    self.security_enabled = security_enabled
     self.kerberos_principal = kerberos_principal
     self.hadoop_bin_path = hadoop_bin_path
     self._resolve_hadoop_path()
 
-    self.nn_client = thrift_util.get_client(Namenode.Client, host, thrift_port,
-        service_name="HDFS Namenode HUE Plugin",
-        kerberos_principal=kerberos_principal,
-        timeout_seconds=NN_THRIFT_TIMEOUT)
+    self.nn_client = thrift_util.get_client(
+      Namenode.Client, host, thrift_port,
+      service_name="HDFS Namenode HUE Plugin",
+      use_sasl=security_enabled,
+      kerberos_principal=kerberos_principal,
+      timeout_seconds=NN_THRIFT_TIMEOUT)
 
     # The file systems are cached globally.  We store
     # user information in a thread-local variable so that
@@ -207,6 +208,16 @@ class HadoopFileSystem(object):
     self.setuser(DEFAULT_USER)
     LOG.debug("Initialized HadoopFS: %s:%d (%s)", host, thrift_port, hadoop_bin_path)
 
+  @classmethod
+  def from_config(cls, fs_config, hadoop_bin_path="hadoop"):
+    return cls(host=fs_config.NN_HOST.get(),
+               thrift_port=fs_config.NN_THRIFT_PORT.get(),
+               hdfs_port=fs_config.NN_HDFS_PORT.get(),
+               security_enabled=fs_config.SECURITY_ENABLED.get(),
+               kerberos_principal=fs_config.NN_KERBEROS_PRINCIPAL.get(),
+               hadoop_bin_path=hadoop_bin_path)
+
+
   def _get_hdfs_base(self):
     return "hdfs://%s:%d" % (self.host, self.hdfs_port) # TODO(todd) fetch the port from the NN thrift