|
|
@@ -13,13 +13,14 @@
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
+
|
|
|
+import logging
|
|
|
import os
|
|
|
import socket
|
|
|
import sys
|
|
|
import threading
|
|
|
-import unittest
|
|
|
import time
|
|
|
-import logging
|
|
|
+import unittest
|
|
|
|
|
|
gen_py_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "gen-py"))
|
|
|
if not gen_py_path in sys.path:
|
|
|
@@ -28,79 +29,130 @@ if not gen_py_path in sys.path:
|
|
|
from djangothrift_test_gen.ttypes import TestStruct, TestNesting, TestEnum, TestManyTypes
|
|
|
from djangothrift_test_gen import TestService
|
|
|
|
|
|
+import hadoop
|
|
|
import thrift_util
|
|
|
from thrift_util import jsonable2thrift, thrift2json
|
|
|
|
|
|
+from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
|
|
|
from thrift.server import TServer
|
|
|
from thrift.transport import TSocket
|
|
|
+from thrift.transport.TTransport import TBufferedTransportFactory
|
|
|
|
|
|
from nose.tools import assert_equal
|
|
|
-from nose.plugins.skip import SkipTest
|
|
|
-
|
|
|
-class TestSuperClient(unittest.TestCase):
|
|
|
- class TestHandler(object):
|
|
|
- def ping(self, in_val):
|
|
|
- return in_val * 2
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def start_server_thread(cls):
|
|
|
- """Starts a test server, returns the ServerThread object started."""
|
|
|
- handler = cls()
|
|
|
- processor = TestService.Processor(handler)
|
|
|
- transport = TSocket.TServerSocket(0)
|
|
|
- server = TServer.TSimpleServer(processor, transport)
|
|
|
-
|
|
|
- class ServerThread(threading.Thread):
|
|
|
- def __init__(self, server):
|
|
|
- threading.Thread.__init__(self)
|
|
|
- self.server = server
|
|
|
- self.stopped = False
|
|
|
-
|
|
|
- def run(self):
|
|
|
- try:
|
|
|
- logging.info("About to serve...")
|
|
|
- self.server.serve()
|
|
|
- logging.info("Done serving...")
|
|
|
- except:
|
|
|
- assert self.stopped
|
|
|
-
|
|
|
- def get_port(self):
|
|
|
- return self.server.serverTransport.handle.getsockname()[1]
|
|
|
-
|
|
|
- def stop(self):
|
|
|
- # This closes the listening socket, but the current accept()
|
|
|
- # call keeps going. So we have to ping that port
|
|
|
- self.stopped = True
|
|
|
- port = self.get_port()
|
|
|
- self.server.serverTransport.close() # hopefully this works?
|
|
|
- ping_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
- ping_s.connect(('localhost', port))
|
|
|
- ping_s.close()
|
|
|
- logging.info("Waiting for server to stop")
|
|
|
- self.join()
|
|
|
-
|
|
|
- thr = ServerThread(server)
|
|
|
- thr.start()
|
|
|
- while not transport.isOpen():
|
|
|
- logging.info("Waiting for server to start")
|
|
|
+
|
|
|
+
|
|
|
+class SimpleThriftServer(object):
|
|
|
+ """A simple thrift server impl"""
|
|
|
+ def __init__(self):
|
|
|
+ self.port = hadoop.mini_cluster.find_unused_port()
|
|
|
+ self.pid = 0
|
|
|
+
|
|
|
+ def ping(self, in_val):
|
|
|
+ return in_val * 2
|
|
|
+
|
|
|
+ def start_server_process(self):
|
|
|
+ """
|
|
|
+ Starts a test server, returns the (pid, port) pair.
|
|
|
+
|
|
|
+ The server needs to be in a subprocess because we need to run a
|
|
|
+ TThreadedServer for the concurrency tests. And the only way to stop a
|
|
|
+ TThreadedServer is to kill it. So we can't just use a thread.
|
|
|
+ """
|
|
|
+ self.pid = os.fork()
|
|
|
+ if self.pid != 0:
|
|
|
+ logging.info("Started SimpleThriftServer (pid %s) on port %s" %
|
|
|
+ (self.pid, self.port))
|
|
|
+ self._ensure_online()
|
|
|
+ return
|
|
|
+
|
|
|
+ # Child process runs the thrift server loop
|
|
|
+ try:
|
|
|
+ processor = TestService.Processor(self)
|
|
|
+ transport = TSocket.TServerSocket(self.port)
|
|
|
+ server = TServer.TThreadedServer(processor,
|
|
|
+ transport,
|
|
|
+ TBufferedTransportFactory(),
|
|
|
+ TBinaryProtocolFactory())
|
|
|
+ server.serve()
|
|
|
+ except:
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ def _ensure_online(self):
|
|
|
+ """Ensure that the child server is online"""
|
|
|
+ deadline = time.time() + 60
|
|
|
+ while time.time() < deadline:
|
|
|
+ logging.info("Waiting for service to come online")
|
|
|
+ try:
|
|
|
+ ping_s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
+ ping_s.connect(('localhost', self.port))
|
|
|
+ ping_s.close()
|
|
|
+ return
|
|
|
+ except:
|
|
|
+ _, status = os.waitpid(self.pid, os.WNOHANG)
|
|
|
+ if status != 0:
|
|
|
+ logging.info("SimpleThriftServer child process exited with %s" % (status,))
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
- return thr
|
|
|
+ logging.info("SimpleThriftServer took too long to come online")
|
|
|
+ self.stop_server_process()
|
|
|
+
|
|
|
+ def stop_server_process(self):
|
|
|
+ """Stop the server"""
|
|
|
+ if self.pid == 0:
|
|
|
+ return
|
|
|
|
|
|
- # TODO(todd) I couldn't get this to work after much effort.
|
|
|
- # Thrift's server doesn't really have a reasonable lifecycle
|
|
|
- # interface, so hard to bring up a thrift server inside a test.
|
|
|
- def test_basic_operation(self):
|
|
|
- raise SkipTest()
|
|
|
- server = TestSuperClient.TestHandler.start_server_thread()
|
|
|
try:
|
|
|
- test_client = thrift_util.get_client(TestService.Client,
|
|
|
- '127.0.0.1',
|
|
|
- server.get_port(),
|
|
|
- timeout_seconds=1)
|
|
|
- assert_equal(10, test_client.ping(5))
|
|
|
- finally:
|
|
|
- server.stop()
|
|
|
+ logging.info("Stopping SimpleThriftServer (pid %s)" % (self.pid,))
|
|
|
+ os.kill(self.pid, 15)
|
|
|
+ except Exception, ex:
|
|
|
+ logging.exception("(Potentially ok) Exception while stopping server")
|
|
|
+ os.waitpid(self.pid, 0)
|
|
|
+ self.pid = 0
|
|
|
+
|
|
|
+
|
|
|
+class TestWithThriftServer(object):
|
|
|
+ @classmethod
|
|
|
+ def setup_class(cls):
|
|
|
+ cls.server = SimpleThriftServer()
|
|
|
+ cls.server.start_server_process()
|
|
|
+ cls.client = thrift_util.get_client(TestService.Client,
|
|
|
+ '127.0.0.1',
|
|
|
+ cls.server.port,
|
|
|
+ 'Hue Unit Test Client',
|
|
|
+ timeout_seconds=1)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def teardown_class(cls):
|
|
|
+ cls.server.stop_server_process()
|
|
|
+
|
|
|
+ def test_basic_operation(self):
|
|
|
+ assert_equal(10, self.client.ping(5))
|
|
|
+
|
|
|
+ def test_connection_race(self):
|
|
|
+ class Racer(threading.Thread):
|
|
|
+ def __init__(self, client, n_iter, begin):
|
|
|
+ threading.Thread.__init__(self)
|
|
|
+ self.setName("Racer%s" % (begin,))
|
|
|
+ self.client = client
|
|
|
+ self.n_iter = n_iter
|
|
|
+ self.begin = begin
|
|
|
+ self.errors = []
|
|
|
+
|
|
|
+ def run(self):
|
|
|
+ for i in range(self.begin, self.begin + self.n_iter):
|
|
|
+ res = self.client.ping(i)
|
|
|
+ if i * 2 != res:
|
|
|
+ self.errors.append(i)
|
|
|
+
|
|
|
+ racers = []
|
|
|
+ for i in range(10):
|
|
|
+ racer = Racer(self.client, n_iter=30, begin=(i * 100))
|
|
|
+ racers.append(racer)
|
|
|
+ racer.start()
|
|
|
+
|
|
|
+ for racer in racers:
|
|
|
+ racer.join()
|
|
|
+ assert_equal(0, len(racer.errors))
|
|
|
|
|
|
class ThriftUtilTest(unittest.TestCase):
|
|
|
def test_simpler_string(self):
|