소스 검색

HUE-1750 [core] Make IPv4 or IPv6 pluggable for Thrift clients

Update thrift lib to use configurable socket family.
Update thrift utils tests to use IPv4.
Abraham Elmahrek 12 년 전
부모
커밋
07de444
2개의 변경된 파일14개의 추가작업 그리고 8개의 파일을 삭제
  1. 8 4
      desktop/core/ext-py/thrift-0.9.1/src/transport/TSocket.py
  2. 6 4
      desktop/core/src/desktop/lib/thrift_util_test.py

+ 8 - 4
desktop/core/ext-py/thrift-0.9.1/src/transport/TSocket.py

@@ -33,7 +33,7 @@ class TSocketBase(TTransportBase):
     else:
       return socket.getaddrinfo(self.host,
                                 self.port,
-                                socket.AF_UNSPEC,
+                                socket.self._socket_family,
                                 socket.SOCK_STREAM,
                                 0,
                                 socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
@@ -47,19 +47,21 @@ class TSocketBase(TTransportBase):
 class TSocket(TSocketBase):
   """Socket implementation of TTransport base."""
 
-  def __init__(self, host='localhost', port=9090, unix_socket=None):
+  def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
     """Initialize a TSocket
 
     @param host(str)  The host to connect to.
     @param port(int)  The (TCP) port to connect to.
     @param unix_socket(str)  The filename of a unix socket to connect to.
                              (host and port will be ignored.)
+    @param socket_family(int)  The socket family to use with this socket.
     """
     self.host = host
     self.port = port
     self.handle = None
     self._unix_socket = unix_socket
     self._timeout = None
+    self._socket_family = socket_family
 
   def setHandle(self, h):
     self.handle = h
@@ -139,16 +141,18 @@ class TSocket(TSocketBase):
 class TServerSocket(TSocketBase, TServerTransportBase):
   """Socket implementation of TServerTransport base."""
 
-  def __init__(self, host=None, port=9090, unix_socket=None):
+  def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
     self.host = host
     self.port = port
     self._unix_socket = unix_socket
+    self._socket_family = socket_family
     self.handle = None
 
   def listen(self):
     res0 = self._resolveAddr()
+    socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
     for res in res0:
-      if res[0] is socket.AF_INET6 or res is res0[-1]:
+      if res[0] is socket_family or res is res0[-1]:
         break
 
     # We need remove the old unix socket if the file exists and

+ 6 - 4
desktop/core/src/desktop/lib/thrift_util_test.py

@@ -29,7 +29,6 @@ if not gen_py_path in sys.path:
 from djangothrift_test_gen.ttypes import TestStruct, TestNesting, TestEnum, TestManyTypes
 from djangothrift_test_gen import TestService
 
-import hadoop
 import python_util
 import thrift_util
 from thrift_util import jsonable2thrift, thrift2json
@@ -43,6 +42,8 @@ from nose.tools import assert_equal
 
 
 class SimpleThriftServer(object):
+  socket_family = socket.AF_INET
+
   """A simple thrift server impl"""
   def __init__(self):
     self.port = python_util.find_unused_port()
@@ -70,7 +71,7 @@ class SimpleThriftServer(object):
     # Child process runs the thrift server loop
     try:
       processor = TestService.Processor(self)
-      transport = TSocket.TServerSocket('localhost', self.port)
+      transport = TSocket.TServerSocket('localhost', self.port, socket_family=self.socket_family)
       server = TServer.TThreadedServer(processor,
                                        transport,
                                        TBufferedTransportFactory(),
@@ -82,11 +83,12 @@ class SimpleThriftServer(object):
   def _ensure_online(self):
     """Ensure that the child server is online"""
     deadline = time.time() + 60
+    logging.debug("Socket Info: " + str(socket.getaddrinfo('localhost', self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)))
     while time.time() < deadline:
       logging.info("Waiting for service to come online")
       try:
-        ping_s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
-        ping_s.connect(('localhost', self.port, 0, 0))
+        ping_s = socket.socket(self.socket_family, socket.SOCK_STREAM)
+        ping_s.connect(('localhost', self.port))
         ping_s.close()
         return
       except: