Sfoglia il codice sorgente

HUE-553. Thrift pooled client is not thread safe

bc Wong 14 anni fa
parent
commit
30bf4e049b

+ 2 - 1
desktop/core/src/desktop/lib/djangothrift_test.thrift

@@ -53,5 +53,6 @@ struct TestManyTypes {
 
 
 service TestService {
-  i32 ping(1:i32 input)
+  // Multiply the input by 2 and return the result
+  i32 ping(1:i32 input);
 }

+ 14 - 9
desktop/core/src/desktop/lib/thrift_util.py

@@ -246,17 +246,23 @@ class PooledClient(object):
   def __init__(self, conf):
     self.conf = conf
 
-  def __getattr__(self,attr):
+  def __getattr__(self, attr):
     if attr in self.__dict__:
       return self.__dict__[attr]
 
     # Fetch the thrift client from the pool
     superclient = _connection_pool.get_client(self.conf)
 
-    try:
-      res = getattr(superclient, attr)
-      if hasattr(res,"__call__"):
-        def wrapper(*args, **kwargs):
+    res = getattr(superclient, attr)
+    if not callable(res):
+      # It's a simple attribute. We can put the superclient back in the pool.
+      _connection_pool.return_client(self.conf, superclient)
+      return res
+    else:
+      # It's gonna be a thrift call. Add wrapping logic to reopen the transport,
+      # and return the connection to the pool when done.
+      def wrapper(*args, **kwargs):
+        try:
           try:
             # Poke it to see if it's closed on the other end. This can happen if a connection
             # sits in the connection pool longer than the read timeout of the server.
@@ -283,10 +289,9 @@ class PooledClient(object):
               self.conf.service_name, self.conf.host, self.conf.port, str(e))
             e.response_data = dict(code="THRIFT_EXCEPTION", message=msg, data="")
             raise
-        return wrapper
-      return res
-    finally:
-      _connection_pool.return_client(self.conf, superclient)
+        finally:
+          _connection_pool.return_client(self.conf, superclient)
+      return wrapper
 
 
 

+ 116 - 64
desktop/core/src/desktop/lib/thrift_util_test.py

@@ -13,13 +13,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+import logging
 import os
 import socket
 import sys
 import threading
-import unittest
 import time
-import logging
+import unittest
 
 gen_py_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "gen-py"))
 if not gen_py_path in sys.path:
@@ -28,79 +29,130 @@ if not gen_py_path in sys.path:
 from djangothrift_test_gen.ttypes import TestStruct, TestNesting, TestEnum, TestManyTypes
 from djangothrift_test_gen import TestService
 
+import hadoop
 import thrift_util
 from thrift_util import jsonable2thrift, thrift2json
 
+from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
 from thrift.server import TServer
 from thrift.transport import TSocket
+from thrift.transport.TTransport import TBufferedTransportFactory
 
 from nose.tools import assert_equal
-from nose.plugins.skip import SkipTest
-
-class TestSuperClient(unittest.TestCase):
-  class TestHandler(object):
-    def ping(self, in_val):
-      return in_val * 2
-
-    @classmethod
-    def start_server_thread(cls):
-      """Starts a test server, returns the ServerThread object started."""
-      handler = cls()
-      processor = TestService.Processor(handler)
-      transport = TSocket.TServerSocket(0)
-      server = TServer.TSimpleServer(processor, transport)
-
-      class ServerThread(threading.Thread):
-        def __init__(self, server):
-          threading.Thread.__init__(self)
-          self.server = server
-          self.stopped = False
-
-        def run(self):
-          try:
-            logging.info("About to serve...")
-            self.server.serve()
-            logging.info("Done serving...")
-          except:
-            assert self.stopped
-
-        def get_port(self):
-          return self.server.serverTransport.handle.getsockname()[1]
-
-        def stop(self):
-          # This closes the listening socket, but the current accept()
-          # call keeps going. So we have to ping that port
-          self.stopped = True
-          port = self.get_port()
-          self.server.serverTransport.close() # hopefully this works?
-          ping_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-          ping_s.connect(('localhost', port))
-          ping_s.close()
-          logging.info("Waiting for server to stop")
-          self.join()
-
-      thr = ServerThread(server)
-      thr.start()
-      while not transport.isOpen():
-        logging.info("Waiting for server to start")
+
+
+class SimpleThriftServer(object):
+  """A simple thrift server impl"""
+  def __init__(self):
+    self.port = hadoop.mini_cluster.find_unused_port()
+    self.pid = 0
+
+  def ping(self, in_val):
+    return in_val * 2
+
+  def start_server_process(self):
+    """
+    Starts a test server, returns the (pid, port) pair.
+
+    The server needs to be in a subprocess because we need to run a
+    TThreadedServer for the concurrency tests. And the only way to stop a
+    TThreadedServer is to kill it. So we can't just use a thread.
+    """
+    self.pid = os.fork()
+    if self.pid != 0:
+      logging.info("Started SimpleThriftServer (pid %s) on port %s" %
+                   (self.pid, self.port))
+      self._ensure_online()
+      return
+
+    # Child process runs the thrift server loop
+    try:
+      processor = TestService.Processor(self)
+      transport = TSocket.TServerSocket(self.port)
+      server = TServer.TThreadedServer(processor,
+                                       transport,
+                                       TBufferedTransportFactory(),
+                                       TBinaryProtocolFactory())
+      server.serve()
+    except:
+      sys.exit(1)
+
+  def _ensure_online(self):
+    """Ensure that the child server is online"""
+    deadline = time.time() + 60
+    while time.time() < deadline:
+      logging.info("Waiting for service to come online")
+      try:
+        ping_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        ping_s.connect(('localhost', self.port))
+        ping_s.close()
+        return
+      except:
+        _, status = os.waitpid(self.pid, os.WNOHANG)
+        if status != 0:
+          logging.info("SimpleThriftServer child process exited with %s" % (status,))
         time.sleep(0.1)
 
