|
|
@@ -45,6 +45,7 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
|
|
|
self.__wbuf = StringIO()
|
|
|
self.__rbuf = StringIO()
|
|
|
self.opened = False
|
|
|
+ self.encode = None
|
|
|
|
|
|
def isOpen(self):
|
|
|
return self._trans.isOpen()
|
|
|
@@ -100,18 +101,49 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
|
|
|
self.__wbuf.write(data)
|
|
|
|
|
|
def flush(self):
|
|
|
- success, encoded = self.sasl.encode(self.__wbuf.getvalue())
|
|
|
+ buffer = self.__wbuf.getvalue()
|
|
|
+ # The first time we flush data, we send it to sasl.encode()
|
|
|
+ # If the length doesn't change, then we must be using a QOP
|
|
|
+ # of auth and we should no longer call sasl.encode(), otherwise
|
|
|
+ # we encode every time.
|
|
|
+ if self.encode == None:
|
|
|
+ success, encoded = self.sasl.encode(buffer)
|
|
|
+ if not success:
|
|
|
+ raise TTransportException(type=TTransportException.UNKNOWN,
|
|
|
+ message=self.sasl.getError())
|
|
|
+ if (len(encoded)==len(buffer)):
|
|
|
+ self.encode = False
|
|
|
+ self._flushPlain(buffer)
|
|
|
+ else:
|
|
|
+ self.encode = True
|
|
|
+ self._trans.write(encoded)
|
|
|
+ elif self.encode:
|
|
|
+ self._flushEncoded(buffer)
|
|
|
+ else:
|
|
|
+ self._flushPlain(buffer)
|
|
|
+
|
|
|
+ self._trans.flush()
|
|
|
+ self.__wbuf = StringIO()
|
|
|
+
|
|
|
+ def _flushEncoded(self, buffer):
|
|
|
+ # sasl.ecnode() does the encoding and adds the length header, so nothing
|
|
|
+ # to do but call it and write the result.
|
|
|
+ success, encoded = self.sasl.encode(buffer)
|
|
|
if not success:
|
|
|
raise TTransportException(type=TTransportException.UNKNOWN,
|
|
|
message=self.sasl.getError())
|
|
|
+ self._trans.write(encoded)
|
|
|
+
|
|
|
+ def _flushPlain(self, buffer):
|
|
|
+ # When we have QOP of auth, sasl.encode() will pass the input to the output
|
|
|
+ # but won't put a length header, so we have to do that.
|
|
|
+
|
|
|
# Note stolen from TFramedTransport:
|
|
|
# N.B.: Doing this string concatenation is WAY cheaper than making
|
|
|
# two separate calls to the underlying socket object. Socket writes in
|
|
|
# Python turn out to be REALLY expensive, but it seems to do a pretty
|
|
|
# good job of managing string buffer operations without excessive copies
|
|
|
- self._trans.write(struct.pack(">I", len(encoded)) + encoded)
|
|
|
- self._trans.flush()
|
|
|
- self.__wbuf = StringIO()
|
|
|
+ self._trans.write(struct.pack(">I", len(buffer)) + buffer)
|
|
|
|
|
|
def read(self, sz):
|
|
|
ret = self.__rbuf.read(sz)
|
|
|
@@ -124,11 +156,18 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
|
|
|
def _read_frame(self):
|
|
|
header = self._trans.readAll(4)
|
|
|
(length,) = struct.unpack(">I", header)
|
|
|
- encoded = self._trans.readAll(length)
|
|
|
- success, decoded = self.sasl.decode(encoded)
|
|
|
- if not success:
|
|
|
- raise TTransportException(type=TTransportException.UNKNOWN,
|
|
|
- message=self.sasl.getError())
|
|
|
+ if self.encode:
|
|
|
+ # If the frames are encoded (i.e. you're using a QOP of auth-int or
|
|
|
+ # auth-conf), then make sure to include the header in the bytes you send to
|
|
|
+ # sasl.decode()
|
|
|
+ encoded = header + self._trans.readAll(length)
|
|
|
+ success, decoded = self.sasl.decode(encoded)
|
|
|
+ if not success:
|
|
|
+ raise TTransportException(type=TTransportException.UNKNOWN,
|
|
|
+ message=self.sasl.getError())
|
|
|
+ else:
|
|
|
+ # If the frames are not encoded, just pass it through
|
|
|
+ decoded = self._trans.readAll(length)
|
|
|
self.__rbuf = StringIO(decoded)
|
|
|
|
|
|
def close(self):
|