|
|
@@ -14,7 +14,8 @@ from kazoo.exceptions import (
|
|
|
ConnectionDropped,
|
|
|
EXCEPTIONS,
|
|
|
SessionExpiredError,
|
|
|
- NoNodeError
|
|
|
+ NoNodeError,
|
|
|
+ SaslException
|
|
|
)
|
|
|
from kazoo.handlers.utils import create_pipe
|
|
|
from kazoo.loggingsupport import BLATHER
|
|
|
@@ -27,6 +28,7 @@ from kazoo.protocol.serialization import (
|
|
|
Ping,
|
|
|
PingInstance,
|
|
|
ReplyHeader,
|
|
|
+ SASL,
|
|
|
Transaction,
|
|
|
Watch,
|
|
|
int_struct
|
|
|
@@ -42,6 +44,8 @@ from kazoo.retry import (
|
|
|
RetryFailedError
|
|
|
)
|
|
|
|
|
|
+import sasl
|
|
|
+
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
@@ -132,11 +136,12 @@ class RWServerAvailable(Exception):
|
|
|
|
|
|
class ConnectionHandler(object):
|
|
|
"""Zookeeper connection handler"""
|
|
|
- def __init__(self, client, retry_sleeper, logger=None):
|
|
|
+ def __init__(self, client, retry_sleeper, logger=None, sasl_server_principal=None):
|
|
|
self.client = client
|
|
|
self.handler = client.handler
|
|
|
self.retry_sleeper = retry_sleeper
|
|
|
self.logger = logger or log
|
|
|
+ self.sasl_server_principal = sasl_server_principal
|
|
|
|
|
|
# Our event objects
|
|
|
self.connection_closed = client.handler.event_object()
|
|
|
@@ -155,6 +160,7 @@ class ConnectionHandler(object):
|
|
|
|
|
|
self._connection_routine = None
|
|
|
|
|
|
+
|
|
|
# This is instance specific to avoid odd thread bug issues in Python
|
|
|
# during shutdown global cleanup
|
|
|
@contextmanager
|
|
|
@@ -608,6 +614,9 @@ class ConnectionHandler(object):
|
|
|
negotiated_session_timeout, connect_timeout,
|
|
|
read_timeout)
|
|
|
|
|
|
+ if self.sasl_server_principal:
|
|
|
+ self._authenticate_with_sasl(host, connect_timeout / 1000.0)
|
|
|
+
|
|
|
if connect_result.read_only:
|
|
|
client._session_callback(KeeperState.CONNECTED_RO)
|
|
|
self._ro_mode = iter(self._server_pinger())
|
|
|
@@ -620,4 +629,48 @@ class ConnectionHandler(object):
|
|
|
zxid = self._invoke(connect_timeout, ap, xid=AUTH_XID)
|
|
|
if zxid:
|
|
|
client.last_zxid = zxid
|
|
|
+
|
|
|
return read_timeout, connect_timeout
|
|
|
+
|
|
|
+ def _authenticate_with_sasl(self, host, timeout):
|
|
|
+ saslc = sasl.Client()
|
|
|
+ saslc.setAttr('host', host)
|
|
|
+ saslc.setAttr('service', self.sasl_server_principal)
|
|
|
+ saslc.init()
|
|
|
+
|
|
|
+ ret, chosen_mech, initial_response = saslc.start('GSSAPI')
|
|
|
+ if not ret:
|
|
|
+ raise SaslException(saslc.getError())
|
|
|
+
|
|
|
+ response = initial_response
|
|
|
+
|
|
|
+ xid = 0
|
|
|
+
|
|
|
+ while True:
|
|
|
+ xid += 1
|
|
|
+
|
|
|
+ request = SASL(response)
|
|
|
+ self._submit(request, timeout, xid)
|
|
|
+
|
|
|
+ header, buffer, offset = self._read_header(timeout)
|
|
|
+ if header.xid != xid:
|
|
|
+ raise RuntimeError('xids do not match, expected %r '
|
|
|
+ 'received %r', xid, header.xid)
|
|
|
+
|
|
|
+ if header.zxid > 0:
|
|
|
+ client.last_zxid = zxid
|
|
|
+
|
|
|
+ if header.err:
|
|
|
+ callback_exception = EXCEPTIONS[header.err]()
|
|
|
+ self.logger.debug(
|
|
|
+ 'Received error(xid=%s) %r', xid, callback_exception)
|
|
|
+ raise callback_exception
|
|
|
+
|
|
|
+ token, _ = SASL.deserialize(buffer, offset)
|
|
|
+
|
|
|
+ if not token:
|
|
|
+ break
|
|
|
+
|
|
|
+ ret, response = saslc.step(token)
|
|
|
+ if not ret:
|
|
|
+ raise SaslException(saslc.getError())
|