Pārlūkot izejas kodu

[desktop] Update kazoo with with SASL

This uses the python-sasl library.
Erick Tryzelaar 10 gadi atpakaļ
vecāks
revīzija
0491ffdd3c

+ 5 - 3
desktop/core/ext-py/kazoo-2.0/kazoo/client.py

@@ -102,8 +102,8 @@ class KazooClient(object):
     """
     def __init__(self, hosts='127.0.0.1:2181',
                  timeout=10.0, client_id=None, handler=None,
-                 default_acl=None, auth_data=None, read_only=None,
-                 randomize_hosts=True, connection_retry=None,
+                 default_acl=None, auth_data=None, sasl_server_principal=None,
+                 read_only=None, randomize_hosts=True, connection_retry=None,
                  command_retry=None, logger=None, **kwargs):
         """Create a :class:`KazooClient` instance. All time arguments
         are in seconds.
@@ -121,6 +121,8 @@ class KazooClient(object):
             A list of authentication credentials to use for the
             connection. Should be a list of (scheme, credential)
             tuples as :meth:`add_auth` takes.
+        :param sasl_server_principal:
+            The name of SASL server principal.
         :param read_only: Allow connections to read only servers.
         :param randomize_hosts: By default randomize host selection.
         :param connection_retry:
@@ -258,7 +260,7 @@ class KazooClient(object):
 
         self._conn_retry.interrupt = lambda: self._stopped.is_set()
         self._connection = ConnectionHandler(self, self._conn_retry.copy(),
-            logger=self.logger)
+            logger=self.logger, sasl_server_principal=sasl_server_principal)
 
         # Every retry call should have its own copy of the retry helper
         # to avoid shared retry counts

+ 6 - 0
desktop/core/ext-py/kazoo-2.0/kazoo/exceptions.py

@@ -43,6 +43,12 @@ class WriterNotClosedException(KazooException):
     """
 
 
+class SaslException(KazooException):
+    """Raised if SASL encountered an error.
+    .. versionadded:: 2.1
+    """
+
+
 def _invalid_error_code():
     raise RuntimeError('Invalid error code')
 

+ 55 - 2
desktop/core/ext-py/kazoo-2.0/kazoo/protocol/connection.py

@@ -14,7 +14,8 @@ from kazoo.exceptions import (
     ConnectionDropped,
     EXCEPTIONS,
     SessionExpiredError,
-    NoNodeError
+    NoNodeError,
+    SaslException
 )
 from kazoo.handlers.utils import create_pipe
 from kazoo.loggingsupport import BLATHER
@@ -27,6 +28,7 @@ from kazoo.protocol.serialization import (
     Ping,
     PingInstance,
     ReplyHeader,
+    SASL,
     Transaction,
     Watch,
     int_struct
@@ -42,6 +44,8 @@ from kazoo.retry import (
     RetryFailedError
 )
 
+import sasl
+
 log = logging.getLogger(__name__)
 
 
@@ -132,11 +136,12 @@ class RWServerAvailable(Exception):
 
 class ConnectionHandler(object):
     """Zookeeper connection handler"""
-    def __init__(self, client, retry_sleeper, logger=None):
+    def __init__(self, client, retry_sleeper, logger=None, sasl_server_principal=None):
         self.client = client
         self.handler = client.handler
         self.retry_sleeper = retry_sleeper
         self.logger = logger or log
+        self.sasl_server_principal = sasl_server_principal
 
         # Our event objects
         self.connection_closed = client.handler.event_object()
@@ -155,6 +160,7 @@ class ConnectionHandler(object):
 
         self._connection_routine = None
 
+
     # This is instance specific to avoid odd thread bug issues in Python
     # during shutdown global cleanup
     @contextmanager
@@ -608,6 +614,9 @@ class ConnectionHandler(object):
                           negotiated_session_timeout, connect_timeout,
                           read_timeout)
 
+        if self.sasl_server_principal:
+            self._authenticate_with_sasl(host, connect_timeout / 1000.0)
+
         if connect_result.read_only:
             client._session_callback(KeeperState.CONNECTED_RO)
             self._ro_mode = iter(self._server_pinger())
@@ -620,4 +629,48 @@ class ConnectionHandler(object):
             zxid = self._invoke(connect_timeout, ap, xid=AUTH_XID)
             if zxid:
                 client.last_zxid = zxid
+
         return read_timeout, connect_timeout
+
+    def _authenticate_with_sasl(self, host, timeout):
+        saslc = sasl.Client()
+        saslc.setAttr('host', host)
+        saslc.setAttr('service', self.sasl_server_principal)
+        saslc.init()
+
+        ret, chosen_mech, initial_response = saslc.start('GSSAPI')
+        if not ret:
+            raise SaslException(saslc.getError())
+
+        response = initial_response
+
+        xid = 0
+
+        while True:
+            xid += 1
+
+            request = SASL(response)
+            self._submit(request, timeout, xid)
+
+            header, buffer, offset = self._read_header(timeout)
+            if header.xid != xid:
+                raise RuntimeError('xids do not match, expected %r '
+                                   'received %r', xid, header.xid)
+
+            if header.zxid > 0:
+                client.last_zxid = zxid
+
+            if header.err:
+                callback_exception = EXCEPTIONS[header.err]()
+                self.logger.debug(
+                    'Received error(xid=%s) %r', xid, callback_exception)
+                raise callback_exception
+
+            token, _ = SASL.deserialize(buffer, offset)
+
+            if not token:
+                break
+
+            ret, response = saslc.step(token)
+            if not ret:
+                raise SaslException(saslc.getError())

+ 14 - 0
desktop/core/ext-py/kazoo-2.0/kazoo/protocol/serialization.py

@@ -360,6 +360,20 @@ class Auth(namedtuple('Auth', 'auth_type scheme auth')):
                 write_string(self.auth))
 
 
+class SASL(namedtuple('SASL', 'token')):
+    type = 102
+
+    def serialize(self):
+        b = bytearray()
+        b.extend(write_buffer(self.token))
+        return b
+
+    @classmethod
+    def deserialize(cls, bytes, offset):
+        token, offset = read_buffer(bytes, offset)
+        return token, offset
+
+
 class Watch(namedtuple('Watch', 'type state path')):
     @classmethod
     def deserialize(cls, bytes, offset):