conftest.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import collections
  2. import contextlib
  3. import platform
  4. import socket
  5. import ssl
  6. import sys
  7. import threading
  8. import pytest
  9. import trustme
  10. from tornado import ioloop, web
  11. from dummyserver.handlers import TestingApp
  12. from dummyserver.proxy import ProxyHandler
  13. from dummyserver.server import HAS_IPV6, run_tornado_app
  14. from dummyserver.testcase import HTTPSDummyServerTestCase
  15. from urllib3.util import ssl_
  16. from .tz_stub import stub_timezone_ctx
  17. # The Python 3.8+ default loop on Windows breaks Tornado
  18. @pytest.fixture(scope="session", autouse=True)
  19. def configure_windows_event_loop():
  20. if sys.version_info >= (3, 8) and platform.system() == "Windows":
  21. import asyncio
  22. asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  23. ServerConfig = collections.namedtuple("ServerConfig", ["host", "port", "ca_certs"])
  24. def _write_cert_to_dir(cert, tmpdir, file_prefix="server"):
  25. cert_path = str(tmpdir / ("%s.pem" % file_prefix))
  26. key_path = str(tmpdir / ("%s.key" % file_prefix))
  27. cert.private_key_pem.write_to_path(key_path)
  28. cert.cert_chain_pems[0].write_to_path(cert_path)
  29. certs = {"keyfile": key_path, "certfile": cert_path}
  30. return certs
  31. @contextlib.contextmanager
  32. def run_server_in_thread(scheme, host, tmpdir, ca, server_cert):
  33. ca_cert_path = str(tmpdir / "ca.pem")
  34. ca.cert_pem.write_to_path(ca_cert_path)
  35. server_certs = _write_cert_to_dir(server_cert, tmpdir)
  36. io_loop = ioloop.IOLoop.current()
  37. app = web.Application([(r".*", TestingApp)])
  38. server, port = run_tornado_app(app, io_loop, server_certs, scheme, host)
  39. server_thread = threading.Thread(target=io_loop.start)
  40. server_thread.start()
  41. yield ServerConfig(host, port, ca_cert_path)
  42. io_loop.add_callback(server.stop)
  43. io_loop.add_callback(io_loop.stop)
  44. server_thread.join()
  45. @contextlib.contextmanager
  46. def run_server_and_proxy_in_thread(
  47. proxy_scheme, proxy_host, tmpdir, ca, proxy_cert, server_cert
  48. ):
  49. ca_cert_path = str(tmpdir / "ca.pem")
  50. ca.cert_pem.write_to_path(ca_cert_path)
  51. server_certs = _write_cert_to_dir(server_cert, tmpdir)
  52. proxy_certs = _write_cert_to_dir(proxy_cert, tmpdir, "proxy")
  53. io_loop = ioloop.IOLoop.current()
  54. server = web.Application([(r".*", TestingApp)])
  55. server, port = run_tornado_app(server, io_loop, server_certs, "https", "localhost")
  56. server_config = ServerConfig("localhost", port, ca_cert_path)
  57. proxy = web.Application([(r".*", ProxyHandler)])
  58. proxy_app, proxy_port = run_tornado_app(
  59. proxy, io_loop, proxy_certs, proxy_scheme, proxy_host
  60. )
  61. proxy_config = ServerConfig(proxy_host, proxy_port, ca_cert_path)
  62. server_thread = threading.Thread(target=io_loop.start)
  63. server_thread.start()
  64. yield (proxy_config, server_config)
  65. io_loop.add_callback(server.stop)
  66. io_loop.add_callback(proxy_app.stop)
  67. io_loop.add_callback(io_loop.stop)
  68. server_thread.join()
  69. @pytest.fixture
  70. def no_san_server(tmp_path_factory):
  71. tmpdir = tmp_path_factory.mktemp("certs")
  72. ca = trustme.CA()
  73. # only common name, no subject alternative names
  74. server_cert = ca.issue_cert(common_name=u"localhost")
  75. with run_server_in_thread("https", "localhost", tmpdir, ca, server_cert) as cfg:
  76. yield cfg
  77. @pytest.fixture()
  78. def no_san_server_with_different_commmon_name(tmp_path_factory):
  79. tmpdir = tmp_path_factory.mktemp("certs")
  80. ca = trustme.CA()
  81. server_cert = ca.issue_cert(common_name=u"example.com")
  82. with run_server_in_thread("https", "localhost", tmpdir, ca, server_cert) as cfg:
  83. yield cfg
  84. @pytest.fixture
  85. def no_san_proxy(tmp_path_factory):
  86. tmpdir = tmp_path_factory.mktemp("certs")
  87. ca = trustme.CA()
  88. # only common name, no subject alternative names
  89. proxy_cert = ca.issue_cert(common_name=u"localhost")
  90. server_cert = ca.issue_cert(u"localhost")
  91. with run_server_and_proxy_in_thread(
  92. "https", "localhost", tmpdir, ca, proxy_cert, server_cert
  93. ) as cfg:
  94. yield cfg
  95. @pytest.fixture
  96. def no_localhost_san_server(tmp_path_factory):
  97. tmpdir = tmp_path_factory.mktemp("certs")
  98. ca = trustme.CA()
  99. # non localhost common name
  100. server_cert = ca.issue_cert(u"example.com")
  101. with run_server_in_thread("https", "localhost", tmpdir, ca, server_cert) as cfg:
  102. yield cfg
  103. @pytest.fixture
  104. def ipv4_san_proxy(tmp_path_factory):
  105. tmpdir = tmp_path_factory.mktemp("certs")
  106. ca = trustme.CA()
  107. # IP address in Subject Alternative Name
  108. proxy_cert = ca.issue_cert(u"127.0.0.1")
  109. server_cert = ca.issue_cert(u"localhost")
  110. with run_server_and_proxy_in_thread(
  111. "https", "127.0.0.1", tmpdir, ca, proxy_cert, server_cert
  112. ) as cfg:
  113. yield cfg
  114. @pytest.fixture
  115. def ipv6_san_proxy(tmp_path_factory):
  116. tmpdir = tmp_path_factory.mktemp("certs")
  117. ca = trustme.CA()
  118. # IP addresses in Subject Alternative Name
  119. proxy_cert = ca.issue_cert(u"::1")
  120. server_cert = ca.issue_cert(u"localhost")
  121. with run_server_and_proxy_in_thread(
  122. "https", "::1", tmpdir, ca, proxy_cert, server_cert
  123. ) as cfg:
  124. yield cfg
  125. @pytest.fixture
  126. def ipv4_san_server(tmp_path_factory):
  127. tmpdir = tmp_path_factory.mktemp("certs")
  128. ca = trustme.CA()
  129. # IP address in Subject Alternative Name
  130. server_cert = ca.issue_cert(u"127.0.0.1")
  131. with run_server_in_thread("https", "127.0.0.1", tmpdir, ca, server_cert) as cfg:
  132. yield cfg
  133. @pytest.fixture
  134. def ipv6_addr_server(tmp_path_factory):
  135. if not HAS_IPV6:
  136. pytest.skip("Only runs on IPv6 systems")
  137. tmpdir = tmp_path_factory.mktemp("certs")
  138. ca = trustme.CA()
  139. # IP address in Common Name
  140. server_cert = ca.issue_cert(common_name=u"::1")
  141. with run_server_in_thread("https", "::1", tmpdir, ca, server_cert) as cfg:
  142. yield cfg
  143. @pytest.fixture
  144. def ipv6_san_server(tmp_path_factory):
  145. if not HAS_IPV6:
  146. pytest.skip("Only runs on IPv6 systems")
  147. tmpdir = tmp_path_factory.mktemp("certs")
  148. ca = trustme.CA()
  149. # IP address in Subject Alternative Name
  150. server_cert = ca.issue_cert(u"::1")
  151. with run_server_in_thread("https", "::1", tmpdir, ca, server_cert) as cfg:
  152. yield cfg
  153. @pytest.yield_fixture
  154. def stub_timezone(request):
  155. """
  156. A pytest fixture that runs the test with a stub timezone.
  157. """
  158. with stub_timezone_ctx(request.param):
  159. yield
  160. @pytest.fixture(scope="session")
  161. def supported_tls_versions():
  162. # We have to create an actual TLS connection
  163. # to test if the TLS version is not disabled by
  164. # OpenSSL config. Ubuntu 20.04 specifically
  165. # disables TLSv1 and TLSv1.1.
  166. tls_versions = set()
  167. _server = HTTPSDummyServerTestCase()
  168. _server._start_server()
  169. for _ssl_version_name in (
  170. "PROTOCOL_TLSv1",
  171. "PROTOCOL_TLSv1_1",
  172. "PROTOCOL_TLSv1_2",
  173. "PROTOCOL_TLS",
  174. ):
  175. _ssl_version = getattr(ssl, _ssl_version_name, 0)
  176. if _ssl_version == 0:
  177. continue
  178. _sock = socket.create_connection((_server.host, _server.port))
  179. try:
  180. _sock = ssl_.ssl_wrap_socket(
  181. _sock, cert_reqs=ssl.CERT_NONE, ssl_version=_ssl_version
  182. )
  183. except ssl.SSLError:
  184. pass
  185. else:
  186. tls_versions.add(_sock.version())
  187. _sock.close()
  188. _server._stop_server()
  189. return tls_versions
  190. @pytest.fixture(scope="function")
  191. def requires_tlsv1(supported_tls_versions):
  192. """Test requires TLSv1 available"""
  193. if not hasattr(ssl, "PROTOCOL_TLSv1") or "TLSv1" not in supported_tls_versions:
  194. pytest.skip("Test requires TLSv1")
  195. @pytest.fixture(scope="function")
  196. def requires_tlsv1_1(supported_tls_versions):
  197. """Test requires TLSv1.1 available"""
  198. if not hasattr(ssl, "PROTOCOL_TLSv1_1") or "TLSv1.1" not in supported_tls_versions:
  199. pytest.skip("Test requires TLSv1.1")
  200. @pytest.fixture(scope="function")
  201. def requires_tlsv1_2(supported_tls_versions):
  202. """Test requires TLSv1.2 available"""
  203. if not hasattr(ssl, "PROTOCOL_TLSv1_2") or "TLSv1.2" not in supported_tls_versions:
  204. pytest.skip("Test requires TLSv1.2")
  205. @pytest.fixture(scope="function")
  206. def requires_tlsv1_3(supported_tls_versions):
  207. """Test requires TLSv1.3 available"""
  208. if (
  209. not getattr(ssl, "HAS_TLSv1_3", False)
  210. or "TLSv1.3" not in supported_tls_versions
  211. ):
  212. pytest.skip("Test requires TLSv1.3")