pidbox.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. """Generic process mailbox."""
  2. from __future__ import absolute_import, unicode_literals
  3. import socket
  4. import warnings
  5. from collections import defaultdict, deque
  6. from contextlib import contextmanager
  7. from copy import copy
  8. from itertools import count
  9. from threading import local
  10. from time import time
  11. from . import Exchange, Queue, Consumer, Producer
  12. from .clocks import LamportClock
  13. from .common import maybe_declare, oid_from
  14. from .exceptions import InconsistencyError
  15. from .five import range
  16. from .log import get_logger
  17. from .utils.functional import maybe_evaluate, reprcall
  18. from .utils.objects import cached_property
  19. from .utils.uuid import uuid
  20. W_PIDBOX_IN_USE = """\
  21. A node named {node.hostname} is already using this process mailbox!
  22. Maybe you forgot to shutdown the other node or did not do so properly?
  23. Or if you meant to start multiple nodes on the same host please make sure
  24. you give each node a unique node name!
  25. """
  26. __all__ = ('Node', 'Mailbox')
  27. logger = get_logger(__name__)
  28. debug, error = logger.debug, logger.error
  29. class Node(object):
  30. """Mailbox node."""
  31. #: hostname of the node.
  32. hostname = None
  33. #: the :class:`Mailbox` this is a node for.
  34. mailbox = None
  35. #: map of method name/handlers.
  36. handlers = None
  37. #: current context (passed on to handlers)
  38. state = None
  39. #: current channel.
  40. channel = None
  41. def __init__(self, hostname, state=None, channel=None,
  42. handlers=None, mailbox=None):
  43. self.channel = channel
  44. self.mailbox = mailbox
  45. self.hostname = hostname
  46. self.state = state
  47. self.adjust_clock = self.mailbox.clock.adjust
  48. if handlers is None:
  49. handlers = {}
  50. self.handlers = handlers
  51. def Consumer(self, channel=None, no_ack=True, accept=None, **options):
  52. queue = self.mailbox.get_queue(self.hostname)
  53. def verify_exclusive(name, messages, consumers):
  54. if consumers:
  55. warnings.warn(W_PIDBOX_IN_USE.format(node=self))
  56. queue.on_declared = verify_exclusive
  57. return Consumer(
  58. channel or self.channel, [queue], no_ack=no_ack,
  59. accept=self.mailbox.accept if accept is None else accept,
  60. **options
  61. )
  62. def handler(self, fun):
  63. self.handlers[fun.__name__] = fun
  64. return fun
  65. def on_decode_error(self, message, exc):
  66. error('Cannot decode message: %r', exc, exc_info=1)
  67. def listen(self, channel=None, callback=None):
  68. consumer = self.Consumer(channel=channel,
  69. callbacks=[callback or self.handle_message],
  70. on_decode_error=self.on_decode_error)
  71. consumer.consume()
  72. return consumer
  73. def dispatch(self, method, arguments=None,
  74. reply_to=None, ticket=None, **kwargs):
  75. arguments = arguments or {}
  76. debug('pidbox received method %s [reply_to:%s ticket:%s]',
  77. reprcall(method, (), kwargs=arguments), reply_to, ticket)
  78. handle = reply_to and self.handle_call or self.handle_cast
  79. try:
  80. reply = handle(method, arguments)
  81. except SystemExit:
  82. raise
  83. except Exception as exc:
  84. error('pidbox command error: %r', exc, exc_info=1)
  85. reply = {'error': repr(exc)}
  86. if reply_to:
  87. self.reply({self.hostname: reply},
  88. exchange=reply_to['exchange'],
  89. routing_key=reply_to['routing_key'],
  90. ticket=ticket)
  91. return reply
  92. def handle(self, method, arguments={}):
  93. return self.handlers[method](self.state, **arguments)
  94. def handle_call(self, method, arguments):
  95. return self.handle(method, arguments)
  96. def handle_cast(self, method, arguments):
  97. return self.handle(method, arguments)
  98. def handle_message(self, body, message=None):
  99. destination = body.get('destination')
  100. if message:
  101. self.adjust_clock(message.headers.get('clock') or 0)
  102. if not destination or self.hostname in destination:
  103. return self.dispatch(**body)
  104. dispatch_from_message = handle_message
  105. def reply(self, data, exchange, routing_key, ticket, **kwargs):
  106. self.mailbox._publish_reply(data, exchange, routing_key, ticket,
  107. channel=self.channel,
  108. serializer=self.mailbox.serializer)
  109. class Mailbox(object):
  110. """Process Mailbox."""
  111. node_cls = Node
  112. exchange_fmt = '%s.pidbox'
  113. reply_exchange_fmt = 'reply.%s.pidbox'
  114. #: Name of application.
  115. namespace = None
  116. #: Connection (if bound).
  117. connection = None
  118. #: Exchange type (usually direct, or fanout for broadcast).
  119. type = 'direct'
  120. #: mailbox exchange (init by constructor).
  121. exchange = None
  122. #: exchange to send replies to.
  123. reply_exchange = None
  124. #: Only accepts json messages by default.
  125. accept = ['json']
  126. #: Message serializer
  127. serializer = None
  128. def __init__(self, namespace,
  129. type='direct', connection=None, clock=None,
  130. accept=None, serializer=None, producer_pool=None,
  131. queue_ttl=None, queue_expires=None,
  132. reply_queue_ttl=None, reply_queue_expires=10.0):
  133. self.namespace = namespace
  134. self.connection = connection
  135. self.type = type
  136. self.clock = LamportClock() if clock is None else clock
  137. self.exchange = self._get_exchange(self.namespace, self.type)
  138. self.reply_exchange = self._get_reply_exchange(self.namespace)
  139. self._tls = local()
  140. self.unclaimed = defaultdict(deque)
  141. self.accept = self.accept if accept is None else accept
  142. self.serializer = self.serializer if serializer is None else serializer
  143. self.queue_ttl = queue_ttl
  144. self.queue_expires = queue_expires
  145. self.reply_queue_ttl = reply_queue_ttl
  146. self.reply_queue_expires = reply_queue_expires
  147. self._producer_pool = producer_pool
  148. def __call__(self, connection):
  149. bound = copy(self)
  150. bound.connection = connection
  151. return bound
  152. def Node(self, hostname=None, state=None, channel=None, handlers=None):
  153. hostname = hostname or socket.gethostname()
  154. return self.node_cls(hostname, state, channel, handlers, mailbox=self)
  155. def call(self, destination, command, kwargs={},
  156. timeout=None, callback=None, channel=None):
  157. return self._broadcast(command, kwargs, destination,
  158. reply=True, timeout=timeout,
  159. callback=callback,
  160. channel=channel)
  161. def cast(self, destination, command, kwargs={}):
  162. return self._broadcast(command, kwargs, destination, reply=False)
  163. def abcast(self, command, kwargs={}):
  164. return self._broadcast(command, kwargs, reply=False)
  165. def multi_call(self, command, kwargs={}, timeout=1,
  166. limit=None, callback=None, channel=None):
  167. return self._broadcast(command, kwargs, reply=True,
  168. timeout=timeout, limit=limit,
  169. callback=callback,
  170. channel=channel)
  171. def get_reply_queue(self):
  172. oid = self.oid
  173. return Queue(
  174. '%s.%s' % (oid, self.reply_exchange.name),
  175. exchange=self.reply_exchange,
  176. routing_key=oid,
  177. durable=False,
  178. auto_delete=True,
  179. expires=self.reply_queue_expires,
  180. message_ttl=self.reply_queue_ttl,
  181. )
  182. @cached_property
  183. def reply_queue(self):
  184. return self.get_reply_queue()
  185. def get_queue(self, hostname):
  186. return Queue(
  187. '%s.%s.pidbox' % (hostname, self.namespace),
  188. exchange=self.exchange,
  189. durable=False,
  190. auto_delete=True,
  191. expires=self.queue_expires,
  192. message_ttl=self.queue_ttl,
  193. )
  194. @contextmanager
  195. def producer_or_acquire(self, producer=None, channel=None):
  196. if producer:
  197. yield producer
  198. elif self.producer_pool:
  199. with self.producer_pool.acquire() as producer:
  200. yield producer
  201. else:
  202. yield Producer(channel, auto_declare=False)
  203. def _publish_reply(self, reply, exchange, routing_key, ticket,
  204. channel=None, producer=None, **opts):
  205. chan = channel or self.connection.default_channel
  206. exchange = Exchange(exchange, exchange_type='direct',
  207. delivery_mode='transient',
  208. durable=False)
  209. with self.producer_or_acquire(producer, chan) as producer:
  210. try:
  211. producer.publish(
  212. reply, exchange=exchange, routing_key=routing_key,
  213. declare=[exchange], headers={
  214. 'ticket': ticket, 'clock': self.clock.forward(),
  215. }, retry=True,
  216. **opts
  217. )
  218. except InconsistencyError:
  219. # queue probably deleted and no one is expecting a reply.
  220. pass
  221. def _publish(self, type, arguments, destination=None,
  222. reply_ticket=None, channel=None, timeout=None,
  223. serializer=None, producer=None):
  224. message = {'method': type,
  225. 'arguments': arguments,
  226. 'destination': destination}
  227. chan = channel or self.connection.default_channel
  228. exchange = self.exchange
  229. if reply_ticket:
  230. maybe_declare(self.reply_queue(channel))
  231. message.update(ticket=reply_ticket,
  232. reply_to={'exchange': self.reply_exchange.name,
  233. 'routing_key': self.oid})
  234. serializer = serializer or self.serializer
  235. with self.producer_or_acquire(producer, chan) as producer:
  236. producer.publish(
  237. message, exchange=exchange.name, declare=[exchange],
  238. headers={'clock': self.clock.forward(),
  239. 'expires': time() + timeout if timeout else 0},
  240. serializer=serializer, retry=True,
  241. )
  242. def _broadcast(self, command, arguments=None, destination=None,
  243. reply=False, timeout=1, limit=None,
  244. callback=None, channel=None, serializer=None):
  245. if destination is not None and \
  246. not isinstance(destination, (list, tuple)):
  247. raise ValueError(
  248. 'destination must be a list/tuple not {0}'.format(
  249. type(destination)))
  250. arguments = arguments or {}
  251. reply_ticket = reply and uuid() or None
  252. chan = channel or self.connection.default_channel
  253. # Set reply limit to number of destinations (if specified)
  254. if limit is None and destination:
  255. limit = destination and len(destination) or None
  256. serializer = serializer or self.serializer
  257. self._publish(command, arguments, destination=destination,
  258. reply_ticket=reply_ticket,
  259. channel=chan,
  260. timeout=timeout,
  261. serializer=serializer)
  262. if reply_ticket:
  263. return self._collect(reply_ticket, limit=limit,
  264. timeout=timeout,
  265. callback=callback,
  266. channel=chan)
  267. def _collect(self, ticket,
  268. limit=None, timeout=1, callback=None,
  269. channel=None, accept=None):
  270. if accept is None:
  271. accept = self.accept
  272. chan = channel or self.connection.default_channel
  273. queue = self.reply_queue
  274. consumer = Consumer(channel, [queue], accept=accept, no_ack=True)
  275. responses = []
  276. unclaimed = self.unclaimed
  277. adjust_clock = self.clock.adjust
  278. try:
  279. return unclaimed.pop(ticket)
  280. except KeyError:
  281. pass
  282. def on_message(body, message):
  283. # ticket header added in kombu 2.5
  284. header = message.headers.get
  285. adjust_clock(header('clock') or 0)
  286. expires = header('expires')
  287. if expires and time() > expires:
  288. return
  289. this_id = header('ticket', ticket)
  290. if this_id == ticket:
  291. if callback:
  292. callback(body)
  293. responses.append(body)
  294. else:
  295. unclaimed[this_id].append(body)
  296. consumer.register_callback(on_message)
  297. try:
  298. with consumer:
  299. for i in limit and range(limit) or count():
  300. try:
  301. self.connection.drain_events(timeout=timeout)
  302. except socket.timeout:
  303. break
  304. return responses
  305. finally:
  306. chan.after_reply_message_received(queue.name)
  307. def _get_exchange(self, namespace, type):
  308. return Exchange(self.exchange_fmt % namespace,
  309. type=type,
  310. durable=False,
  311. delivery_mode='transient')
  312. def _get_reply_exchange(self, namespace):
  313. return Exchange(self.reply_exchange_fmt % namespace,
  314. type='direct',
  315. durable=False,
  316. delivery_mode='transient')
  317. @cached_property
  318. def oid(self):
  319. try:
  320. return self._tls.OID
  321. except AttributeError:
  322. oid = self._tls.OID = oid_from(self)
  323. return oid
  324. @cached_property
  325. def producer_pool(self):
  326. return maybe_evaluate(self._producer_pool)