-      return thr
+    logging.info("SimpleThriftServer took too long to come online")
+    self.stop_server_process()
+
+  def stop_server_process(self):
+    """Stop the server"""
+    if self.pid == 0:
+      return
 
-  # TODO(todd) I couldn't get this to work after much effort.
-  # Thrift's server doesn't really have a reasonable lifecycle
-  # interface, so hard to bring up a thrift server inside a test.
-  def test_basic_operation(self):
-    raise SkipTest()
-    server = TestSuperClient.TestHandler.start_server_thread()
     try:
-      test_client = thrift_util.get_client(TestService.Client,
-                                           '127.0.0.1',
-                                           server.get_port(),
-                                           timeout_seconds=1)
-      assert_equal(10, test_client.ping(5))
-    finally:
-      server.stop()
+      logging.info("Stopping SimpleThriftServer (pid %s)" % (self.pid,))
+      os.kill(self.pid, 15)
+    except Exception, ex:
+      logging.exception("(Potentially ok) Exception while stopping server")
+    os.waitpid(self.pid, 0)
+    self.pid = 0
+
+
+class TestWithThriftServer(object):
+  @classmethod
+  def setup_class(cls):
+    cls.server = SimpleThriftServer()
+    cls.server.start_server_process()
+    cls.client = thrift_util.get_client(TestService.Client,
+                                        '127.0.0.1',
+                                        cls.server.port,
+                                        'Hue Unit Test Client',
+                                        timeout_seconds=1)
+
+  @classmethod
+  def teardown_class(cls):
+    cls.server.stop_server_process()
+
+  def test_basic_operation(self):
+    assert_equal(10, self.client.ping(5))
+
+  def test_connection_race(self):
+    class Racer(threading.Thread):
+      def __init__(self, client, n_iter, begin):
+        threading.Thread.__init__(self)
+        self.setName("Racer%s" % (begin,))
+        self.client = client
+        self.n_iter = n_iter
+        self.begin = begin
+        self.errors = []
+
+      def run(self):
+        for i in range(self.begin, self.begin + self.n_iter):
+          res = self.client.ping(i)
+          if i * 2 != res:
+            self.errors.append(i)
+
+    racers = []
+    for i in range(10):
+      racer = Racer(self.client, n_iter=30, begin=(i * 100))
+      racers.append(racer)
+      racer.start()
+
+    for racer in racers:
+      racer.join()
+      assert_equal(0, len(racer.errors))
 
 class ThriftUtilTest(unittest.TestCase):
   def test_simpler_string(self):

+ 4 - 4
desktop/libs/hadoop/src/hadoop/mini_cluster.py

@@ -80,7 +80,7 @@ TEST_USER_GROUP_MAPPING = {
 
 LOGGER=logging.getLogger(__name__)
 
-def _find_unused_port():
+def find_unused_port():
   """
   Finds a port that's available.
   Unfortunately, this port may not be available by the time
@@ -178,13 +178,13 @@ rpc.class=org.apache.hadoop.metrics.spi.NoEmitMetricsContext
         "-D", "jobclient.progress.monitor.poll.interval=100",
         "-D", "fs.checkpoint.period=1",
         # For a reason I don't fully understand, this must be 0.0.0.0 and not 'localhost'
-        "-D", "dfs.secondary.http.address=0.0.0.0:%d" % _find_unused_port(),
+        "-D", "dfs.secondary.http.address=0.0.0.0:%d" % find_unused_port(),
         # We bind the NN's thrift interface to a port we find here.
         # This is suboptimal, since there's a race.  Alas, if we don't
         # do this here, the datanodes fail to discover the namenode's thrift
         # address, and there's a race there
-        "-D", "dfs.thrift.address=localhost:%d" % _find_unused_port(),
-        "-D", "jobtracker.thrift.address=localhost:%d" % _find_unused_port(),
+        "-D", "dfs.thrift.address=localhost:%d" % find_unused_port(),
+        "-D", "jobtracker.thrift.address=localhost:%d" % find_unused_port(),
         # Jobs realize they have finished faster with this timeout.
         "-D", "jobclient.completion.poll.interval=50",
         "-D", "hadoop.security.authorization=true",