reduction.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. #
  2. # Module which deals with pickling of objects.
  3. #
  4. # multiprocessing/reduction.py
  5. #
  6. # Copyright (c) 2006-2008, R Oudkerk
  7. # Licensed to PSF under a Contributor Agreement.
  8. #
  9. from __future__ import absolute_import
  10. import functools
  11. import io
  12. import os
  13. import pickle
  14. import socket
  15. import sys
  16. from . import context
  17. __all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump']
  18. PY3 = sys.version_info[0] == 3
  19. HAVE_SEND_HANDLE = (sys.platform == 'win32' or
  20. (hasattr(socket, 'CMSG_LEN') and
  21. hasattr(socket, 'SCM_RIGHTS') and
  22. hasattr(socket.socket, 'sendmsg')))
  23. #
  24. # Pickler subclass
  25. #
  26. if PY3:
  27. import copyreg
  28. class ForkingPickler(pickle.Pickler):
  29. '''Pickler subclass used by multiprocessing.'''
  30. _extra_reducers = {}
  31. _copyreg_dispatch_table = copyreg.dispatch_table
  32. def __init__(self, *args):
  33. super(ForkingPickler, self).__init__(*args)
  34. self.dispatch_table = self._copyreg_dispatch_table.copy()
  35. self.dispatch_table.update(self._extra_reducers)
  36. @classmethod
  37. def register(cls, type, reduce):
  38. '''Register a reduce function for a type.'''
  39. cls._extra_reducers[type] = reduce
  40. @classmethod
  41. def dumps(cls, obj, protocol=None):
  42. buf = io.BytesIO()
  43. cls(buf, protocol).dump(obj)
  44. return buf.getbuffer()
  45. @classmethod
  46. def loadbuf(cls, buf, protocol=None):
  47. return cls.loads(buf.getbuffer())
  48. loads = pickle.loads
  49. else:
  50. class ForkingPickler(pickle.Pickler): # noqa
  51. '''Pickler subclass used by multiprocessing.'''
  52. dispatch = pickle.Pickler.dispatch.copy()
  53. @classmethod
  54. def register(cls, type, reduce):
  55. '''Register a reduce function for a type.'''
  56. def dispatcher(self, obj):
  57. rv = reduce(obj)
  58. self.save_reduce(obj=obj, *rv)
  59. cls.dispatch[type] = dispatcher
  60. @classmethod
  61. def dumps(cls, obj, protocol=None):
  62. buf = io.BytesIO()
  63. cls(buf, protocol).dump(obj)
  64. return buf.getvalue()
  65. @classmethod
  66. def loadbuf(cls, buf, protocol=None):
  67. return cls.loads(buf.getvalue())
  68. @classmethod
  69. def loads(cls, buf, loads=pickle.loads):
  70. if isinstance(buf, io.BytesIO):
  71. buf = buf.getvalue()
  72. return loads(buf)
  73. register = ForkingPickler.register
  74. def dump(obj, file, protocol=None):
  75. '''Replacement for pickle.dump() using ForkingPickler.'''
  76. ForkingPickler(file, protocol).dump(obj)
  77. #
  78. # Platform specific definitions
  79. #
  80. if sys.platform == 'win32':
  81. # Windows
  82. __all__ += ['DupHandle', 'duplicate', 'steal_handle']
  83. from .compat import _winapi
  84. def duplicate(handle, target_process=None, inheritable=False):
  85. '''Duplicate a handle. (target_process is a handle not a pid!)'''
  86. if target_process is None:
  87. target_process = _winapi.GetCurrentProcess()
  88. return _winapi.DuplicateHandle(
  89. _winapi.GetCurrentProcess(), handle, target_process,
  90. 0, inheritable, _winapi.DUPLICATE_SAME_ACCESS)
  91. def steal_handle(source_pid, handle):
  92. '''Steal a handle from process identified by source_pid.'''
  93. source_process_handle = _winapi.OpenProcess(
  94. _winapi.PROCESS_DUP_HANDLE, False, source_pid)
  95. try:
  96. return _winapi.DuplicateHandle(
  97. source_process_handle, handle,
  98. _winapi.GetCurrentProcess(), 0, False,
  99. _winapi.DUPLICATE_SAME_ACCESS | _winapi.DUPLICATE_CLOSE_SOURCE)
  100. finally:
  101. _winapi.CloseHandle(source_process_handle)
  102. def send_handle(conn, handle, destination_pid):
  103. '''Send a handle over a local connection.'''
  104. dh = DupHandle(handle, _winapi.DUPLICATE_SAME_ACCESS, destination_pid)
  105. conn.send(dh)
  106. def recv_handle(conn):
  107. '''Receive a handle over a local connection.'''
  108. return conn.recv().detach()
  109. class DupHandle(object):
  110. '''Picklable wrapper for a handle.'''
  111. def __init__(self, handle, access, pid=None):
  112. if pid is None:
  113. # We just duplicate the handle in the current process and
  114. # let the receiving process steal the handle.
  115. pid = os.getpid()
  116. proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False, pid)
  117. try:
  118. self._handle = _winapi.DuplicateHandle(
  119. _winapi.GetCurrentProcess(),
  120. handle, proc, access, False, 0)
  121. finally:
  122. _winapi.CloseHandle(proc)
  123. self._access = access
  124. self._pid = pid
  125. def detach(self):
  126. '''Get the handle. This should only be called once.'''
  127. # retrieve handle from process which currently owns it
  128. if self._pid == os.getpid():
  129. # The handle has already been duplicated for this process.
  130. return self._handle
  131. # We must steal the handle from the process whose pid is self._pid.
  132. proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False,
  133. self._pid)
  134. try:
  135. return _winapi.DuplicateHandle(
  136. proc, self._handle, _winapi.GetCurrentProcess(),
  137. self._access, False, _winapi.DUPLICATE_CLOSE_SOURCE)
  138. finally:
  139. _winapi.CloseHandle(proc)
  140. else:
  141. # Unix
  142. __all__ += ['DupFd', 'sendfds', 'recvfds']
  143. import array
  144. # On macOS we should acknowledge receipt of fds -- see Issue14669
  145. ACKNOWLEDGE = sys.platform == 'darwin'
  146. def sendfds(sock, fds):
  147. '''Send an array of fds over an AF_UNIX socket.'''
  148. fds = array.array('i', fds)
  149. msg = bytes([len(fds) % 256])
  150. sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)])
  151. if ACKNOWLEDGE and sock.recv(1) != b'A':
  152. raise RuntimeError('did not receive acknowledgement of fd')
  153. def recvfds(sock, size):
  154. '''Receive an array of fds over an AF_UNIX socket.'''
  155. a = array.array('i')
  156. bytes_size = a.itemsize * size
  157. msg, ancdata, flags, addr = sock.recvmsg(
  158. 1, socket.CMSG_LEN(bytes_size),
  159. )
  160. if not msg and not ancdata:
  161. raise EOFError
  162. try:
  163. if ACKNOWLEDGE:
  164. sock.send(b'A')
  165. if len(ancdata) != 1:
  166. raise RuntimeError(
  167. 'received %d items of ancdata' % len(ancdata),
  168. )
  169. cmsg_level, cmsg_type, cmsg_data = ancdata[0]
  170. if (cmsg_level == socket.SOL_SOCKET and
  171. cmsg_type == socket.SCM_RIGHTS):
  172. if len(cmsg_data) % a.itemsize != 0:
  173. raise ValueError
  174. a.frombytes(cmsg_data)
  175. assert len(a) % 256 == msg[0]
  176. return list(a)
  177. except (ValueError, IndexError):
  178. pass
  179. raise RuntimeError('Invalid data received')
  180. def send_handle(conn, handle, destination_pid): # noqa
  181. '''Send a handle over a local connection.'''
  182. fd = conn.fileno()
  183. with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s:
  184. sendfds(s, [handle])
  185. def recv_handle(conn): # noqa
  186. '''Receive a handle over a local connection.'''
  187. fd = conn.fileno()
  188. with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s:
  189. return recvfds(s, 1)[0]
  190. def DupFd(fd):
  191. '''Return a wrapper for an fd.'''
  192. popen_obj = context.get_spawning_popen()
  193. if popen_obj is not None:
  194. return popen_obj.DupFd(popen_obj.duplicate_for_child(fd))
  195. elif HAVE_SEND_HANDLE:
  196. from . import resource_sharer
  197. return resource_sharer.DupFd(fd)
  198. else:
  199. raise ValueError('SCM_RIGHTS appears not to be available')
  200. #
  201. # Try making some callable types picklable
  202. #
  203. def _reduce_method(m):
  204. if m.__self__ is None:
  205. return getattr, (m.__class__, m.__func__.__name__)
  206. else:
  207. return getattr, (m.__self__, m.__func__.__name__)
  208. class _C:
  209. def f(self):
  210. pass
  211. register(type(_C().f), _reduce_method)
  212. def _reduce_method_descriptor(m):
  213. return getattr, (m.__objclass__, m.__name__)
  214. register(type(list.append), _reduce_method_descriptor)
  215. register(type(int.__add__), _reduce_method_descriptor)
  216. def _reduce_partial(p):
  217. return _rebuild_partial, (p.func, p.args, p.keywords or {})
  218. def _rebuild_partial(func, args, keywords):
  219. return functools.partial(func, *args, **keywords)
  220. register(functools.partial, _reduce_partial)
  221. #
  222. # Make sockets picklable
  223. #
  224. if sys.platform == 'win32':
  225. def _reduce_socket(s):
  226. from .resource_sharer import DupSocket
  227. return _rebuild_socket, (DupSocket(s),)
  228. def _rebuild_socket(ds):
  229. return ds.detach()
  230. register(socket.socket, _reduce_socket)
  231. else:
  232. def _reduce_socket(s): # noqa
  233. df = DupFd(s.fileno())
  234. return _rebuild_socket, (df, s.family, s.type, s.proto)
  235. def _rebuild_socket(df, family, type, proto): # noqa
  236. fd = df.detach()
  237. return socket.socket(family, type, proto, fileno=fd)
  238. register(socket.socket, _reduce_socket)