|
|
@@ -17,6 +17,7 @@
|
|
|
# under the License.
|
|
|
#
|
|
|
""" SASL transports for Thrift. """
|
|
|
+from __future__ import absolute_import
|
|
|
|
|
|
from future import standard_library
|
|
|
standard_library.install_aliases()
|
|
|
@@ -27,8 +28,11 @@ import sasl
|
|
|
import struct
|
|
|
import sys
|
|
|
|
|
|
+# TODO: Check whether the following distinction is necessary. Does not appear to
|
|
|
+# break anything when `io.BytesIO` is used everywhere, but there may be some edge
|
|
|
+# cases where things break down.
|
|
|
if sys.version_info[0] > 2:
|
|
|
- from io import StringIO as string_io
|
|
|
+ from io import BytesIO as string_io
|
|
|
else:
|
|
|
from cStringIO import StringIO as string_io
|
|
|
|
|
|
@@ -42,7 +46,7 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
|
|
|
def __init__(self, sasl_client_factory, mechanism, trans):
|
|
|
"""
|
|
|
@param sasl_client_factory: a callable that returns a new sasl.Client object
|
|
|
- @param mechanism: the SASL mechanism (e.g. "GSSAPI", "PLAIN")
|
|
|
+ @param mechanism: the SASL mechanism (e.g. "GSSAPI")
|
|
|
@param trans: the underlying transport over which to communicate.
|
|
|
"""
|
|
|
self._trans = trans
|
|
|
@@ -57,6 +61,9 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
|
|
|
def isOpen(self):
|
|
|
return self._trans.isOpen()
|
|
|
|
|
|
+ def is_open(self):
|
|
|
+ return self.isOpen()
|
|
|
+
|
|
|
def open(self):
|
|
|
if not self._trans.isOpen():
|
|
|
self._trans.open()
|
|
|
@@ -154,11 +161,11 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
|
|
|
|
|
|
def read(self, sz):
|
|
|
ret = self.__rbuf.read(sz)
|
|
|
- if len(ret) != 0:
|
|
|
+ if len(ret) == sz:
|
|
|
return ret
|
|
|
|
|
|
self._read_frame()
|
|
|
- return self.__rbuf.read(sz)
|
|
|
+ return ret + self.__rbuf.read(sz - len(ret))
|
|
|
|
|
|
def _read_frame(self):
|
|
|
header = self._trans.readAll(4)
|