http.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # -*- coding: utf-8 -*-
  2. """
  3. # Run server:
  4. >>> import thriftpy2
  5. >>> from thriftpy2.http import make_server
  6. >>> pingpong = thriftpy2.load("pingpong.thrift")
  7. >>>
  8. >>> class Dispatcher(object):
  9. >>> def ping(self):
  10. >>> return "pong"
  11. >>> server = make_server(pingpong.PingService, Dispatcher(),
  12. host='127.0.0.1', port=6000)
  13. >>> server.serve()
  14. # Run client:
  15. >>> import thriftpy2
  16. >>> from thriftpy2.http import make_client
  17. >>> pingpong = thriftpy2.load("pingpong.thrift")
  18. >>> client = make_client(pingpong.PingService, host='127.0.0.1', port=6000)
  19. >>> client.ping()
  20. # Run HTTPS client with unverified SSL context for TESTING ONLY purpose:
  21. >>> import ssl
  22. >>> ssl_context_factory = ssl._create_unverified_context
  23. >>> client = make_client(pingpong.PingService, host='example.com', port=443,
  24. ... scheme="https",
  25. ... ssl_context_factory=ssl_context_factory)
  26. >>> client.ping()
  27. """
  28. from __future__ import absolute_import
  29. import os
  30. import socket
  31. import sys
  32. from contextlib import contextmanager
  33. from io import BytesIO
  34. from thriftpy2._compat import PY3
  35. if PY3:
  36. import http.client as http_client
  37. import http.server as http_server
  38. import urllib
  39. else:
  40. import httplib as http_client
  41. import BaseHTTPServer as http_server
  42. import urllib2 as urllib
  43. import urlparse
  44. urllib.parse = urlparse
  45. urllib.parse.quote = urllib.quote
  46. from thriftpy2.thrift import TProcessor, TClient
  47. from thriftpy2.server import TServer
  48. from thriftpy2.transport import TTransportBase, TMemoryBuffer
  49. from thriftpy2.protocol import TBinaryProtocolFactory
  50. from thriftpy2.transport import TBufferedTransportFactory
  51. HTTP_URI = '{scheme}://{host}:{port}{path}'
  52. DEFAULT_HTTP_CLIENT_TIMEOUT_MS = 30000 # 30 seconds
  53. class TFileObjectTransport(TTransportBase):
  54. """Wraps a file-like object to make it work as a Thrift transport."""
  55. def __init__(self, fileobj):
  56. self.fileobj = fileobj
  57. def isOpen(self):
  58. return True
  59. def close(self):
  60. self.fileobj.close()
  61. def read(self, sz):
  62. return self.fileobj.read(sz)
  63. def write(self, buf):
  64. self.fileobj.write(buf)
  65. def flush(self):
  66. self.fileobj.flush()
  67. class ResponseException(Exception):
  68. """Allows handlers to override the HTTP response
  69. Normally, THttpServer always sends a 200 response. If a handler wants
  70. to override this behavior (e.g., to simulate a misconfigured or
  71. overloaded web server during testing), it can raise a ResponseException.
  72. The function passed to the constructor will be called with the
  73. RequestHandler as its only argument.
  74. """
  75. def __init__(self, handler):
  76. self.handler = handler
  77. class THttpServer(TServer):
  78. """A simple HTTP-based Thrift server
  79. This class is not very performant, but it is useful (for example) for
  80. acting as a mock version of an Apache-based PHP Thrift endpoint.
  81. """
  82. def __init__(self,
  83. processor,
  84. server_address,
  85. itrans_factory,
  86. iprot_factory,
  87. server_class=http_server.HTTPServer):
  88. """Set up protocol factories and HTTP server.
  89. See http.server for server_address.
  90. See TServer for protocol factories.
  91. """
  92. TServer.__init__(self, processor, trans=None,
  93. itrans_factory=itrans_factory,
  94. iprot_factory=iprot_factory,
  95. otrans_factory=None, oprot_factory=None)
  96. thttpserver = self
  97. class RequestHander(http_server.BaseHTTPRequestHandler):
  98. # Don't care about the request path.
  99. def do_POST(self):
  100. # Don't care about the request path.
  101. # Pre-read all of the data into a BytesIO. Buffered transport
  102. # was previously configured to read everything on the first
  103. # consumption, but that was a hack relying on the internal
  104. # mechanism and prevents other transports from working, so
  105. # replicate that properly to prevent timeout issues
  106. content_len = int(self.headers['Content-Length'])
  107. buf = BytesIO(self.rfile.read(content_len))
  108. itrans = TFileObjectTransport(buf)
  109. itrans = thttpserver.itrans_factory.get_transport(itrans)
  110. iprot = thttpserver.iprot_factory.get_protocol(itrans)
  111. otrans = TMemoryBuffer()
  112. oprot = thttpserver.oprot_factory.get_protocol(otrans)
  113. try:
  114. thttpserver.processor.process(iprot, oprot)
  115. except ResponseException as exn:
  116. exn.handler(self)
  117. else:
  118. self.send_response(200)
  119. self.send_header("content-type", "application/x-thrift")
  120. self.end_headers()
  121. self.wfile.write(otrans.getvalue())
  122. self.httpd = server_class(server_address, RequestHander)
  123. def serve(self):
  124. self.httpd.serve_forever()
  125. class THttpClient(object):
  126. """Http implementation of TTransport base.
  127. """
  128. def __init__(self, uri, timeout=None, ssl_context_factory=None):
  129. """Initialize a HTTP Socket.
  130. @param uri(str) The http_scheme:://host:port/path to connect to.
  131. @param timeout timeout in ms
  132. """
  133. parsed = urllib.parse.urlparse(uri)
  134. self.scheme = parsed.scheme
  135. assert self.scheme in ('http', 'https')
  136. if self.scheme == 'http':
  137. self.port = parsed.port or http_client.HTTP_PORT
  138. elif self.scheme == 'https':
  139. self.port = parsed.port or http_client.HTTPS_PORT
  140. self.host = parsed.hostname
  141. self.path = parsed.path
  142. if parsed.query:
  143. self.path += '?%s' % parsed.query
  144. self.__wbuf = BytesIO()
  145. self.__http = None
  146. self.__custom_headers = None
  147. self.__timeout = None
  148. if timeout:
  149. self.setTimeout(timeout)
  150. self._ssl_context_factory = ssl_context_factory
  151. def open(self):
  152. if self.scheme == "https":
  153. ssl_context = self._ssl_context_factory() \
  154. if self._ssl_context_factory else None
  155. self.__http = http_client.HTTPSConnection(self.host, self.port,
  156. context=ssl_context)
  157. else:
  158. self.__http = http_client.HTTPConnection(self.host, self.port)
  159. def close(self):
  160. self.__http.close()
  161. self.__http = None
  162. def isOpen(self):
  163. return self.__http is not None
  164. def setTimeout(self, ms):
  165. if not hasattr(socket, 'getdefaulttimeout'):
  166. raise NotImplementedError
  167. self.__timeout = ms / 1000.0 if (ms and ms > 0) else None
  168. def setCustomHeaders(self, headers):
  169. self.__custom_headers = headers
  170. def read(self, sz):
  171. content = self.response.read(sz)
  172. return content
  173. def write(self, buf):
  174. self.__wbuf.write(buf)
  175. def flush(self):
  176. # Pull data out of buffer
  177. # Do this before opening a new connection in case there isn't data
  178. data = self.__wbuf.getvalue()
  179. self.__wbuf = BytesIO()
  180. if not data: # No data to flush, ignore
  181. return
  182. if self.isOpen():
  183. self.close()
  184. self.open()
  185. # HTTP request
  186. self.__http.putrequest('POST', self.path, skip_host=True)
  187. # Write headers
  188. self.__http.putheader('Host', self.host)
  189. self.__http.putheader('Content-Type', 'application/x-thrift')
  190. self.__http.putheader('Content-Length', str(len(data)))
  191. if (not self.__custom_headers or
  192. 'User-Agent' not in self.__custom_headers):
  193. user_agent = 'Python/THttpClient'
  194. script = os.path.basename(sys.argv[0])
  195. if script:
  196. user_agent = '%s (%s)' % (
  197. user_agent, urllib.parse.quote(script))
  198. self.__http.putheader('User-Agent', user_agent)
  199. if self.__custom_headers:
  200. for key, val in self.__custom_headers.items():
  201. self.__http.putheader(key, val)
  202. self.__http.endheaders()
  203. # Write payload
  204. self.__http.send(data)
  205. # Get reply to flush the request
  206. response = self.__http.getresponse()
  207. self.code, self.message, self.headers = (
  208. response.status, response.msg, response.getheaders())
  209. self.response = response
  210. def __with_timeout(f):
  211. def _f(*args, **kwargs):
  212. orig_timeout = socket.getdefaulttimeout()
  213. socket.setdefaulttimeout(args[0].__timeout)
  214. result = None
  215. try:
  216. result = f(*args, **kwargs)
  217. finally:
  218. socket.setdefaulttimeout(orig_timeout)
  219. return result
  220. return _f
  221. # Decorate if we know how to timeout
  222. if hasattr(socket, 'getdefaulttimeout'):
  223. flush = __with_timeout(flush)
  224. def make_client(service, host='localhost', port=9090, path='', scheme='http',
  225. proto_factory=TBinaryProtocolFactory(),
  226. trans_factory=TBufferedTransportFactory(),
  227. ssl_context_factory=None,
  228. timeout=DEFAULT_HTTP_CLIENT_TIMEOUT_MS, url=''):
  229. if url:
  230. parsed_url = urllib.parse.urlparse(url)
  231. host = parsed_url.hostname or host
  232. port = parsed_url.port or port
  233. scheme = parsed_url.scheme or scheme
  234. path = parsed_url.path or path
  235. uri = HTTP_URI.format(scheme=scheme, host=host, port=port, path=path)
  236. http_socket = THttpClient(uri, timeout, ssl_context_factory)
  237. transport = trans_factory.get_transport(http_socket)
  238. iprot = proto_factory.get_protocol(transport)
  239. transport.open()
  240. return TClient(service, iprot)
  241. @contextmanager
  242. def client_context(service, host='localhost', port=9090, path='', scheme='http',
  243. proto_factory=TBinaryProtocolFactory(),
  244. trans_factory=TBufferedTransportFactory(),
  245. ssl_context_factory=None,
  246. timeout=DEFAULT_HTTP_CLIENT_TIMEOUT_MS, url=''):
  247. if url:
  248. parsed_url = urllib.parse.urlparse(url)
  249. host = parsed_url.hostname or host
  250. port = parsed_url.port or port
  251. scheme = parsed_url.scheme or scheme
  252. path = parsed_url.path or path
  253. uri = HTTP_URI.format(scheme=scheme, host=host, port=port, path=path)
  254. http_socket = THttpClient(uri, timeout, ssl_context_factory)
  255. transport = trans_factory.get_transport(http_socket)
  256. try:
  257. iprot = proto_factory.get_protocol(transport)
  258. transport.open()
  259. yield TClient(service, iprot)
  260. finally:
  261. transport.close()
  262. def make_server(service, handler, host, port,
  263. proto_factory=TBinaryProtocolFactory(),
  264. trans_factory=TBufferedTransportFactory()):
  265. processor = TProcessor(service, handler)
  266. server = THttpServer(processor, (host, port),
  267. itrans_factory=trans_factory,
  268. iprot_factory=proto_factory)
  269. return server