Răsfoiți Sursa

HUE-1288 [core] Implement proper support for SASL QOP in Hue's Thrift SASL library

Joey Echeverria 12 ani în urmă
părinte
comite
a9898b4e81
1 a modificat fișierele cu 48 adăugiri și 9 ștergeri
  1. 48 9
      desktop/core/src/desktop/lib/thrift_sasl.py

+ 48 - 9
desktop/core/src/desktop/lib/thrift_sasl.py

@@ -45,6 +45,7 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
     self.__wbuf = StringIO()
     self.__rbuf = StringIO()
     self.opened = False
+    self.encode = None
 
   def isOpen(self):
     return self._trans.isOpen()
@@ -100,18 +101,49 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
     self.__wbuf.write(data)
 
   def flush(self):
-    success, encoded = self.sasl.encode(self.__wbuf.getvalue())
+    buffer = self.__wbuf.getvalue()
+    # The first time we flush data, we send it to sasl.encode()
+    # If the length doesn't change, then we must be using a QOP
+    # of auth and we should no longer call sasl.encode(), otherwise
+    # we encode every time.
+    if self.encode == None:
+      success, encoded = self.sasl.encode(buffer)
+      if not success:
+        raise TTransportException(type=TTransportException.UNKNOWN,
+                                  message=self.sasl.getError())
+      if (len(encoded)==len(buffer)):
+        self.encode = False
+        self._flushPlain(buffer)
+      else:
+        self.encode = True
+        self._trans.write(encoded)
+    elif self.encode:
+      self._flushEncoded(buffer)
+    else:
+      self._flushPlain(buffer)
+
+    self._trans.flush()
+    self.__wbuf = StringIO()
+
+  def _flushEncoded(self, buffer):
+    # sasl.ecnode() does the encoding and adds the length header, so nothing
+    # to do but call it and write the result.
+    success, encoded = self.sasl.encode(buffer)
     if not success:
       raise TTransportException(type=TTransportException.UNKNOWN,
                                 message=self.sasl.getError())
+    self._trans.write(encoded)
+
+  def _flushPlain(self, buffer):
+    # When we have QOP of auth, sasl.encode() will pass the input to the output
+    # but won't put a length header, so we have to do that.
+
     # Note stolen from TFramedTransport:
     # N.B.: Doing this string concatenation is WAY cheaper than making
     # two separate calls to the underlying socket object. Socket writes in
     # Python turn out to be REALLY expensive, but it seems to do a pretty
     # good job of managing string buffer operations without excessive copies
-    self._trans.write(struct.pack(">I", len(encoded)) + encoded)
-    self._trans.flush()
-    self.__wbuf = StringIO()
+    self._trans.write(struct.pack(">I", len(buffer)) + buffer)
 
   def read(self, sz):
     ret = self.__rbuf.read(sz)
@@ -124,11 +156,18 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
   def _read_frame(self):
     header = self._trans.readAll(4)
     (length,) = struct.unpack(">I", header)
-    encoded = self._trans.readAll(length)
-    success, decoded = self.sasl.decode(encoded)
-    if not success:
-      raise TTransportException(type=TTransportException.UNKNOWN,
-                                message=self.sasl.getError())
+    if self.encode:
+      # If the frames are encoded (i.e. you're using a QOP of auth-int or
+      # auth-conf), then make sure to include the header in the bytes you send to
+      # sasl.decode()
+      encoded = header + self._trans.readAll(length)
+      success, decoded = self.sasl.decode(encoded)
+      if not success:
+        raise TTransportException(type=TTransportException.UNKNOWN,
+                                  message=self.sasl.getError())
+    else:
+      # If the frames are not encoded, just pass it through
+      decoded = self._trans.readAll(length)
     self.__rbuf = StringIO(decoded)
 
   def close(self):