Przeglądaj źródła

[multi-python] pure-sasl for py3.11

thrift-sasl does not work for py3.11. So, we fallback to pure-sasl.

This is based on impyla commit https://github.com/cloudera/impyla/commit/426e9d3e000c1eb5c7cd9e6d3236542a6c8a9d5b
Amit Srivastava 7 miesięcy temu
rodzic
commit
bb31ccbbb0

+ 67 - 0
desktop/core/src/desktop/lib/sasl_compat.py

@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from contextlib import contextmanager
+
+from puresasl.client import SASLClient, SASLError
+
+
+@contextmanager
+def error_catcher(self, Exc=Exception):
+    try:
+        self.error = None
+        yield
+    except Exc as e:
+        self.error = e.message
+
+
+class PureSASLClient(SASLClient):
+    def __init__(self, *args, **kwargs):
+        self.error = None
+        super(PureSASLClient, self).__init__(*args, **kwargs)
+
+    def start(self, mechanism):
+        with error_catcher(self, SASLError):
+            if isinstance(mechanism, list):
+                self.choose_mechanism(mechanism)
+            else:
+                self.choose_mechanism([mechanism])
+            return True, self.mechanism, self.process()
+        # else
+        return False, mechanism, None
+
+    def encode(self, incoming):
+        with error_catcher(self):
+            return True, self.unwrap(incoming)
+        # else
+        return False, None
+
+    def decode(self, outgoing):
+        with error_catcher(self):
+            return True, self.wrap(outgoing)
+        # else
+        return False, None
+
+    def step(self, challenge):
+        with error_catcher(self):
+            return True, self.process(challenge)
+        # else
+        return False, None
+
+    def getError(self):
+        return self.error

+ 2 - 0
desktop/core/src/desktop/lib/thrift_sasl.py

@@ -94,6 +94,8 @@ class TSaslClientTransport(TTransportBase, CReadableTransport):
 
   def _send_message(self, status, body):
     header = struct.pack(">BI", status, len(body))
+    if isinstance(body, str):
+      body = body.encode('utf-8')
     self._trans.write(header + body)
     self._trans.flush()
 

+ 29 - 11
desktop/core/src/desktop/lib/thrift_util.py

@@ -42,6 +42,7 @@ from desktop.conf import CHERRYPY_SERVER_THREADS, ENABLE_ORGANIZATIONS, ENABLE_S
 from desktop.lib.apputil import INFO_LEVEL_CALL_DURATION_MS, WARN_LEVEL_CALL_DURATION_MS
 from desktop.lib.exceptions import StructuredException, StructuredThriftTransportException
 from desktop.lib.python_util import create_synchronous_io_multiplexer
+from desktop.lib.sasl_compat import PureSASLClient
 from desktop.lib.thrift_.http_client import THttpClient
 from desktop.lib.thrift_.TSSLSocketWithWildcardSAN import TSSLSocketWithWildcardSAN
 from desktop.lib.thrift_sasl import TSaslClientTransport
@@ -351,17 +352,34 @@ def connect_to_thrift(conf):
       mode.set_basic_auth(conf.username, conf.password)
 
   if conf.transport_mode == 'socket' and conf.use_sasl:
-    def sasl_factory():
-      saslc = sasl.Client()
-      saslc.setAttr("host", str(conf.host))
-      saslc.setAttr("service", str(conf.kerberos_principal))
-      if conf.mechanism == 'PLAIN':
-        saslc.setAttr("username", str(conf.username))
-        saslc.setAttr("password", str(conf.password))  # Defaults to 'hue' for a non-empty string unless using LDAP
-      else:
-        saslc.setAttr("maxbufsize", SASL_MAX_BUFFER.get())
-      saslc.init()
-      return saslc
+    try:
+      import sasl  # pylint: disable=import-error
+
+      def sasl_factory():
+        saslc = sasl.Client()
+        saslc.setAttr("host", str(conf.host))
+        saslc.setAttr("service", str(conf.kerberos_principal))
+        if conf.mechanism == 'PLAIN':
+          saslc.setAttr("username", str(conf.username))
+          saslc.setAttr("password", str(conf.password))  # Defaults to 'hue' for a non-empty string unless using LDAP
+        else:
+          saslc.setAttr("maxbufsize", SASL_MAX_BUFFER.get())
+        saslc.init()
+        return saslc
+
+    except Exception as e:
+      LOG.debug("Unable to import 'sasl'. Fallback to 'puresasl'.")
+      from desktop.lib.sasl_compat import PureSASLClient
+
+      def sasl_factory():
+        return PureSASLClient(
+          host=str(conf.host),
+          username=str(conf.username) if conf.mechanism == 'PLAIN' else None,
+          password=str(conf.password) if conf.mechanism == 'PLAIN' else None,
+          maxbufsize=SASL_MAX_BUFFER.get() if conf.mechanism != 'PLAIN' else None,
+          service=str(conf.kerberos_principal)
+        )
+
     transport = TSaslClientTransport(sasl_factory, conf.mechanism, mode)
   elif conf.transport == 'framed':
     transport = TFramedTransport(mode)