synchronize.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. #
  2. # Module implementing synchronization primitives
  3. #
  4. # multiprocessing/synchronize.py
  5. #
  6. # Copyright (c) 2006-2008, R Oudkerk
  7. # Licensed to PSF under a Contributor Agreement.
  8. #
  9. from __future__ import absolute_import
  10. import errno
  11. import sys
  12. import tempfile
  13. import threading
  14. from . import context
  15. from . import process
  16. from . import util
  17. from ._ext import _billiard, ensure_SemLock
  18. from .five import range, monotonic
  19. __all__ = [
  20. 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', 'Event',
  21. ]
  22. # Try to import the mp.synchronize module cleanly, if it fails
  23. # raise ImportError for platforms lacking a working sem_open implementation.
  24. # See issue 3770
  25. ensure_SemLock()
  26. #
  27. # Constants
  28. #
  29. RECURSIVE_MUTEX, SEMAPHORE = list(range(2))
  30. SEM_VALUE_MAX = _billiard.SemLock.SEM_VALUE_MAX
  31. try:
  32. sem_unlink = _billiard.SemLock.sem_unlink
  33. except AttributeError: # pragma: no cover
  34. try:
  35. # Py3.4+ implements sem_unlink and the semaphore must be named
  36. from _multiprocessing import sem_unlink # noqa
  37. except ImportError:
  38. sem_unlink = None # noqa
  39. #
  40. # Base class for semaphores and mutexes; wraps `_billiard.SemLock`
  41. #
  42. def _semname(sl):
  43. try:
  44. return sl.name
  45. except AttributeError:
  46. pass
  47. class SemLock(object):
  48. _rand = tempfile._RandomNameSequence()
  49. def __init__(self, kind, value, maxvalue, ctx=None):
  50. if ctx is None:
  51. ctx = context._default_context.get_context()
  52. name = ctx.get_start_method()
  53. unlink_now = sys.platform == 'win32' or name == 'fork'
  54. if sem_unlink:
  55. for i in range(100):
  56. try:
  57. sl = self._semlock = _billiard.SemLock(
  58. kind, value, maxvalue, self._make_name(), unlink_now,
  59. )
  60. except (OSError, IOError) as exc:
  61. if getattr(exc, 'errno', None) != errno.EEXIST:
  62. raise
  63. else:
  64. break
  65. else:
  66. exc = IOError('cannot find file for semaphore')
  67. exc.errno = errno.EEXIST
  68. raise exc
  69. else:
  70. sl = self._semlock = _billiard.SemLock(kind, value, maxvalue)
  71. util.debug('created semlock with handle %s', sl.handle)
  72. self._make_methods()
  73. if sem_unlink:
  74. if sys.platform != 'win32':
  75. def _after_fork(obj):
  76. obj._semlock._after_fork()
  77. util.register_after_fork(self, _after_fork)
  78. if _semname(self._semlock) is not None:
  79. # We only get here if we are on Unix with forking
  80. # disabled. When the object is garbage collected or the
  81. # process shuts down we unlink the semaphore name
  82. from .semaphore_tracker import register
  83. register(self._semlock.name)
  84. util.Finalize(self, SemLock._cleanup, (self._semlock.name,),
  85. exitpriority=0)
  86. @staticmethod
  87. def _cleanup(name):
  88. from .semaphore_tracker import unregister
  89. sem_unlink(name)
  90. unregister(name)
  91. def _make_methods(self):
  92. self.acquire = self._semlock.acquire
  93. self.release = self._semlock.release
  94. def __enter__(self):
  95. return self._semlock.__enter__()
  96. def __exit__(self, *args):
  97. return self._semlock.__exit__(*args)
  98. def __getstate__(self):
  99. context.assert_spawning(self)
  100. sl = self._semlock
  101. if sys.platform == 'win32':
  102. h = context.get_spawning_popen().duplicate_for_child(sl.handle)
  103. else:
  104. h = sl.handle
  105. state = (h, sl.kind, sl.maxvalue)
  106. try:
  107. state += (sl.name, )
  108. except AttributeError:
  109. pass
  110. return state
  111. def __setstate__(self, state):
  112. self._semlock = _billiard.SemLock._rebuild(*state)
  113. util.debug('recreated blocker with handle %r', state[0])
  114. self._make_methods()
  115. @staticmethod
  116. def _make_name():
  117. return '%s-%s' % (process.current_process()._config['semprefix'],
  118. next(SemLock._rand))
  119. class Semaphore(SemLock):
  120. def __init__(self, value=1, ctx=None):
  121. SemLock.__init__(self, SEMAPHORE, value, SEM_VALUE_MAX, ctx=ctx)
  122. def get_value(self):
  123. return self._semlock._get_value()
  124. def __repr__(self):
  125. try:
  126. value = self._semlock._get_value()
  127. except Exception:
  128. value = 'unknown'
  129. return '<%s(value=%s)>' % (self.__class__.__name__, value)
  130. class BoundedSemaphore(Semaphore):
  131. def __init__(self, value=1, ctx=None):
  132. SemLock.__init__(self, SEMAPHORE, value, value, ctx=ctx)
  133. def __repr__(self):
  134. try:
  135. value = self._semlock._get_value()
  136. except Exception:
  137. value = 'unknown'
  138. return '<%s(value=%s, maxvalue=%s)>' % (
  139. self.__class__.__name__, value, self._semlock.maxvalue)
  140. class Lock(SemLock):
  141. '''
  142. Non-recursive lock.
  143. '''
  144. def __init__(self, ctx=None):
  145. SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx)
  146. def __repr__(self):
  147. try:
  148. if self._semlock._is_mine():
  149. name = process.current_process().name
  150. if threading.current_thread().name != 'MainThread':
  151. name += '|' + threading.current_thread().name
  152. elif self._semlock._get_value() == 1:
  153. name = 'None'
  154. elif self._semlock._count() > 0:
  155. name = 'SomeOtherThread'
  156. else:
  157. name = 'SomeOtherProcess'
  158. except Exception:
  159. name = 'unknown'
  160. return '<%s(owner=%s)>' % (self.__class__.__name__, name)
  161. class RLock(SemLock):
  162. '''
  163. Recursive lock
  164. '''
  165. def __init__(self, ctx=None):
  166. SemLock.__init__(self, RECURSIVE_MUTEX, 1, 1, ctx=ctx)
  167. def __repr__(self):
  168. try:
  169. if self._semlock._is_mine():
  170. name = process.current_process().name
  171. if threading.current_thread().name != 'MainThread':
  172. name += '|' + threading.current_thread().name
  173. count = self._semlock._count()
  174. elif self._semlock._get_value() == 1:
  175. name, count = 'None', 0
  176. elif self._semlock._count() > 0:
  177. name, count = 'SomeOtherThread', 'nonzero'
  178. else:
  179. name, count = 'SomeOtherProcess', 'nonzero'
  180. except Exception:
  181. name, count = 'unknown', 'unknown'
  182. return '<%s(%s, %s)>' % (self.__class__.__name__, name, count)
  183. class Condition(object):
  184. '''
  185. Condition variable
  186. '''
  187. def __init__(self, lock=None, ctx=None):
  188. assert ctx
  189. self._lock = lock or ctx.RLock()
  190. self._sleeping_count = ctx.Semaphore(0)
  191. self._woken_count = ctx.Semaphore(0)
  192. self._wait_semaphore = ctx.Semaphore(0)
  193. self._make_methods()
  194. def __getstate__(self):
  195. context.assert_spawning(self)
  196. return (self._lock, self._sleeping_count,
  197. self._woken_count, self._wait_semaphore)
  198. def __setstate__(self, state):
  199. (self._lock, self._sleeping_count,
  200. self._woken_count, self._wait_semaphore) = state
  201. self._make_methods()
  202. def __enter__(self):
  203. return self._lock.__enter__()
  204. def __exit__(self, *args):
  205. return self._lock.__exit__(*args)
  206. def _make_methods(self):
  207. self.acquire = self._lock.acquire
  208. self.release = self._lock.release
  209. def __repr__(self):
  210. try:
  211. num_waiters = (self._sleeping_count._semlock._get_value() -
  212. self._woken_count._semlock._get_value())
  213. except Exception:
  214. num_waiters = 'unknown'
  215. return '<%s(%s, %s)>' % (
  216. self.__class__.__name__, self._lock, num_waiters)
  217. def wait(self, timeout=None):
  218. assert self._lock._semlock._is_mine(), \
  219. 'must acquire() condition before using wait()'
  220. # indicate that this thread is going to sleep
  221. self._sleeping_count.release()
  222. # release lock
  223. count = self._lock._semlock._count()
  224. for i in range(count):
  225. self._lock.release()
  226. try:
  227. # wait for notification or timeout
  228. return self._wait_semaphore.acquire(True, timeout)
  229. finally:
  230. # indicate that this thread has woken
  231. self._woken_count.release()
  232. # reacquire lock
  233. for i in range(count):
  234. self._lock.acquire()
  235. def notify(self):
  236. assert self._lock._semlock._is_mine(), 'lock is not owned'
  237. assert not self._wait_semaphore.acquire(False)
  238. # to take account of timeouts since last notify() we subtract
  239. # woken_count from sleeping_count and rezero woken_count
  240. while self._woken_count.acquire(False):
  241. res = self._sleeping_count.acquire(False)
  242. assert res
  243. if self._sleeping_count.acquire(False): # try grabbing a sleeper
  244. self._wait_semaphore.release() # wake up one sleeper
  245. self._woken_count.acquire() # wait for sleeper to wake
  246. # rezero _wait_semaphore in case a timeout just happened
  247. self._wait_semaphore.acquire(False)
  248. def notify_all(self):
  249. assert self._lock._semlock._is_mine(), 'lock is not owned'
  250. assert not self._wait_semaphore.acquire(False)
  251. # to take account of timeouts since last notify*() we subtract
  252. # woken_count from sleeping_count and rezero woken_count
  253. while self._woken_count.acquire(False):
  254. res = self._sleeping_count.acquire(False)
  255. assert res
  256. sleepers = 0
  257. while self._sleeping_count.acquire(False):
  258. self._wait_semaphore.release() # wake up one sleeper
  259. sleepers += 1
  260. if sleepers:
  261. for i in range(sleepers):
  262. self._woken_count.acquire() # wait for a sleeper to wake
  263. # rezero wait_semaphore in case some timeouts just happened
  264. while self._wait_semaphore.acquire(False):
  265. pass
  266. def wait_for(self, predicate, timeout=None):
  267. result = predicate()
  268. if result:
  269. return result
  270. if timeout is not None:
  271. endtime = monotonic() + timeout
  272. else:
  273. endtime = None
  274. waittime = None
  275. while not result:
  276. if endtime is not None:
  277. waittime = endtime - monotonic()
  278. if waittime <= 0:
  279. break
  280. self.wait(waittime)
  281. result = predicate()
  282. return result
  283. class Event(object):
  284. def __init__(self, ctx=None):
  285. assert ctx
  286. self._cond = ctx.Condition(ctx.Lock())
  287. self._flag = ctx.Semaphore(0)
  288. def is_set(self):
  289. with self._cond:
  290. if self._flag.acquire(False):
  291. self._flag.release()
  292. return True
  293. return False
  294. def set(self):
  295. with self._cond:
  296. self._flag.acquire(False)
  297. self._flag.release()
  298. self._cond.notify_all()
  299. def clear(self):
  300. with self._cond:
  301. self._flag.acquire(False)
  302. def wait(self, timeout=None):
  303. with self._cond:
  304. if self._flag.acquire(False):
  305. self._flag.release()
  306. else:
  307. self._cond.wait(timeout)
  308. if self._flag.acquire(False):
  309. self._flag.release()
  310. return True
  311. return False
  312. #
  313. # Barrier
  314. #
  315. if hasattr(threading, 'Barrier'):
  316. class Barrier(threading.Barrier):
  317. def __init__(self, parties, action=None, timeout=None, ctx=None):
  318. assert ctx
  319. import struct
  320. from .heap import BufferWrapper
  321. wrapper = BufferWrapper(struct.calcsize('i') * 2)
  322. cond = ctx.Condition()
  323. self.__setstate__((parties, action, timeout, cond, wrapper))
  324. self._state = 0
  325. self._count = 0
  326. def __setstate__(self, state):
  327. (self._parties, self._action, self._timeout,
  328. self._cond, self._wrapper) = state
  329. self._array = self._wrapper.create_memoryview().cast('i')
  330. def __getstate__(self):
  331. return (self._parties, self._action, self._timeout,
  332. self._cond, self._wrapper)
  333. @property
  334. def _state(self):
  335. return self._array[0]
  336. @_state.setter
  337. def _state(self, value): # noqa
  338. self._array[0] = value
  339. @property
  340. def _count(self):
  341. return self._array[1]
  342. @_count.setter
  343. def _count(self, value): # noqa
  344. self._array[1] = value
  345. else:
  346. class Barrier(object): # noqa
  347. def __init__(self, *args, **kwargs):
  348. raise NotImplementedError('Barrier only supported on Py3')