fake_shell.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # Licensed to Cloudera, Inc. under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. Cloudera, Inc. licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import ast
  17. import cStringIO
  18. import datetime
  19. import decimal
  20. import json
  21. import logging
  22. import sys
  23. import traceback
  24. logging.basicConfig()
  25. LOG = logging.getLogger('fake_shell')
  26. global_dict = {}
  27. execution_count = 0
  28. def execute_reply(status, content):
  29. global execution_count
  30. execution_count += 1
  31. return {
  32. 'msg_type': 'execute_reply',
  33. 'content': dict(
  34. content,
  35. status=status,
  36. execution_count=execution_count - 1
  37. )
  38. }
  39. def execute_reply_ok(data):
  40. return execute_reply('ok', {
  41. 'data': data,
  42. })
  43. def execute_reply_error(exc_type, exc_value, tb):
  44. LOG.error('execute_reply', exc_info=True)
  45. return execute_reply('error', {
  46. 'ename': unicode(exc_type.__name__),
  47. 'evalue': unicode(exc_value),
  48. 'traceback': traceback.format_exception(exc_type, exc_value, tb, -1),
  49. })
  50. def execute(code):
  51. try:
  52. code = ast.parse(code)
  53. to_run_exec, to_run_single = code.body[:-1], code.body[-1:]
  54. for node in to_run_exec:
  55. mod = ast.Module([node])
  56. code = compile(mod, '<stdin>', 'exec')
  57. exec code in global_dict
  58. for node in to_run_single:
  59. mod = ast.Interactive([node])
  60. code = compile(mod, '<stdin>', 'single')
  61. exec code in global_dict
  62. except:
  63. # We don't need to log the exception because we're just executing user
  64. # code and passing the error along.
  65. return execute_reply_error(*sys.exc_info())
  66. stdout = fake_stdout.getvalue()
  67. fake_stdout.truncate(0)
  68. stderr = fake_stderr.getvalue()
  69. fake_stderr.truncate(0)
  70. output = ''
  71. if stdout:
  72. output += stdout
  73. if stderr:
  74. output += stderr
  75. return execute_reply_ok({
  76. 'text/plain': output.rstrip(),
  77. })
  78. def execute_request(content):
  79. try:
  80. code = content['code']
  81. except KeyError:
  82. exc_type, exc_value, tb = sys.exc_info()
  83. return execute_reply_error(exc_type, exc_value, [])
  84. lines = code.split('\n')
  85. if lines and lines[-1].startswith('%'):
  86. code, magic = lines[:-1], lines[-1]
  87. # Make sure to execute the other lines first.
  88. if code:
  89. result = execute('\n'.join(code))
  90. if result['content']['status'] != 'ok':
  91. return result
  92. parts = magic[1:].split(' ', 1)
  93. if len(parts) == 1:
  94. magic, rest = parts[0], ()
  95. else:
  96. magic, rest = parts[0], (parts[1],)
  97. try:
  98. handler = magic_router[magic]
  99. except KeyError:
  100. exc_type, exc_value, tb = sys.exc_info()
  101. return execute_reply_error(exc_type, exc_value, [])
  102. else:
  103. return handler(*rest)
  104. else:
  105. return execute(code)
  106. def magic_table_convert(value):
  107. try:
  108. converter = magic_table_types[type(value)]
  109. except KeyError:
  110. converter = magic_table_types[str]
  111. return converter(value)
  112. def magic_table_convert_seq(items):
  113. last_item_type = None
  114. converted_items = []
  115. for item in items:
  116. item_type, item = magic_table_convert(item)
  117. if last_item_type is None:
  118. last_item_type = item_type
  119. elif last_item_type != item_type:
  120. raise ValueError('value has inconsistent types')
  121. converted_items.append(item)
  122. return 'ARRAY_TYPE', converted_items
  123. def magic_table_convert_map(m):
  124. last_key_type = None
  125. last_value_type = None
  126. converted_items = {}
  127. for key, value in m:
  128. key_type, key = magic_table_convert(key)
  129. value_type, value = magic_table_convert(value)
  130. if last_key_type is None:
  131. last_key_type = key_type
  132. elif last_value_type != value_type:
  133. raise ValueError('value has inconsistent types')
  134. if last_value_type is None:
  135. last_value_type = value_type
  136. elif last_value_type != value_type:
  137. raise ValueError('value has inconsistent types')
  138. converted_items[key] = value
  139. return 'MAP_TYPE', items
  140. magic_table_types = {
  141. type(None): lambda x: ('NULL_TYPE', x),
  142. bool: lambda x: ('BOOLEAN_TYPE', x),
  143. int: lambda x: ('INT_TYPE', x),
  144. long: lambda x: ('BIGINT_TYPE', x),
  145. float: lambda x: ('DOUBLE_TYPE', x),
  146. str: lambda x: ('STRING_TYPE', str(x)),
  147. unicode: lambda x: ('STRING_TYPE', x.encode('utf-8')),
  148. datetime.date: lambda x: ('DATE_TYPE', str(x)),
  149. datetime.datetime: lambda x: ('TIMESTAMP_TYPE', str(x)),
  150. decimal.Decimal: lambda x: ('DECIMAL_TYPE', str(x)),
  151. tuple: magic_table_convert_seq,
  152. list: magic_table_convert_seq,
  153. dict: magic_table_convert_map,
  154. }
  155. def magic_table(name):
  156. try:
  157. value = global_dict[name]
  158. except KeyError:
  159. exc_type, exc_value, tb = sys.exc_info()
  160. return execute_reply_error(exc_type, exc_value, [])
  161. if not isinstance(value, (list, tuple)):
  162. value = [value]
  163. headers = {}
  164. data = []
  165. for row in value:
  166. cols = []
  167. data.append(cols)
  168. if not isinstance(row, (list, tuple, dict)):
  169. row = [row]
  170. if isinstance(row, (list, tuple)):
  171. iterator = enumerate(row)
  172. else:
  173. iterator = sorted(row.iteritems())
  174. for name, col in iterator:
  175. col_type, col = magic_table_convert(col)
  176. try:
  177. header = headers[name]
  178. except KeyError:
  179. header = {
  180. 'name': str(name),
  181. 'type': col_type,
  182. }
  183. headers[name] = header
  184. else:
  185. # Reject columns that have a different type.
  186. if header['type'] != col_type:
  187. exc_type = Exception
  188. exc_value = 'table rows have different types'
  189. return execute_reply_error(exc_type, exc_value, [])
  190. cols.append(col)
  191. headers = [v for k, v in sorted(headers.iteritems())]
  192. return execute_reply_ok({
  193. 'application/vnd.livy.table.v1+json': {
  194. 'headers': headers,
  195. 'data': data,
  196. }
  197. })
  198. def shutdown_request(content):
  199. sys.exit()
  200. magic_router = {
  201. 'table': magic_table,
  202. }
  203. msg_type_router = {
  204. 'execute_request': execute_request,
  205. 'shutdown_request': shutdown_request,
  206. }
  207. sys_stdin = sys.stdin
  208. sys_stdout = sys.stdout
  209. sys_stderr = sys.stderr
  210. fake_stdin = cStringIO.StringIO()
  211. fake_stdout = cStringIO.StringIO()
  212. fake_stderr = cStringIO.StringIO()
  213. sys.stdin = fake_stdin
  214. sys.stdout = fake_stdout
  215. sys.stderr = fake_stderr
  216. try:
  217. # Load spark into the context
  218. exec 'from pyspark.shell import sc' in global_dict
  219. print >> sys_stderr, fake_stdout.getvalue()
  220. print >> sys_stderr, fake_stderr.getvalue()
  221. fake_stdout.truncate(0)
  222. fake_stderr.truncate(0)
  223. print >> sys_stdout, 'READY'
  224. sys_stdout.flush()
  225. while True:
  226. line = sys_stdin.readline()
  227. if line == '':
  228. break
  229. elif line == '\n':
  230. continue
  231. try:
  232. msg = json.loads(line)
  233. except ValueError:
  234. LOG.error('failed to parse message', exc_info=True)
  235. continue
  236. try:
  237. msg_type = msg['msg_type']
  238. except KeyError:
  239. LOG.error('missing message type', exc_info=True)
  240. continue
  241. try:
  242. content = msg['content']
  243. except KeyError:
  244. LOG.error('missing content', exc_info=True)
  245. continue
  246. try:
  247. handler = msg_type_router[msg_type]
  248. except KeyError:
  249. LOG.error('unknown message type: %s', msg_type)
  250. continue
  251. response = handler(content)
  252. try:
  253. response = json.dumps(response)
  254. except ValueError, e:
  255. response = json.dumps({
  256. 'msg_type': 'inspect_reply',
  257. 'execution_count': execution_count - 1,
  258. 'content': {
  259. 'status': 'error',
  260. 'ename': 'ValueError',
  261. 'evalue': 'cannot json-ify %s' % response,
  262. 'traceback': [],
  263. }
  264. })
  265. print >> sys_stdout, response
  266. sys_stdout.flush()
  267. finally:
  268. global_dict['sc'].stop()
  269. sys.stdin = sys_stdin
  270. sys.stdout = sys_stdout
  271. sys.stderr = sys_stderr