rpc.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # -*- coding: utf-8 -*-
  2. from __future__ import absolute_import
  3. import contextlib
  4. import warnings
  5. from thriftpy.protocol import TBinaryProtocolFactory
  6. from thriftpy.server import TThreadedServer
  7. from thriftpy.thrift import TProcessor, TClient
  8. from thriftpy.transport import (
  9. TBufferedTransportFactory,
  10. TServerSocket,
  11. TSSLServerSocket,
  12. TSocket,
  13. TSSLSocket,
  14. )
  15. def make_client(service, host="localhost", port=9090, unix_socket=None,
  16. proto_factory=TBinaryProtocolFactory(),
  17. trans_factory=TBufferedTransportFactory(),
  18. timeout=None,
  19. cafile=None, ssl_context=None, certfile=None, keyfile=None):
  20. if unix_socket:
  21. socket = TSocket(unix_socket=unix_socket)
  22. if certfile:
  23. warnings.warn("SSL only works with host:port, not unix_socket.")
  24. elif host and port:
  25. if cafile or ssl_context:
  26. socket = TSSLSocket(host, port, socket_timeout=timeout,
  27. cafile=cafile,
  28. certfile=certfile, keyfile=keyfile,
  29. ssl_context=ssl_context)
  30. else:
  31. socket = TSocket(host, port, socket_timeout=timeout)
  32. else:
  33. raise ValueError("Either host/port or unix_socket must be provided.")
  34. transport = trans_factory.get_transport(socket)
  35. protocol = proto_factory.get_protocol(transport)
  36. transport.open()
  37. return TClient(service, protocol)
  38. def make_server(service, handler,
  39. host="localhost", port=9090, unix_socket=None,
  40. proto_factory=TBinaryProtocolFactory(),
  41. trans_factory=TBufferedTransportFactory(),
  42. client_timeout=3000, certfile=None):
  43. processor = TProcessor(service, handler)
  44. if unix_socket:
  45. server_socket = TServerSocket(unix_socket=unix_socket)
  46. if certfile:
  47. warnings.warn("SSL only works with host:port, not unix_socket.")
  48. elif host and port:
  49. if certfile:
  50. server_socket = TSSLServerSocket(
  51. host=host, port=port, client_timeout=client_timeout,
  52. certfile=certfile)
  53. else:
  54. server_socket = TServerSocket(
  55. host=host, port=port, client_timeout=client_timeout)
  56. else:
  57. raise ValueError("Either host/port or unix_socket must be provided.")
  58. server = TThreadedServer(processor, server_socket,
  59. iprot_factory=proto_factory,
  60. itrans_factory=trans_factory)
  61. return server
  62. @contextlib.contextmanager
  63. def client_context(service, host="localhost", port=9090, unix_socket=None,
  64. proto_factory=TBinaryProtocolFactory(),
  65. trans_factory=TBufferedTransportFactory(),
  66. timeout=None, socket_timeout=3000, connect_timeout=3000,
  67. cafile=None, ssl_context=None, certfile=None, keyfile=None):
  68. if timeout:
  69. warnings.warn("`timeout` deprecated, use `socket_timeout` and "
  70. "`connect_timeout` instead.")
  71. socket_timeout = connect_timeout = timeout
  72. if unix_socket:
  73. socket = TSocket(unix_socket=unix_socket,
  74. connect_timeout=connect_timeout,
  75. socket_timeout=socket_timeout)
  76. if certfile:
  77. warnings.warn("SSL only works with host:port, not unix_socket.")
  78. elif host and port:
  79. if cafile or ssl_context:
  80. socket = TSSLSocket(host, port,
  81. connect_timeout=connect_timeout,
  82. socket_timeout=socket_timeout,
  83. cafile=cafile,
  84. certfile=certfile, keyfile=keyfile,
  85. ssl_context=ssl_context)
  86. else:
  87. socket = TSocket(host, port,
  88. connect_timeout=connect_timeout,
  89. socket_timeout=socket_timeout)
  90. else:
  91. raise ValueError("Either host/port or unix_socket must be provided.")
  92. try:
  93. transport = trans_factory.get_transport(socket)
  94. protocol = proto_factory.get_protocol(transport)
  95. transport.open()
  96. yield TClient(service, protocol)
  97. finally:
  98. transport.close()