thrift.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # -*- coding: utf-8 -*-
  2. """
  3. thriftpy.thrift
  4. ~~~~~~~~~~~~~~~~~~
  5. Thrift simplified.
  6. """
  7. from __future__ import absolute_import
  8. import functools
  9. import linecache
  10. import types
  11. from ._compat import with_metaclass
  12. def args2kwargs(thrift_spec, *args):
  13. arg_names = [item[1][1] for item in sorted(thrift_spec.items())]
  14. return dict(zip(arg_names, args))
  15. def parse_spec(ttype, spec=None):
  16. name_map = TType._VALUES_TO_NAMES
  17. def _type(s):
  18. return parse_spec(*s) if isinstance(s, tuple) else name_map[s]
  19. if spec is None:
  20. return name_map[ttype]
  21. if ttype == TType.STRUCT:
  22. return spec.__name__
  23. if ttype in (TType.LIST, TType.SET):
  24. return "%s<%s>" % (name_map[ttype], _type(spec))
  25. if ttype == TType.MAP:
  26. return "MAP<%s, %s>" % (_type(spec[0]), _type(spec[1]))
  27. def init_func_generator(cls, spec):
  28. """Generate `__init__` function based on TPayload.default_spec
  29. For example::
  30. spec = [('name', 'Alice'), ('number', None)]
  31. will generate a types.FunctionType object representing::
  32. def __init__(self, name='Alice', number=None):
  33. self.name = name
  34. self.number = number
  35. """
  36. if not spec:
  37. def __init__(self):
  38. pass
  39. return __init__
  40. varnames, defaults = zip(*spec)
  41. args = ', '.join(map('{0[0]}={0[1]!r}'.format, spec))
  42. init = "def __init__(self, {0}):\n".format(args)
  43. init += "\n".join(map(' self.{0} = {0}'.format, varnames))
  44. name = '<generated {0}.__init__>'.format(cls.__name__)
  45. code = compile(init, name, 'exec')
  46. func = next(c for c in code.co_consts if isinstance(c, types.CodeType))
  47. # Add a fake linecache entry so debuggers and the traceback module can
  48. # better understand our generated code.
  49. linecache.cache[name] = (len(init), None, init.splitlines(True), name)
  50. return types.FunctionType(func, {}, argdefs=defaults)
  51. class TType(object):
  52. STOP = 0
  53. VOID = 1
  54. BOOL = 2
  55. BYTE = 3
  56. I08 = 3
  57. DOUBLE = 4
  58. I16 = 6
  59. I32 = 8
  60. I64 = 10
  61. STRING = 11
  62. UTF7 = 11
  63. BINARY = 11 # This here just for parsing. For all purposes, it's a string
  64. STRUCT = 12
  65. MAP = 13
  66. SET = 14
  67. LIST = 15
  68. UTF8 = 16
  69. UTF16 = 17
  70. _VALUES_TO_NAMES = {
  71. STOP: 'STOP',
  72. VOID: 'VOID',
  73. BOOL: 'BOOL',
  74. BYTE: 'BYTE',
  75. I08: 'BYTE',
  76. DOUBLE: 'DOUBLE',
  77. I16: 'I16',
  78. I32: 'I32',
  79. I64: 'I64',
  80. STRING: 'STRING',
  81. UTF7: 'STRING',
  82. BINARY: 'STRING',
  83. STRUCT: 'STRUCT',
  84. MAP: 'MAP',
  85. SET: 'SET',
  86. LIST: 'LIST',
  87. UTF8: 'UTF8',
  88. UTF16: 'UTF16'
  89. }
  90. class TMessageType(object):
  91. CALL = 1
  92. REPLY = 2
  93. EXCEPTION = 3
  94. ONEWAY = 4
  95. class TPayloadMeta(type):
  96. def __new__(cls, name, bases, attrs):
  97. if "default_spec" in attrs:
  98. spec = attrs.pop("default_spec")
  99. attrs["__init__"] = init_func_generator(cls, spec)
  100. return super(TPayloadMeta, cls).__new__(cls, name, bases, attrs)
  101. def gen_init(cls, thrift_spec=None, default_spec=None):
  102. if thrift_spec is not None:
  103. cls.thrift_spec = thrift_spec
  104. if default_spec is not None:
  105. cls.__init__ = init_func_generator(cls, default_spec)
  106. return cls
  107. class TPayload(with_metaclass(TPayloadMeta, object)):
  108. __hash__ = None
  109. def read(self, iprot):
  110. iprot.read_struct(self)
  111. def write(self, oprot):
  112. oprot.write_struct(self)
  113. def __repr__(self):
  114. l = ['%s=%r' % (key, value) for key, value in self.__dict__.items()]
  115. return '%s(%s)' % (self.__class__.__name__, ', '.join(l))
  116. def __str__(self):
  117. return repr(self)
  118. def __eq__(self, other):
  119. return isinstance(other, self.__class__) and \
  120. self.__dict__ == other.__dict__
  121. def __ne__(self, other):
  122. return not self.__eq__(other)
  123. class TClient(object):
  124. def __init__(self, service, iprot, oprot=None):
  125. self._service = service
  126. self._iprot = self._oprot = iprot
  127. if oprot is not None:
  128. self._oprot = oprot
  129. self._seqid = 0
  130. def __getattr__(self, _api):
  131. if _api in self._service.thrift_services:
  132. return functools.partial(self._req, _api)
  133. raise AttributeError("{} instance has no attribute '{}'".format(
  134. self.__class__.__name__, _api))
  135. def __dir__(self):
  136. return self._service.thrift_services
  137. def _req(self, _api, *args, **kwargs):
  138. _kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec,
  139. *args)
  140. kwargs.update(_kw)
  141. result_cls = getattr(self._service, _api + "_result")
  142. self._send(_api, **kwargs)
  143. # wait result only if non-oneway
  144. if not getattr(result_cls, "oneway"):
  145. return self._recv(_api)
  146. def _send(self, _api, **kwargs):
  147. self._oprot.write_message_begin(_api, TMessageType.CALL, self._seqid)
  148. args = getattr(self._service, _api + "_args")()
  149. for k, v in kwargs.items():
  150. setattr(args, k, v)
  151. args.write(self._oprot)
  152. self._oprot.write_message_end()
  153. self._oprot.trans.flush()
  154. def _recv(self, _api):
  155. fname, mtype, rseqid = self._iprot.read_message_begin()
  156. if mtype == TMessageType.EXCEPTION:
  157. x = TApplicationException()
  158. x.read(self._iprot)
  159. self._iprot.read_message_end()
  160. raise x
  161. result = getattr(self._service, _api + "_result")()
  162. result.read(self._iprot)
  163. self._iprot.read_message_end()
  164. if hasattr(result, "success") and result.success is not None:
  165. return result.success
  166. # void api without throws
  167. if len(result.thrift_spec) == 0:
  168. return
  169. # check throws
  170. for k, v in result.__dict__.items():
  171. if k != "success" and v:
  172. raise v
  173. # no throws & not void api
  174. if hasattr(result, "success"):
  175. raise TApplicationException(TApplicationException.MISSING_RESULT)
  176. def close(self):
  177. self._iprot.trans.close()
  178. if self._iprot != self._oprot:
  179. self._oprot.trans.close()
  180. class TProcessor(object):
  181. """Base class for procsessor, which works on two streams."""
  182. def __init__(self, service, handler):
  183. self._service = service
  184. self._handler = handler
  185. def process_in(self, iprot):
  186. api, type, seqid = iprot.read_message_begin()
  187. if api not in self._service.thrift_services:
  188. iprot.skip(TType.STRUCT)
  189. iprot.read_message_end()
  190. return api, seqid, TApplicationException(TApplicationException.UNKNOWN_METHOD), None # noqa
  191. args = getattr(self._service, api + "_args")()
  192. args.read(iprot)
  193. iprot.read_message_end()
  194. result = getattr(self._service, api + "_result")()
  195. # convert kwargs to args
  196. api_args = [args.thrift_spec[k][1] for k in sorted(args.thrift_spec)]
  197. def call():
  198. f = getattr(self._handler, api)
  199. return f(*(args.__dict__[k] for k in api_args))
  200. return api, seqid, result, call
  201. def send_exception(self, oprot, api, exc, seqid):
  202. oprot.write_message_begin(api, TMessageType.EXCEPTION, seqid)
  203. exc.write(oprot)
  204. oprot.write_message_end()
  205. oprot.trans.flush()
  206. def send_result(self, oprot, api, result, seqid):
  207. oprot.write_message_begin(api, TMessageType.REPLY, seqid)
  208. result.write(oprot)
  209. oprot.write_message_end()
  210. oprot.trans.flush()
  211. def handle_exception(self, e, result):
  212. for k in sorted(result.thrift_spec):
  213. if result.thrift_spec[k][1] == "success":
  214. continue
  215. _, exc_name, exc_cls, _ = result.thrift_spec[k]
  216. if isinstance(e, exc_cls):
  217. setattr(result, exc_name, e)
  218. break
  219. else:
  220. raise
  221. def process(self, iprot, oprot):
  222. api, seqid, result, call = self.process_in(iprot)
  223. if isinstance(result, TApplicationException):
  224. return self.send_exception(oprot, api, result, seqid)
  225. try:
  226. result.success = call()
  227. except Exception as e:
  228. # raise if api don't have throws
  229. self.handle_exception(e, result)
  230. if not result.oneway:
  231. self.send_result(oprot, api, result, seqid)
  232. class TMultiplexedProcessor(TProcessor):
  233. SEPARATOR = ":"
  234. def __init__(self):
  235. self.processors = {}
  236. def register_processor(self, service_name, processor):
  237. if service_name in self.processors:
  238. raise TApplicationException(
  239. type=TApplicationException.INTERNAL_ERROR,
  240. message='processor for `{0}` already registered'
  241. .format(service_name))
  242. self.processors[service_name] = processor
  243. def process_in(self, iprot):
  244. api, type, seqid = iprot.read_message_begin()
  245. if type not in (TMessageType.CALL, TMessageType.ONEWAY):
  246. raise TException("TMultiplex protocol only supports CALL & ONEWAY")
  247. if TMultiplexedProcessor.SEPARATOR not in api:
  248. raise TException("Service name not found in message. "
  249. "You should use TMultiplexedProtocol in client.")
  250. service_name, api = api.split(TMultiplexedProcessor.SEPARATOR)
  251. if service_name not in self.processors:
  252. iprot.skip(TType.STRUCT)
  253. iprot.read_message_end()
  254. e = TApplicationException(TApplicationException.UNKNOWN_METHOD)
  255. return api, seqid, e, None
  256. proc = self.processors[service_name]
  257. args = getattr(proc._service, api + "_args")()
  258. args.read(iprot)
  259. iprot.read_message_end()
  260. result = getattr(proc._service, api + "_result")()
  261. # convert kwargs to args
  262. api_args = [args.thrift_spec[k][1] for k in sorted(args.thrift_spec)]
  263. def call():
  264. f = getattr(proc._handler, api)
  265. return f(*(args.__dict__[k] for k in api_args))
  266. return api, seqid, result, call
  267. class TProcessorFactory(object):
  268. def __init__(self, processor_class, *args, **kwargs):
  269. self.args = args
  270. self.kwargs = kwargs
  271. self.processor_class = processor_class
  272. def get_processor(self):
  273. return self.processor_class(*self.args, **self.kwargs)
  274. class TException(TPayload, Exception):
  275. """Base class for all thrift exceptions."""
  276. def __hash__(self):
  277. return id(self)
  278. def __eq__(self, other):
  279. return id(self) == id(other)
  280. class TDecodeException(TException):
  281. def __init__(self, name, fid, field, value, ttype, spec=None):
  282. self.struct_name = name
  283. self.fid = fid
  284. self.field = field
  285. self.value = value
  286. self.type_repr = parse_spec(ttype, spec)
  287. def __str__(self):
  288. return (
  289. "Field '%s(%s)' of '%s' needs type '%s', "
  290. "but the value is `%r`"
  291. ) % (self.field, self.fid, self.struct_name, self.type_repr,
  292. self.value)
  293. class TApplicationException(TException):
  294. """Application level thrift exceptions."""
  295. thrift_spec = {
  296. 1: (TType.STRING, 'message', False),
  297. 2: (TType.I32, 'type', False),
  298. }
  299. UNKNOWN = 0
  300. UNKNOWN_METHOD = 1
  301. INVALID_MESSAGE_TYPE = 2
  302. WRONG_METHOD_NAME = 3
  303. BAD_SEQUENCE_ID = 4
  304. MISSING_RESULT = 5
  305. INTERNAL_ERROR = 6
  306. PROTOCOL_ERROR = 7
  307. def __init__(self, type=UNKNOWN, message=None):
  308. super(TApplicationException, self).__init__()
  309. self.type = type
  310. self.message = message
  311. def __str__(self):
  312. if self.message:
  313. return self.message
  314. if self.type == self.UNKNOWN_METHOD:
  315. return 'Unknown method'
  316. elif self.type == self.INVALID_MESSAGE_TYPE:
  317. return 'Invalid message type'
  318. elif self.type == self.WRONG_METHOD_NAME:
  319. return 'Wrong method name'
  320. elif self.type == self.BAD_SEQUENCE_ID:
  321. return 'Bad sequence ID'
  322. elif self.type == self.MISSING_RESULT:
  323. return 'Missing result'
  324. else:
  325. return 'Default (unknown) TApplicationException'