testcase.py 8.2 KB


  1. import threading
  2. from contextlib import contextmanager
  3. import pytest
  4. from tornado import ioloop, web
  5. from dummyserver.handlers import TestingApp
  6. from dummyserver.proxy import ProxyHandler
  7. from dummyserver.server import (
  8. DEFAULT_CERTS,
  9. HAS_IPV6,
  10. SocketServerThread,
  11. run_loop_in_thread,
  12. run_tornado_app,
  13. )
  14. from urllib3.connection import HTTPConnection
  15. def consume_socket(sock, chunks=65536):
  16. consumed = bytearray()
  17. while True:
  18. b = sock.recv(chunks)
  19. consumed += b
  20. if b.endswith(b"\r\n\r\n"):
  21. break
  22. return consumed
  23. class SocketDummyServerTestCase(object):
  24. """
  25. A simple socket-based server is created for this class that is good for
  26. exactly one request.
  27. """
  28. scheme = "http"
  29. host = "localhost"
  30. @classmethod
  31. def _start_server(cls, socket_handler):
  32. ready_event = threading.Event()
  33. cls.server_thread = SocketServerThread(
  34. socket_handler=socket_handler, ready_event=ready_event, host=cls.host
  35. )
  36. cls.server_thread.start()
  37. ready_event.wait(5)
  38. if not ready_event.is_set():
  39. raise Exception("most likely failed to start server")
  40. cls.port = cls.server_thread.port
  41. @classmethod
  42. def start_response_handler(cls, response, num=1, block_send=None):
  43. ready_event = threading.Event()
  44. def socket_handler(listener):
  45. for _ in range(num):
  46. ready_event.set()
  47. sock = listener.accept()[0]
  48. consume_socket(sock)
  49. if block_send:
  50. block_send.wait()
  51. block_send.clear()
  52. sock.send(response)
  53. sock.close()
  54. cls._start_server(socket_handler)
  55. return ready_event
  56. @classmethod
  57. def start_basic_handler(cls, **kw):
  58. return cls.start_response_handler(
  59. b"HTTP/1.1 200 OK\r\n" b"Content-Length: 0\r\n" b"\r\n", **kw
  60. )
  61. @classmethod
  62. def teardown_class(cls):
  63. if hasattr(cls, "server_thread"):
  64. cls.server_thread.join(0.1)
  65. def assert_header_received(
  66. self, received_headers, header_name, expected_value=None
  67. ):
  68. header_name = header_name.encode("ascii")
  69. if expected_value is not None:
  70. expected_value = expected_value.encode("ascii")
  71. header_titles = []
  72. for header in received_headers:
  73. key, value = header.split(b": ")
  74. header_titles.append(key)
  75. if key == header_name and expected_value is not None:
  76. assert value == expected_value
  77. assert header_name in header_titles
  78. class IPV4SocketDummyServerTestCase(SocketDummyServerTestCase):
  79. @classmethod
  80. def _start_server(cls, socket_handler):
  81. ready_event = threading.Event()
  82. cls.server_thread = SocketServerThread(
  83. socket_handler=socket_handler, ready_event=ready_event, host=cls.host
  84. )
  85. cls.server_thread.USE_IPV6 = False
  86. cls.server_thread.start()
  87. ready_event.wait(5)
  88. if not ready_event.is_set():
  89. raise Exception("most likely failed to start server")
  90. cls.port = cls.server_thread.port
  91. class HTTPDummyServerTestCase(object):
  92. """A simple HTTP server that runs when your test class runs
  93. Have your test class inherit from this one, and then a simple server
  94. will start when your tests run, and automatically shut down when they
  95. complete. For examples of what test requests you can send to the server,
  96. see the TestingApp in dummyserver/handlers.py.
  97. """
  98. scheme = "http"
  99. host = "localhost"
  100. host_alt = "127.0.0.1" # Some tests need two hosts
  101. certs = DEFAULT_CERTS
  102. @classmethod
  103. def _start_server(cls):
  104. cls.io_loop = ioloop.IOLoop.current()
  105. app = web.Application([(r".*", TestingApp)])
  106. cls.server, cls.port = run_tornado_app(
  107. app, cls.io_loop, cls.certs, cls.scheme, cls.host
  108. )
  109. cls.server_thread = run_loop_in_thread(cls.io_loop)
  110. @classmethod
  111. def _stop_server(cls):
  112. cls.io_loop.add_callback(cls.server.stop)
  113. cls.io_loop.add_callback(cls.io_loop.stop)
  114. cls.server_thread.join()
  115. @classmethod
  116. def setup_class(cls):
  117. cls._start_server()
  118. @classmethod
  119. def teardown_class(cls):
  120. cls._stop_server()
  121. class HTTPSDummyServerTestCase(HTTPDummyServerTestCase):
  122. scheme = "https"
  123. host = "localhost"
  124. certs = DEFAULT_CERTS
  125. class HTTPDummyProxyTestCase(object):
  126. http_host = "localhost"
  127. http_host_alt = "127.0.0.1"
  128. https_host = "localhost"
  129. https_host_alt = "127.0.0.1"
  130. https_certs = DEFAULT_CERTS
  131. proxy_host = "localhost"
  132. proxy_host_alt = "127.0.0.1"
  133. @classmethod
  134. def setup_class(cls):
  135. cls.io_loop = ioloop.IOLoop.current()
  136. app = web.Application([(r".*", TestingApp)])
  137. cls.http_server, cls.http_port = run_tornado_app(
  138. app, cls.io_loop, None, "http", cls.http_host
  139. )
  140. app = web.Application([(r".*", TestingApp)])
  141. cls.https_server, cls.https_port = run_tornado_app(
  142. app, cls.io_loop, cls.https_certs, "https", cls.http_host
  143. )
  144. app = web.Application([(r".*", ProxyHandler)])
  145. cls.proxy_server, cls.proxy_port = run_tornado_app(
  146. app, cls.io_loop, None, "http", cls.proxy_host
  147. )
  148. upstream_ca_certs = cls.https_certs.get("ca_certs", None)
  149. app = web.Application(
  150. [(r".*", ProxyHandler)], upstream_ca_certs=upstream_ca_certs
  151. )
  152. cls.https_proxy_server, cls.https_proxy_port = run_tornado_app(
  153. app, cls.io_loop, cls.https_certs, "https", cls.proxy_host
  154. )
  155. cls.server_thread = run_loop_in_thread(cls.io_loop)
  156. @classmethod
  157. def teardown_class(cls):
  158. cls.io_loop.add_callback(cls.http_server.stop)
  159. cls.io_loop.add_callback(cls.https_server.stop)
  160. cls.io_loop.add_callback(cls.proxy_server.stop)
  161. cls.io_loop.add_callback(cls.https_proxy_server.stop)
  162. cls.io_loop.add_callback(cls.io_loop.stop)
  163. cls.server_thread.join()
  164. @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 not available")
  165. class IPv6HTTPDummyServerTestCase(HTTPDummyServerTestCase):
  166. host = "::1"
  167. @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 not available")
  168. class IPv6HTTPDummyProxyTestCase(HTTPDummyProxyTestCase):
  169. http_host = "localhost"
  170. http_host_alt = "127.0.0.1"
  171. https_host = "localhost"
  172. https_host_alt = "127.0.0.1"
  173. https_certs = DEFAULT_CERTS
  174. proxy_host = "::1"
  175. proxy_host_alt = "127.0.0.1"
  176. class ConnectionMarker(object):
  177. """
  178. Marks an HTTP(S)Connection's socket after a request was made.
  179. Helps a test server understand when a client finished a request,
  180. without implementing a complete HTTP server.
  181. """
  182. MARK_FORMAT = b"$#MARK%04x*!"
  183. @classmethod
  184. @contextmanager
  185. def mark(cls, monkeypatch):
  186. """
  187. Mark connections under in that context.
  188. """
  189. orig_request = HTTPConnection.request
  190. orig_request_chunked = HTTPConnection.request_chunked
  191. def call_and_mark(target):
  192. def part(self, *args, **kwargs):
  193. result = target(self, *args, **kwargs)
  194. self.sock.sendall(cls._get_socket_mark(self.sock, False))
  195. return result
  196. return part
  197. with monkeypatch.context() as m:
  198. m.setattr(HTTPConnection, "request", call_and_mark(orig_request))
  199. m.setattr(
  200. HTTPConnection, "request_chunked", call_and_mark(orig_request_chunked)
  201. )
  202. yield
  203. @classmethod
  204. def consume_request(cls, sock, chunks=65536):
  205. """
  206. Consume a socket until after the HTTP request is sent.
  207. """
  208. consumed = bytearray()
  209. mark = cls._get_socket_mark(sock, True)
  210. while True:
  211. b = sock.recv(chunks)
  212. if not b:
  213. break
  214. consumed += b
  215. if consumed.endswith(mark):
  216. break
  217. return consumed
  218. @classmethod
  219. def _get_socket_mark(cls, sock, server):
  220. if server:
  221. port = sock.getpeername()[1]
  222. else:
  223. port = sock.getsockname()[1]
  224. return cls.MARK_FORMAT % (port,)