mock.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # mock.py
  2. # Test tools for mocking and patching.
  3. # Copyright (C) 2007-2009 Michael Foord
  4. # E-mail: fuzzyman AT voidspace DOT org DOT uk
  5. # mock 0.6.0
  6. # http://www.voidspace.org.uk/python/mock/
  7. # Released subject to the BSD License
  8. # Please see http://www.voidspace.org.uk/python/license.shtml
  9. # Scripts maintained at http://www.voidspace.org.uk/python/index.shtml
  10. # Comments, suggestions and bug reports welcome.
  11. __all__ = (
  12. 'Mock',
  13. 'patch',
  14. 'patch_object',
  15. 'sentinel',
  16. 'DEFAULT'
  17. )
  18. __version__ = '0.6.0'
  19. class SentinelObject(object):
  20. def __init__(self, name):
  21. self.name = name
  22. def __repr__(self):
  23. return '<SentinelObject "%s">' % self.name
  24. class Sentinel(object):
  25. def __init__(self):
  26. self._sentinels = {}
  27. def __getattr__(self, name):
  28. return self._sentinels.setdefault(name, SentinelObject(name))
  29. sentinel = Sentinel()
  30. DEFAULT = sentinel.DEFAULT
  31. class OldStyleClass:
  32. pass
  33. ClassType = type(OldStyleClass)
  34. def _is_magic(name):
  35. return '__%s__' % name[2:-2] == name
  36. def _copy(value):
  37. if type(value) in (dict, list, tuple, set):
  38. return type(value)(value)
  39. return value
  40. class Mock(object):
  41. def __init__(self, spec=None, side_effect=None, return_value=DEFAULT,
  42. name=None, parent=None, wraps=None):
  43. self._parent = parent
  44. self._name = name
  45. if spec is not None and not isinstance(spec, list):
  46. spec = [member for member in dir(spec) if not _is_magic(member)]
  47. self._methods = spec
  48. self._children = {}
  49. self._return_value = return_value
  50. self.side_effect = side_effect
  51. self._wraps = wraps
  52. self.reset_mock()
  53. def reset_mock(self):
  54. self.called = False
  55. self.call_args = None
  56. self.call_count = 0
  57. self.call_args_list = []
  58. self.method_calls = []
  59. for child in self._children.itervalues():
  60. child.reset_mock()
  61. if isinstance(self._return_value, Mock):
  62. self._return_value.reset_mock()
  63. def __get_return_value(self):
  64. if self._return_value is DEFAULT:
  65. self._return_value = Mock()
  66. return self._return_value
  67. def __set_return_value(self, value):
  68. self._return_value = value
  69. return_value = property(__get_return_value, __set_return_value)
  70. def __call__(self, *args, **kwargs):
  71. self.called = True
  72. self.call_count += 1
  73. self.call_args = (args, kwargs)
  74. self.call_args_list.append((args, kwargs))
  75. parent = self._parent
  76. name = self._name
  77. while parent is not None:
  78. parent.method_calls.append((name, args, kwargs))
  79. if parent._parent is None:
  80. break
  81. name = parent._name + '.' + name
  82. parent = parent._parent
  83. ret_val = DEFAULT
  84. if self.side_effect is not None:
  85. if (isinstance(self.side_effect, Exception) or
  86. isinstance(self.side_effect, (type, ClassType)) and
  87. issubclass(self.side_effect, Exception)):
  88. raise self.side_effect
  89. ret_val = self.side_effect(*args, **kwargs)
  90. if ret_val is DEFAULT:
  91. ret_val = self.return_value
  92. if self._wraps is not None and self._return_value is DEFAULT:
  93. return self._wraps(*args, **kwargs)
  94. if ret_val is DEFAULT:
  95. ret_val = self.return_value
  96. return ret_val
  97. def __getattr__(self, name):
  98. if self._methods is not None:
  99. if name not in self._methods:
  100. raise AttributeError("Mock object has no attribute '%s'" % name)
  101. elif _is_magic(name):
  102. raise AttributeError(name)
  103. if name not in self._children:
  104. wraps = None
  105. if self._wraps is not None:
  106. wraps = getattr(self._wraps, name)
  107. self._children[name] = Mock(parent=self, name=name, wraps=wraps)
  108. return self._children[name]
  109. def assert_called_with(self, *args, **kwargs):
  110. assert self.call_args == (args, kwargs), 'Expected: %s\nCalled with: %s' % ((args, kwargs), self.call_args)
  111. def _dot_lookup(thing, comp, import_path):
  112. try:
  113. return getattr(thing, comp)
  114. except AttributeError:
  115. __import__(import_path)
  116. return getattr(thing, comp)
  117. def _importer(target):
  118. components = target.split('.')
  119. import_path = components.pop(0)
  120. thing = __import__(import_path)
  121. for comp in components:
  122. import_path += ".%s" % comp
  123. thing = _dot_lookup(thing, comp, import_path)
  124. return thing
  125. class _patch(object):
  126. def __init__(self, target, attribute, new, spec, create):
  127. self.target = target
  128. self.attribute = attribute
  129. self.new = new
  130. self.spec = spec
  131. self.create = create
  132. self.has_local = False
  133. def __call__(self, func):
  134. if hasattr(func, 'patchings'):
  135. func.patchings.append(self)
  136. return func
  137. def patched(*args, **keywargs):
  138. # don't use a with here (backwards compatability with 2.5)
  139. extra_args = []
  140. for patching in patched.patchings:
  141. arg = patching.__enter__()
  142. if patching.new is DEFAULT:
  143. extra_args.append(arg)
  144. args += tuple(extra_args)
  145. try:
  146. return func(*args, **keywargs)
  147. finally:
  148. for patching in getattr(patched, 'patchings', []):
  149. patching.__exit__()
  150. patched.patchings = [self]
  151. patched.__name__ = func.__name__
  152. patched.compat_co_firstlineno = getattr(func, "compat_co_firstlineno",
  153. func.func_code.co_firstlineno)
  154. return patched
  155. def get_original(self):
  156. target = self.target
  157. name = self.attribute
  158. create = self.create
  159. original = DEFAULT
  160. if _has_local_attr(target, name):
  161. try:
  162. original = target.__dict__[name]
  163. except AttributeError:
  164. # for instances of classes with slots, they have no __dict__
  165. original = getattr(target, name)
  166. elif not create and not hasattr(target, name):
  167. raise AttributeError("%s does not have the attribute %r" % (target, name))
  168. return original
  169. def __enter__(self):
  170. new, spec, = self.new, self.spec
  171. original = self.get_original()
  172. if new is DEFAULT:
  173. # XXXX what if original is DEFAULT - shouldn't use it as a spec
  174. inherit = False
  175. if spec == True:
  176. # set spec to the object we are replacing
  177. spec = original
  178. if isinstance(spec, (type, ClassType)):
  179. inherit = True
  180. new = Mock(spec=spec)
  181. if inherit:
  182. new.return_value = Mock(spec=spec)
  183. self.temp_original = original
  184. setattr(self.target, self.attribute, new)
  185. return new
  186. def __exit__(self, *_):
  187. if self.temp_original is not DEFAULT:
  188. setattr(self.target, self.attribute, self.temp_original)
  189. else:
  190. delattr(self.target, self.attribute)
  191. del self.temp_original
  192. def patch_object(target, attribute, new=DEFAULT, spec=None, create=False):
  193. return _patch(target, attribute, new, spec, create)
  194. def patch(target, new=DEFAULT, spec=None, create=False):
  195. try:
  196. target, attribute = target.rsplit('.', 1)
  197. except (TypeError, ValueError):
  198. raise TypeError("Need a valid target to patch. You supplied: %r" % (target,))
  199. target = _importer(target)
  200. return _patch(target, attribute, new, spec, create)
  201. def _has_local_attr(obj, name):
  202. try:
  203. return name in vars(obj)
  204. except TypeError:
  205. # objects without a __dict__
  206. return hasattr(obj, name)