Эх сурвалжийг харах

HUE-1750 [core] Make IPv4 or IPv6 pluggable for Thrift clients

Update thrift lib to use configurable socket family.
Update thrift utils tests to use IPv4.
Abraham Elmahrek 12 жил өмнө
parent
commit
07de444

+ 8 - 4
desktop/core/ext-py/thrift-0.9.1/src/transport/TSocket.py

@@ -33,7 +33,7 @@ class TSocketBase(TTransportBase):
     else:
       return socket.getaddrinfo(self.host,
                                 self.port,
-                                socket.AF_UNSPEC,
+                                socket.self._socket_family,
                                 socket.SOCK_STREAM,
                                 0,
                                 socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
@@ -47,19 +47,21 @@ class TSocketBase(TTransportBase):
 class TSocket(TSocketBase):
   """Socket implementation of TTransport base."""
 
-  def __init__(self, host='localhost', port=9090, unix_socket=None):
+  def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
     """Initialize a TSocket
 
     @param host(str)  The host to connect to.
     @param port(int)  The (TCP) port to connect to.
     @param unix_socket(str)  The filename of a unix socket to connect to.
                              (host and port will be ignored.)
+    @param socket_family(int)  The socket family to use with this socket.
     """
     self.host = host
     self.port = port
     self.handle = None
     self._unix_socket = unix_socket
     self._timeout = None
+    self._socket_family = socket_family
 
   def setHandle(self, h):
     self.handle = h
@@ -139,16 +141,18 @@ class TSocket(TSocketBase):
 class TServerSocket(TSocketBase, TServerTransportBase):
   """Socket implementation of TServerTransport base."""
 
-  def __init__(self, host=None, port=9090, unix_socket=None):
+  def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
     self.host = host
     self.port = port
     self._unix_socket = unix_socket
+    self._socket_family = socket_family
     self.handle = None
 
   def listen(self):
     res0 = self._resolveAddr()
+    socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
     for res in res0:
-      if res[0] is socket.AF_INET6 or res is res0[-1]:
+      if res[0] is socket_family or res is res0[-1]:
         break
 
     # We need remove the old unix socket if the file exists and

+ 6 - 4
desktop/core/src/desktop/lib/thrift_util_test.py

@@ -29,7 +29,6 @@ if not gen_py_path in sys.path:
 from djangothrift_test_gen.ttypes import TestStruct, TestNesting, TestEnum, TestManyTypes
 from djangothrift_test_gen import TestService
 
-import hadoop
 import python_util
 import thrift_util
 from thrift_util import jsonable2thrift, thrift2json
@@ -43,6 +42,8 @@ from nose.tools import assert_equal
 
 
 class SimpleThriftServer(object):
+  socket_family = socket.AF_INET
+
   """A simple thrift server impl"""
   def __init__(self):
     self.port = python_util.find_unused_port()
@@ -70,7 +71,7 @@ class SimpleThriftServer(object):
     # Child process runs the thrift server loop
     try:
       processor = TestService.Processor(self)
-      transport = TSocket.TServerSocket('localhost', self.port)
+      transport = TSocket.TServerSocket('localhost', self.port, socket_family=self.socket_family)
       server = TServer.TThreadedServer(processor,
                                        transport,
                                        TBufferedTransportFactory(),
@@ -82,11 +83,12 @@ class SimpleThriftServer(object):
   def _ensure_online(self):
     """Ensure that the child server is online"""
     deadline = time.time() + 60
+    logging.debug("Socket Info: " + str(socket.getaddrinfo('localhost', self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)))
     while time.time() < deadline:
       logging.info("Waiting for service to come online")
       try:
-        ping_s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
-        ping_s.connect(('localhost', self.port, 0, 0))
+        ping_s = socket.socket(self.socket_family, socket.SOCK_STREAM)
+        ping_s.connect(('localhost', self.port))
         ping_s.close()
         return
       except: