fake_shell.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import ast
  2. import cStringIO
  3. import datetime
  4. import decimal
  5. import json
  6. import logging
  7. import os
  8. import sys
  9. import traceback
  10. logging.basicConfig()
  11. logger = logging.getLogger('fake_shell')
  12. global_dict = {}
  13. execution_count = 0
  14. def execute_reply(status, content):
  15. global execution_count
  16. execution_count += 1
  17. return {
  18. 'msg_type': 'execute_reply',
  19. 'content': dict(
  20. content,
  21. status=status,
  22. execution_count=execution_count - 1
  23. )
  24. }
  25. def execute_reply_ok(data):
  26. return execute_reply('ok', {
  27. 'data': data,
  28. })
  29. def execute_reply_error(exc_type, exc_value, tb):
  30. logger.error('execute_reply', exc_info=True)
  31. return execute_reply('error', {
  32. 'ename': unicode(exc_type.__name__),
  33. 'evalue': unicode(exc_value),
  34. 'traceback': traceback.format_exception(exc_type, exc_value, tb, -1),
  35. })
  36. def execute(code):
  37. try:
  38. code = ast.parse(code)
  39. to_run_exec, to_run_single = code.body[:-1], code.body[-1:]
  40. for node in to_run_exec:
  41. mod = ast.Module([node])
  42. code = compile(mod, '<stdin>', 'exec')
  43. exec code in global_dict
  44. for node in to_run_single:
  45. mod = ast.Interactive([node])
  46. code = compile(mod, '<stdin>', 'single')
  47. exec code in global_dict
  48. except:
  49. return execute_reply_error(*sys.exc_info())
  50. stdout = fake_stdout.getvalue()
  51. fake_stdout.truncate(0)
  52. stderr = fake_stderr.getvalue()
  53. fake_stderr.truncate(0)
  54. output = ''
  55. if stdout:
  56. output += stdout
  57. if stderr:
  58. output += stderr
  59. return execute_reply_ok({
  60. 'text/plain': output.rstrip(),
  61. })
  62. def execute_request(content):
  63. try:
  64. code = content['code']
  65. except KeyError:
  66. exc_type, exc_value, tb = sys.exc_info()
  67. return execute_reply_error(exc_type, exc_value, [])
  68. lines = code.split('\n')
  69. if lines and lines[-1].startswith('%'):
  70. code, magic = lines[:-1], lines[-1]
  71. # Make sure to execute the other lines first.
  72. if code:
  73. result = execute('\n'.join(code))
  74. if result['content']['status'] != 'ok':
  75. return result
  76. parts = magic[1:].split(' ', 1)
  77. if len(parts) == 1:
  78. magic, rest = parts[0], ()
  79. else:
  80. magic, rest = parts[0], (parts[1],)
  81. try:
  82. handler = magic_router[magic]
  83. except KeyError:
  84. exc_type, exc_value, tb = sys.exc_info()
  85. return execute_reply_error(exc_type, exc_value, [])
  86. else:
  87. return handler(*rest)
  88. else:
  89. return execute(code)
  90. def magic_table_convert(value):
  91. try:
  92. converter = magic_table_types[type(value)]
  93. except KeyError:
  94. converter = magic_table_types[str]
  95. return converter(value)
  96. def magic_table_convert_seq(items):
  97. last_item_type = None
  98. converted_items = []
  99. for item in items:
  100. item_type, item = magic_table_convert(item)
  101. if last_item_type is None:
  102. last_item_type = item_type
  103. elif last_item_type != item_type:
  104. raise ValueError('value has inconsistent types')
  105. converted_items.append(item)
  106. return 'ARRAY_TYPE', converted_items
  107. def magic_table_convert_map(m):
  108. last_key_type = None
  109. last_value_type = None
  110. converted_items = {}
  111. for key, value in m:
  112. key_type, key = magic_table_convert(key)
  113. value_type, value = magic_table_convert(value)
  114. if last_key_type is None:
  115. last_key_type = key_type
  116. elif last_value_type != value_type:
  117. raise ValueError('value has inconsistent types')
  118. if last_value_type is None:
  119. last_value_type = value_type
  120. elif last_value_type != value_type:
  121. raise ValueError('value has inconsistent types')
  122. converted_items[key] = value
  123. return 'MAP_TYPE', items
  124. magic_table_types = {
  125. type(None): lambda x: ('NULL_TYPE', x),
  126. bool: lambda x: ('BOOLEAN_TYPE', x),
  127. int: lambda x: ('INT_TYPE', x),
  128. long: lambda x: ('BIGINT_TYPE', x),
  129. float: lambda x: ('DOUBLE_TYPE', x),
  130. str: lambda x: ('STRING_TYPE', str(x)),
  131. unicode: lambda x: ('STRING_TYPE', x.encode('utf-8')),
  132. datetime.date: lambda x: ('DATE_TYPE', str(x)),
  133. datetime.datetime: lambda x: ('TIMESTAMP_TYPE', str(x)),
  134. decimal.Decimal: lambda x: ('DECIMAL_TYPE', str(x)),
  135. tuple: magic_table_convert_seq,
  136. list: magic_table_convert_seq,
  137. dict: magic_table_convert_map,
  138. }
  139. def magic_table(name):
  140. try:
  141. value = global_dict[name]
  142. except KeyError:
  143. exc_type, exc_value, tb = sys.exc_info()
  144. return execute_reply_error(exc_type, exc_value, [])
  145. if not isinstance(value, (list, tuple)):
  146. value = [value]
  147. headers = {}
  148. data = []
  149. for row in value:
  150. cols = []
  151. data.append(cols)
  152. if not isinstance(row, (list, tuple, dict)):
  153. row = [row]
  154. if isinstance(row, (list, tuple)):
  155. iterator = enumerate(row)
  156. else:
  157. iterator = row.iteritems()
  158. for name, col in iterator:
  159. col_type, col = magic_table_convert(col)
  160. try:
  161. header = headers[name]
  162. except KeyError:
  163. header = {
  164. 'name': str(name),
  165. 'type': col_type,
  166. }
  167. headers[name] = header
  168. else:
  169. # Reject columns that have a different type.
  170. if header['type'] != col_type:
  171. exc_type = Exception
  172. exc_value = 'table rows have different types'
  173. return execute_reply_error(exc_type, exc_value, [])
  174. cols.append(col)
  175. headers = [v for k, v in sorted(headers.iteritems())]
  176. return execute_reply_ok({
  177. 'application/vnd.livy.table.v1+json': {
  178. 'headers': headers,
  179. 'data': data,
  180. }
  181. })
  182. magic_router = {
  183. 'table': magic_table,
  184. }
  185. msg_type_router = {
  186. 'execute_request': execute_request,
  187. }
  188. sys_stdin = sys.stdin
  189. sys_stdout = sys.stdout
  190. sys_stderr = sys.stderr
  191. fake_stdin = cStringIO.StringIO()
  192. fake_stdout = cStringIO.StringIO()
  193. fake_stderr = cStringIO.StringIO()
  194. sys.stdin = fake_stdin
  195. sys.stdout = fake_stdout
  196. sys.stderr = fake_stderr
  197. print >> sys_stdout, 'READY'
  198. sys_stdout.flush()
  199. try:
  200. # Load any startup files
  201. try:
  202. startup = os.environ['PYTHONSTARTUP']
  203. except KeyError:
  204. pass
  205. else:
  206. execfile(startup, global_dict)
  207. fake_stdout.truncate(0)
  208. fake_stderr.truncate(0)
  209. while True:
  210. line = sys_stdin.readline()
  211. if line == '':
  212. break
  213. elif line == '\n':
  214. continue
  215. try:
  216. msg = json.loads(line)
  217. except ValueError:
  218. logger.error('failed to parse message', exc_info=True)
  219. continue
  220. try:
  221. msg_type = msg['msg_type']
  222. except KeyError:
  223. logger.error('missing message type', exc_info=True)
  224. continue
  225. try:
  226. content = msg['content']
  227. except KeyError:
  228. logger.error('missing content', exc_info=True)
  229. continue
  230. try:
  231. handler = msg_type_router[msg_type]
  232. except KeyError:
  233. logger.error('unknown message type: %s', msg_type)
  234. continue
  235. response = handler(content)
  236. try:
  237. response = json.dumps(response)
  238. except ValueError, e:
  239. response = json.dumps({
  240. 'msg_type': 'inspect_reply',
  241. 'execution_count': execution_count - 1,
  242. 'content': {
  243. 'status': 'error',
  244. 'ename': 'ValueError',
  245. 'evalue': 'cannot json-ify %s' % response,
  246. 'traceback': [],
  247. }
  248. })
  249. print >> sys_stdout, response
  250. sys_stdout.flush()
  251. finally:
  252. sys.stdin = sys_stdin
  253. sys.stdout = sys_stdout
  254. sys.stderr = sys_stderr