pool.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. #
  2. # Module providing the `Pool` class for managing a process pool
  3. #
  4. # processing/pool.py
  5. #
  6. # Copyright (c) 2007-2008, R Oudkerk --- see COPYING.txt
  7. #
  8. __all__ = ['Pool']
  9. #
  10. # Imports
  11. #
  12. import processing
  13. import threading
  14. import Queue
  15. import itertools
  16. import collections
  17. import time
  18. from processing import Process
  19. from processing.logger import debug
  20. from processing.finalize import Finalize
  21. from processing.queue import SimpleQueue
  22. #
  23. # Constants representing the state of a pool
  24. #
  25. RUN = 0
  26. CLOSE = 1
  27. TERMINATE = 2
  28. #
  29. # Miscellaneous
  30. #
  31. newJobId = itertools.count().next
  32. def mapstar(args):
  33. return map(*args)
  34. #
  35. # Code run by worker processes
  36. #
  37. def worker(inqueue, outqueue, initializer=None, initargs=()):
  38. put = outqueue.put
  39. if initializer is not None:
  40. initializer(*initargs)
  41. for job, i, func, args, kwds in iter(inqueue.get, None):
  42. try:
  43. result = (True, func(*args, **kwds))
  44. except Exception, e:
  45. result = (False, e)
  46. put((job, i, result))
  47. debug('worker got sentinel -- exiting')
  48. #
  49. # Class representing a process pool
  50. #
  51. class Pool(object):
  52. '''
  53. Class which supports an async version of the `apply()` builtin
  54. '''
  55. def __init__(self, processes=None, initializer=None, initargs=()):
  56. self._inqueue = SimpleQueue()
  57. self._outqueue = SimpleQueue()
  58. self._taskqueue = Queue.Queue()
  59. self._cache = {}
  60. self._state = RUN
  61. if processes is None:
  62. try:
  63. processes = processing.cpuCount()
  64. except NotImplementedError:
  65. processes = 1
  66. self._pool = [
  67. Process(target=worker, args=(self._inqueue, self._outqueue,
  68. initializer, initargs))
  69. for i in range(processes)
  70. ]
  71. for i, w in enumerate(self._pool):
  72. w.setName('PoolWorker-' + ':'.join(map(str, w._identity)))
  73. w.start()
  74. self._task_handler = threading.Thread(
  75. target=Pool._handleTasks,
  76. args=(self._taskqueue, self._inqueue, self._outqueue, self._pool)
  77. )
  78. self._task_handler.setDaemon(True)
  79. self._task_handler._state = RUN
  80. self._task_handler.start()
  81. self._result_handler = threading.Thread(
  82. target=Pool._handleResults,
  83. args=(self._outqueue, self._cache)
  84. )
  85. self._result_handler.setDaemon(True)
  86. self._result_handler._state = RUN
  87. self._result_handler.start()
  88. self._terminate = Finalize(
  89. self, Pool._terminatePool,
  90. args=(self._taskqueue, self._inqueue, self._outqueue,
  91. self._cache, self._pool, self._task_handler,
  92. self._result_handler),
  93. exitpriority=5
  94. )
  95. def apply(self, func, args=(), kwds={}):
  96. '''
  97. Equivalent of `apply()` builtin
  98. '''
  99. assert self._state == RUN
  100. return self.applyAsync(func, args, kwds).get()
  101. def map(self, func, iterable, chunksize=None):
  102. '''
  103. Equivalent of `map()` builtin
  104. '''
  105. assert self._state == RUN
  106. return self.mapAsync(func, iterable, chunksize).get()
  107. def imap(self, func, iterable, chunksize=1):
  108. '''
  109. Equivalent of `itertool.imap()` -- can be MUCH slower than `Pool.map()`
  110. '''
  111. assert self._state == RUN
  112. if chunksize == 1:
  113. result = IMapIterator(self._cache)
  114. self._taskqueue.put((((result._job, i, func, (x,), {})
  115. for i, x in enumerate(iterable)), result._setLength))
  116. return result
  117. else:
  118. assert chunksize > 1
  119. task_batches = Pool._getTasks(func, iterable, chunksize)
  120. result = IMapIterator(self._cache)
  121. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  122. for i, x in enumerate(task_batches)), result._setLength))
  123. return (item for chunk in result for item in chunk)
  124. def imapUnordered(self, func, iterable, chunksize=1):
  125. '''
  126. Like `imap()` method but ordering of results is arbitrary
  127. '''
  128. assert self._state == RUN
  129. if chunksize == 1:
  130. result = IMapUnorderedIterator(self._cache)
  131. self._taskqueue.put((((result._job, i, func, (x,), {})
  132. for i, x in enumerate(iterable)), result._setLength))
  133. return result
  134. else:
  135. assert chunksize > 1
  136. task_batches = Pool._getTasks(func, iterable, chunksize)
  137. result = IMapUnorderedIterator(self._cache)
  138. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  139. for i, x in enumerate(task_batches)), result._setLength))
  140. return (item for chunk in result for item in chunk)
  141. def applyAsync(self, func, args=(), kwds={}, callback=None):
  142. '''
  143. Asynchronous equivalent of `apply()` builtin
  144. '''
  145. assert self._state == RUN
  146. result = ApplyResult(self._cache, callback)
  147. self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
  148. return result
  149. def mapAsync(self, func, iterable, chunksize=None, callback=None):
  150. '''
  151. Asynchronous equivalent of `map()` builtin
  152. '''
  153. assert self._state == RUN
  154. if not hasattr(iterable, '__len__'):
  155. iterable = list(iterable)
  156. if chunksize is None:
  157. chunksize, extra = divmod(len(iterable), len(self._pool) * 4)
  158. if extra:
  159. chunksize += 1
  160. task_batches = Pool._getTasks(func, iterable, chunksize)
  161. result = MapResult(self._cache, chunksize, len(iterable), callback)
  162. self._taskqueue.put((((result._job, i, mapstar, (x,), {})
  163. for i, x in enumerate(task_batches)), None))
  164. return result
  165. @staticmethod
  166. def _handleTasks(taskqueue, inqueue, outqueue, pool):
  167. thread = threading.currentThread()
  168. put = inqueue._writer.send
  169. for taskseq, setLength in iter(taskqueue.get, None):
  170. i = -1
  171. for i, task in enumerate(taskseq):
  172. if thread._state:
  173. debug('task handler found thread._state != RUN')
  174. break
  175. put(task)
  176. else:
  177. if setLength:
  178. debug('doing setLength()')
  179. setLength(i+1)
  180. continue
  181. break
  182. else:
  183. debug('task handler got sentinel')
  184. # tell result handler to finish when cache is empty
  185. outqueue.put(None)
  186. # tell workers there is no more work
  187. debug('task handler sending sentinel to workers')
  188. for p in pool:
  189. put(None)
  190. debug('task handler exiting')
  191. @staticmethod
  192. def _handleResults(outqueue, cache):
  193. thread = threading.currentThread()
  194. get = outqueue._reader.recv
  195. for job, i, obj in iter(get, None):
  196. if thread._state:
  197. assert thread._state == TERMINATE
  198. debug('result handler found thread._state=TERMINATE')
  199. return
  200. try:
  201. cache[job]._set(i, obj)
  202. except KeyError:
  203. pass
  204. else:
  205. debug('result handler got sentinel')
  206. while cache and thread._state != TERMINATE:
  207. item = get()
  208. if item is None:
  209. debug('result handler ignoring extra sentinel')
  210. continue
  211. job, i, obj = item
  212. try:
  213. cache[job]._set(i, obj)
  214. except KeyError:
  215. pass
  216. debug('result handler exiting: len(cache)=%s, thread._state=%s',
  217. len(cache), thread._state)
  218. @staticmethod
  219. def _getTasks(func, it, size):
  220. it = iter(it)
  221. while 1:
  222. x = tuple(itertools.islice(it, size))
  223. if not x:
  224. return
  225. yield (func, x)
  226. def __reduce__(self):
  227. raise NotImplementedError, \
  228. 'pool objects cannot be passed between processes or pickled'
  229. def close(self):
  230. debug('closing pool')
  231. self._state = CLOSE
  232. self._taskqueue.put(None)
  233. def terminate(self):
  234. debug('terminating pool')
  235. self._state = TERMINATE
  236. self._terminate()
  237. def join(self):
  238. debug('joining pool')
  239. assert self._state in (CLOSE, TERMINATE)
  240. self._task_handler.join()
  241. self._result_handler.join()
  242. for p in self._pool:
  243. p.join()
  244. @staticmethod
  245. def _terminatePool(taskqueue, inqueue, outqueue, cache, pool,
  246. task_handler, result_handler):
  247. debug('finalizing pool')
  248. if not result_handler.isAlive():
  249. debug('result handler already finished -- no need to terminate')
  250. return
  251. cache = {}
  252. task_handler._state = TERMINATE
  253. result_handler._state = TERMINATE
  254. debug('sending sentinels')
  255. taskqueue.put(None)
  256. outqueue.put(None)
  257. debug('getting read lock on inqueue')
  258. inqueue._rlock.acquire()
  259. debug('terminating workers')
  260. for p in pool:
  261. p.terminate()
  262. if task_handler.isAlive():
  263. debug('removing tasks from inqueue until task handler finished')
  264. while task_handler.isAlive() and inqueue._reader.poll():
  265. inqueue._reader.recv()
  266. time.sleep(0)
  267. debug('joining result handler')
  268. result_handler.join()
  269. debug('joining task handler')
  270. task_handler.join()
  271. debug('joining pool workers')
  272. for p in pool:
  273. p.join()
  274. debug('closing connections')
  275. inqueue._reader.close()
  276. outqueue._reader.close()
  277. inqueue._writer.close()
  278. outqueue._writer.close()
  279. # deprecated
  280. apply_async = applyAsync
  281. map_async = mapAsync
  282. imap_unordered = imapUnordered
  283. #
  284. # Class whose instances are returned by `Pool.applyAsync()`
  285. #
  286. class ApplyResult(object):
  287. def __init__(self, cache, callback):
  288. self._cond = threading.Condition(threading.Lock())
  289. self._job = newJobId()
  290. self._cache = cache
  291. self._ready = False
  292. self._callback = callback
  293. cache[self._job] = self
  294. def ready(self):
  295. return self._ready
  296. def successful(self):
  297. assert self._ready
  298. return self._success
  299. def wait(self, timeout=None):
  300. self._cond.acquire()
  301. try:
  302. if not self._ready:
  303. self._cond.wait(timeout)
  304. finally:
  305. self._cond.release()
  306. def get(self, timeout=None):
  307. self.wait(timeout)
  308. if not self._ready:
  309. raise processing.TimeoutError
  310. if self._success:
  311. return self._value
  312. else:
  313. raise self._value
  314. def _set(self, i, obj):
  315. self._success, self._value = obj
  316. if self._callback and self._success:
  317. self._callback(self._value)
  318. self._cond.acquire()
  319. try:
  320. self._ready = True
  321. self._cond.notify()
  322. finally:
  323. self._cond.release()
  324. del self._cache[self._job]
  325. #
  326. # Class whose instances are returned by `Pool.mapAsync()`
  327. #
  328. class MapResult(ApplyResult):
  329. def __init__(self, cache, chunksize, length, callback):
  330. ApplyResult.__init__(self, cache, callback)
  331. self._success = True
  332. self._value = [None] * length
  333. self._chunksize = chunksize
  334. if chunksize <= 0:
  335. self._number_left = 0
  336. self._ready = True
  337. else:
  338. self._number_left = length//chunksize + bool(length % chunksize)
  339. def _set(self, i, (success, result)):
  340. if success:
  341. self._value[i*self._chunksize:(i+1)*self._chunksize] = result
  342. self._number_left -= 1
  343. if self._number_left == 0:
  344. if self._callback:
  345. self._callback(self._value)
  346. del self._cache[self._job]
  347. self._cond.acquire()
  348. try:
  349. self._ready = True
  350. self._cond.notify()
  351. finally:
  352. self._cond.release()
  353. else:
  354. self._success = False
  355. self._value = result
  356. del self._cache[self._job]
  357. self._cond.acquire()
  358. try:
  359. self._ready = True
  360. self._cond.notify()
  361. finally:
  362. self._cond.release()
  363. #
  364. # Class whose instances are returned by `Pool.imap()`
  365. #
  366. class IMapIterator(object):
  367. def __init__(self, cache):
  368. self._cond = threading.Condition(threading.Lock())
  369. self._job = newJobId()
  370. self._cache = cache
  371. self._items = collections.deque()
  372. self._index = 0
  373. self._length = None
  374. self._unsorted = {}
  375. cache[self._job] = self
  376. def __iter__(self):
  377. return self
  378. def next(self, timeout=None):
  379. self._cond.acquire()
  380. try:
  381. try:
  382. item = self._items.popleft()
  383. except IndexError:
  384. if self._index == self._length:
  385. raise StopIteration
  386. self._cond.wait(timeout)
  387. try:
  388. item = self._items.popleft()
  389. except IndexError:
  390. if self._index == self._length:
  391. raise StopIteration
  392. raise processing.TimeoutError
  393. finally:
  394. self._cond.release()
  395. success, value = item
  396. if success:
  397. return value
  398. raise value
  399. def _set(self, i, obj):
  400. self._cond.acquire()
  401. try:
  402. if self._index == i:
  403. self._items.append(obj)
  404. self._index += 1
  405. while self._index in self._unsorted:
  406. obj = self._unsorted.pop(self._index)
  407. self._items.append(obj)
  408. self._index += 1
  409. self._cond.notify()
  410. else:
  411. self._unsorted[i] = obj
  412. if self._index == self._length:
  413. del self._cache[self._job]
  414. finally:
  415. self._cond.release()
  416. def _setLength(self, length):
  417. self._cond.acquire()
  418. try:
  419. self._length = length
  420. if self._index == self._length:
  421. self._cond.notify()
  422. del self._cache[self._job]
  423. finally:
  424. self._cond.release()
  425. #
  426. # Class whose instances are returned by `Pool.imapUnordered()`
  427. #
  428. class IMapUnorderedIterator(IMapIterator):
  429. def _set(self, i, obj):
  430. self._cond.acquire()
  431. try:
  432. self._items.append(obj)
  433. self._index += 1
  434. self._cond.notify()
  435. if self._index == self._length:
  436. del self._cache[self._job]
  437. finally:
  438. self._cond.release()