thrift.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # -*- coding: utf-8 -*-
  2. """
  3. thriftpy2.thrift
  4. ~~~~~~~~~~~~~~~~~~
  5. Thrift simplified.
  6. """
  7. from __future__ import absolute_import
  8. import functools
  9. import linecache
  10. import types
  11. from ._compat import with_metaclass, PY3
  12. if PY3:
  13. from itertools import zip_longest
  14. else:
  15. from itertools import izip_longest as zip_longest
  16. def args_to_kwargs(thrift_spec, *args, **kwargs):
  17. for item, value in zip_longest(sorted(thrift_spec.items()), args):
  18. arg_name = item[1][1]
  19. required = item[1][-1]
  20. if value is not None:
  21. kwargs[item[1][1]] = value
  22. if required and arg_name not in kwargs:
  23. raise ValueError(arg_name)
  24. return kwargs
  25. def parse_spec(ttype, spec=None):
  26. name_map = TType._VALUES_TO_NAMES
  27. def _type(s):
  28. return parse_spec(*s) if isinstance(s, tuple) else name_map[s]
  29. if spec is None:
  30. return name_map[ttype]
  31. if ttype == TType.STRUCT:
  32. return spec.__name__
  33. if ttype in (TType.LIST, TType.SET):
  34. return "%s<%s>" % (name_map[ttype], _type(spec))
  35. if ttype == TType.MAP:
  36. return "MAP<%s, %s>" % (_type(spec[0]), _type(spec[1]))
  37. def init_func_generator(cls, spec):
  38. """Generate `__init__` function based on TPayload.default_spec
  39. For example::
  40. spec = [('name', 'Alice'), ('number', None)]
  41. will generate a types.FunctionType object representing::
  42. def __init__(self, name='Alice', number=None):
  43. self.name = name
  44. self.number = number
  45. """
  46. if not spec:
  47. def __init__(self):
  48. pass
  49. return __init__
  50. varnames, defaults = zip(*spec)
  51. args = ', '.join(map('{0[0]}={0[1]!r}'.format, spec))
  52. init = "def __init__(self, {}):\n".format(args)
  53. init += "\n".join(map(' self.{0} = {0}'.format, varnames))
  54. name = '<generated {}.__init__>'.format(cls.__name__)
  55. code = compile(init, name, 'exec')
  56. func = next(c for c in code.co_consts if isinstance(c, types.CodeType))
  57. # Add a fake linecache entry so debuggers and the traceback module can
  58. # better understand our generated code.
  59. linecache.cache[name] = (len(init), None, init.splitlines(True), name)
  60. return types.FunctionType(func, {}, argdefs=defaults)
  61. class TType(object):
  62. STOP = 0
  63. VOID = 1
  64. BOOL = 2
  65. BYTE = 3
  66. I08 = 3
  67. DOUBLE = 4
  68. I16 = 6
  69. I32 = 8
  70. I64 = 10
  71. STRING = 11
  72. UTF7 = 11
  73. BINARY = 11 # This here just for parsing. For all purposes, it's a string
  74. STRUCT = 12
  75. MAP = 13
  76. SET = 14
  77. LIST = 15
  78. UTF8 = 16
  79. UTF16 = 17
  80. _VALUES_TO_NAMES = {
  81. STOP: 'STOP',
  82. VOID: 'VOID',
  83. BOOL: 'BOOL',
  84. BYTE: 'BYTE',
  85. I08: 'BYTE',
  86. DOUBLE: 'DOUBLE',
  87. I16: 'I16',
  88. I32: 'I32',
  89. I64: 'I64',
  90. STRING: 'STRING',
  91. UTF7: 'STRING',
  92. BINARY: 'STRING',
  93. STRUCT: 'STRUCT',
  94. MAP: 'MAP',
  95. SET: 'SET',
  96. LIST: 'LIST',
  97. UTF8: 'UTF8',
  98. UTF16: 'UTF16'
  99. }
  100. class TMessageType(object):
  101. CALL = 1
  102. REPLY = 2
  103. EXCEPTION = 3
  104. ONEWAY = 4
  105. class TPayloadMeta(type):
  106. def __new__(cls, name, bases, attrs):
  107. if "default_spec" in attrs:
  108. spec = attrs.pop("default_spec")
  109. attrs["__init__"] = init_func_generator(cls, spec)
  110. return super(TPayloadMeta, cls).__new__(cls, name, bases, attrs)
  111. def gen_init(cls, thrift_spec=None, default_spec=None):
  112. if thrift_spec is not None:
  113. cls.thrift_spec = thrift_spec
  114. if default_spec is not None:
  115. cls.__init__ = init_func_generator(cls, default_spec)
  116. return cls
  117. class TPayload(with_metaclass(TPayloadMeta, object)):
  118. __hash__ = None
  119. def read(self, iprot):
  120. iprot.read_struct(self)
  121. def write(self, oprot):
  122. oprot.write_struct(self)
  123. def __repr__(self):
  124. l = ['%s=%r' % (key, value) for key, value in self.__dict__.items()]
  125. return '%s(%s)' % (self.__class__.__name__, ', '.join(l))
  126. def __str__(self):
  127. return repr(self)
  128. def __eq__(self, other):
  129. return isinstance(other, self.__class__) and \
  130. self.__dict__ == other.__dict__
  131. def __ne__(self, other):
  132. return not self.__eq__(other)
  133. class TClient(object):
  134. def __init__(self, service, iprot, oprot=None):
  135. self._service = service
  136. self._iprot = self._oprot = iprot
  137. if oprot is not None:
  138. self._oprot = oprot
  139. self._seqid = 0
  140. def __getattr__(self, _api):
  141. if _api in self._service.thrift_services:
  142. return functools.partial(self._req, _api)
  143. # close method is a reserved method name defined as below
  144. # so we need to handle it alone
  145. if _api == 'tclose':
  146. return functools.partial(self._req, 'close')
  147. raise AttributeError("{} instance has no attribute '{}'".format(
  148. self.__class__.__name__, _api))
  149. def __dir__(self):
  150. return self._service.thrift_services
  151. def _req(self, _api, *args, **kwargs):
  152. try:
  153. kwargs = args_to_kwargs(getattr(self._service, _api + "_args").thrift_spec,
  154. *args, **kwargs)
  155. except ValueError as e:
  156. raise TApplicationException(
  157. TApplicationException.UNKNOWN_METHOD,
  158. '{arg} is required argument for {service}.{api}'.format(
  159. arg=e.args[0], service=self._service.__name__, api=_api))
  160. result_cls = getattr(self._service, _api + "_result")
  161. self._send(_api, **kwargs)
  162. # wait result only if non-oneway
  163. if not getattr(result_cls, "oneway"):
  164. return self._recv(_api)
  165. def _send(self, _api, **kwargs):
  166. self._oprot.write_message_begin(_api, TMessageType.CALL, self._seqid)
  167. args = getattr(self._service, _api + "_args")()
  168. for k, v in kwargs.items():
  169. setattr(args, k, v)
  170. args.write(self._oprot)
  171. self._oprot.write_message_end()
  172. self._oprot.trans.flush()
  173. def _recv(self, _api):
  174. fname, mtype, rseqid = self._iprot.read_message_begin()
  175. if mtype == TMessageType.EXCEPTION:
  176. x = TApplicationException()
  177. x.read(self._iprot)
  178. self._iprot.read_message_end()
  179. raise x
  180. result = getattr(self._service, _api + "_result")()
  181. result.read(self._iprot)
  182. self._iprot.read_message_end()
  183. if hasattr(result, "success") and result.success is not None:
  184. return result.success
  185. # void api without throws
  186. if len(result.thrift_spec) == 0:
  187. return
  188. # check throws
  189. for k, v in result.__dict__.items():
  190. if k != "success" and v:
  191. raise v
  192. # no throws & not void api
  193. if hasattr(result, "success"):
  194. raise TApplicationException(TApplicationException.MISSING_RESULT)
  195. def close(self):
  196. self._iprot.trans.close()
  197. if self._iprot != self._oprot:
  198. self._oprot.trans.close()
  199. class TProcessor(object):
  200. """Base class for processor, which works on two streams."""
  201. def __init__(self, service, handler):
  202. self._service = service
  203. self._handler = handler
  204. def process_in(self, iprot):
  205. api, type, seqid = iprot.read_message_begin()
  206. if api not in self._service.thrift_services:
  207. iprot.skip(TType.STRUCT)
  208. iprot.read_message_end()
  209. return api, seqid, TApplicationException(TApplicationException.UNKNOWN_METHOD), None # noqa
  210. args = getattr(self._service, api + "_args")()
  211. args.read(iprot)
  212. iprot.read_message_end()
  213. result = getattr(self._service, api + "_result")()
  214. # convert kwargs to args
  215. api_args = [args.thrift_spec[k][1] for k in sorted(args.thrift_spec)]
  216. def call():
  217. f = getattr(self._handler, api)
  218. return f(*(args.__dict__[k] for k in api_args))
  219. return api, seqid, result, call
  220. def send_exception(self, oprot, api, exc, seqid):
  221. oprot.write_message_begin(api, TMessageType.EXCEPTION, seqid)
  222. exc.write(oprot)
  223. oprot.write_message_end()
  224. oprot.trans.flush()
  225. def send_result(self, oprot, api, result, seqid):
  226. oprot.write_message_begin(api, TMessageType.REPLY, seqid)
  227. result.write(oprot)
  228. oprot.write_message_end()
  229. oprot.trans.flush()
  230. def handle_exception(self, e, result):
  231. for k in sorted(result.thrift_spec):
  232. if result.thrift_spec[k][1] == "success":
  233. continue
  234. _, exc_name, exc_cls, _ = result.thrift_spec[k]
  235. if isinstance(e, exc_cls):
  236. setattr(result, exc_name, e)
  237. return True
  238. return False
  239. def process(self, iprot, oprot):
  240. api, seqid, result, call = self.process_in(iprot)
  241. if isinstance(result, TApplicationException):
  242. return self.send_exception(oprot, api, result, seqid)
  243. try:
  244. result.success = call()
  245. except Exception as e:
  246. # raise if api don't have throws
  247. if not self.handle_exception(e, result):
  248. raise
  249. if not result.oneway:
  250. self.send_result(oprot, api, result, seqid)
  251. class TMultiplexedProcessor(TProcessor):
  252. SEPARATOR = ":"
  253. def __init__(self):
  254. self.processors = {}
  255. def register_processor(self, service_name, processor):
  256. if service_name in self.processors:
  257. raise TApplicationException(
  258. type=TApplicationException.INTERNAL_ERROR,
  259. message='processor for `{}` already registered'
  260. .format(service_name))
  261. self.processors[service_name] = processor
  262. def process_in(self, iprot):
  263. api, type, seqid = iprot.read_message_begin()
  264. if type not in (TMessageType.CALL, TMessageType.ONEWAY):
  265. raise TException("TMultiplexed protocol only supports CALL & ONEWAY") # noqa
  266. if TMultiplexedProcessor.SEPARATOR not in api:
  267. raise TException("Service name not found in message. "
  268. "You should use TMultiplexedProtocol in client.")
  269. service_name, api = api.split(TMultiplexedProcessor.SEPARATOR)
  270. if service_name not in self.processors:
  271. iprot.skip(TType.STRUCT)
  272. iprot.read_message_end()
  273. e = TApplicationException(TApplicationException.UNKNOWN_METHOD)
  274. return api, seqid, e, None
  275. proc = self.processors[service_name]
  276. args = getattr(proc._service, api + "_args")()
  277. args.read(iprot)
  278. iprot.read_message_end()
  279. result = getattr(proc._service, api + "_result")()
  280. # convert kwargs to args
  281. api_args = [args.thrift_spec[k][1] for k in sorted(args.thrift_spec)]
  282. def call():
  283. f = getattr(proc._handler, api)
  284. return f(*(args.__dict__[k] for k in api_args))
  285. return api, seqid, result, call
  286. class TProcessorFactory(object):
  287. def __init__(self, processor_class, *args, **kwargs):
  288. self.args = args
  289. self.kwargs = kwargs
  290. self.processor_class = processor_class
  291. def get_processor(self):
  292. return self.processor_class(*self.args, **self.kwargs)
  293. class TException(TPayload, Exception):
  294. """Base class for all thrift exceptions."""
  295. def __hash__(self):
  296. return id(self)
  297. def __eq__(self, other):
  298. return id(self) == id(other)
  299. class TDecodeException(TException):
  300. def __init__(self, name, fid, field, value, ttype, spec=None):
  301. self.struct_name = name
  302. self.fid = fid
  303. self.field = field
  304. self.value = value
  305. self.type_repr = parse_spec(ttype, spec)
  306. def __str__(self):
  307. return (
  308. "Field '%s(%s)' of '%s' needs type '%s', "
  309. "but the value is `%r`"
  310. ) % (self.field, self.fid, self.struct_name, self.type_repr,
  311. self.value)
  312. class TApplicationException(TException):
  313. """Application level thrift exceptions."""
  314. thrift_spec = {
  315. 1: (TType.STRING, 'message', False),
  316. 2: (TType.I32, 'type', False),
  317. }
  318. UNKNOWN = 0
  319. UNKNOWN_METHOD = 1
  320. INVALID_MESSAGE_TYPE = 2
  321. WRONG_METHOD_NAME = 3
  322. BAD_SEQUENCE_ID = 4
  323. MISSING_RESULT = 5
  324. INTERNAL_ERROR = 6
  325. PROTOCOL_ERROR = 7
  326. def __init__(self, type=UNKNOWN, message=None):
  327. super(TApplicationException, self).__init__()
  328. self.type = type
  329. self.message = message
  330. def __str__(self):
  331. if self.message:
  332. return self.message
  333. if self.type == self.UNKNOWN_METHOD:
  334. return 'Unknown method'
  335. elif self.type == self.INVALID_MESSAGE_TYPE:
  336. return 'Invalid message type'
  337. elif self.type == self.WRONG_METHOD_NAME:
  338. return 'Wrong method name'
  339. elif self.type == self.BAD_SEQUENCE_ID:
  340. return 'Bad sequence ID'
  341. elif self.type == self.MISSING_RESULT:
  342. return 'Missing result'
  343. else:
  344. return 'Default (unknown) TApplicationException'