소스 검색

[hue] Fix parsing classes in pyspark

This changes how python code is compiled. Before, we compiled
line by line until a chunk compiled. This was done in order to
catch magic lines while still allowing for '%' to be used inside
a multiline string as in:

```
x = """
%string"""
%json x
```

This would technique would properly return the json value of
`{"text/plain": "\n%string"}`.

Unfortunately this approach doesn't work with classes with
multiple methods, because a class with a single method
is syntatically valid. This patch fixes this bug by parsing
the whole snippet at once, and only if there's a syntax error
will it try to see if there's a magic command in the
snippet.
Erick Tryzelaar 10 년 전
부모
커밋
577a0b6

+ 144 - 72
apps/spark/java/livy-repl/src/main/resources/fake_shell.py

@@ -16,7 +16,6 @@
 
 import ast
 import cStringIO
-import collections
 import datetime
 import decimal
 import json
@@ -53,98 +52,166 @@ def execute_reply_error(exc_type, exc_value, tb):
     })
 
 
-def execute(code):
-    try:
-        to_run_exec, to_run_single = code.body[:-1], code.body[-1:]
-
-        for node in to_run_exec:
-            mod = ast.Module([node])
-            code = compile(mod, '<stdin>', 'exec')
-            exec code in global_dict
-
-        for node in to_run_single:
-            mod = ast.Interactive([node])
-            code = compile(mod, '<stdin>', 'single')
-            exec code in global_dict
-    except:
-        # We don't need to log the exception because we're just executing user
-        # code and passing the error along.
-        return execute_reply_error(*sys.exc_info())
+def execute_reply_internal_error(message, exc_info=None):
+    LOG.error('execute_reply_internal_error', exc_info=exc_info)
+    return execute_reply('error', {
+        'ename': 'InternalError',
+        'evalue': message,
+        'traceback': [],
+    })
 
-    stdout = fake_stdout.getvalue()
-    fake_stdout.truncate(0)
 
-    stderr = fake_stderr.getvalue()
-    fake_stderr.truncate(0)
+def ast_parse(code, filename='<stdin>', symbol='exec'):
+    return compile(code, filename, symbol, ast.PyCF_ONLY_AST, 1)
 
-    output = ''
 
-    if stdout:
-        output += stdout
+class ExecutionError(Exception):
+    def __init__(self, exc_info):
+        self.exc_info = exc_info
 
-    if stderr:
-        output += stderr
 
-    return execute_reply_ok({
-        'text/plain': output.rstrip(),
-    })
+class NormalNode(object):
+    def __init__(self, code):
+        self.code = ast_parse(code)
 
+    def execute(self):
+        to_run_exec, to_run_single = self.code.body[:-1], self.code.body[-1:]
 
-def execute_magic(line):
-    parts = line[1:].split(' ', 1)
-    if len(parts) == 1:
-        magic, rest = parts[0], ()
-    else:
-        magic, rest = parts[0], (parts[1],)
+        try:
+            for node in to_run_exec:
+                mod = ast.Module([node])
+                code = compile(mod, '<stdin>', 'exec')
+                exec code in global_dict
+
+            for node in to_run_single:
+                mod = ast.Interactive([node])
+                code = compile(mod, '<stdin>', 'single')
+                exec code in global_dict
+        except:
+            # We don't need to log the exception because we're just executing user
+            # code and passing the error along.
+            raise ExecutionError(sys.exc_info())
+
+
+class UnknownMagic(Exception):
+    pass
+
+
+class MagicNode(object):
+    def __init__(self, line):
+        parts = line[1:].split(' ', 1)
+        if len(parts) == 1:
+            self.magic, self.rest = parts[0], ()
+        else:
+            self.magic, self.rest = parts[0], (parts[1],)
+
+
+    def execute(self):
+        try:
+            self.handler = magic_router[self.magic]
+        except KeyError:
+            raise UnknownMagic(self.magic)
+
+        return self.handler(*self.rest)
 
+
+def parse_code_into_nodes(code):
+    nodes = []
     try:
-        handler = magic_router[magic]
-    except KeyError:
-        exc_type, exc_value, tb = sys.exc_info()
-        return execute_reply_error(exc_type, exc_value, [])
+        nodes.append(NormalNode(code))
+    except SyntaxError:
+        # It's possible we hit a syntax error because of a magic command. Split the code groups
+        # of 'normal code', and code that starts with a '%'. possibly magic code
+        # lines, and see if any of the lines
+        # Remove lines until we find a node that parses, then check if the next line is a magic
+        # line
+        # .
+
+        # Split the code into chunks of normal code, and possibly magic code, which starts with a '%'.
+        normal = []
+        chunks = []
+        for i, line in enumerate(code.rstrip().split('\n')):
+            if line.startswith('%'):
+                if normal:
+                    chunks.append(''.join(normal))
+                    normal = []
+
+                chunks.append(line)
+            else:
+                normal.append(line)
+
+        if normal:
+            chunks.append('\n'.join(normal))
+
+        # Convert the chunks into AST nodes. Let exceptions propagate.
+        for chunk in chunks:
+            if chunk.startswith('%'):
+                nodes.append(MagicNode(chunk))
+            else:
+                nodes.append(NormalNode(chunk))
+
+    return nodes
+
+
+def execute_code(code):
+    try:
+        code = ast.parse(code)
+    except SyntaxError, syntax_error:
+        # It's possible we hit a syntax error because of a magic command. So see if one seems
+        # to be present.
+        try:
+            execute_handling_magic(code)
+        except SyntaxError, syntax_error:
+            pass
     else:
-        return handler(*rest)
+        return execute(code)
 
 
 def execute_request(content):
     try:
         code = content['code']
     except KeyError:
+        return execute_reply_internal_error(
+            'Malformed message: content object missing "code"', sys.exc_info()
+        )
+
+    try:
+        nodes = parse_code_into_nodes(code)
+    except (SyntaxError, UnknownMagic):
         exc_type, exc_value, tb = sys.exc_info()
         return execute_reply_error(exc_type, exc_value, [])
 
-    lines = collections.deque(code.rstrip().split('\n'))
-    last_line = ''
-    result = None
+    try:
+        for node in nodes:
+            result = node.execute()
+    except ExecutionError, e:
+        return execute_reply_error(*e.exc_info)
 
-    while lines:
-        line = last_line + lines.popleft()
+    if result is None:
+        result = {}
 
-        if line.rstrip() == '':
-            continue
+    stdout = fake_stdout.getvalue()
+    fake_stdout.truncate(0)
 
-        if line.startswith('%'):
-            result = execute_magic(line)
-        else:
-            try:
-                code = ast.parse(line)
-            except SyntaxError:
-                last_line = line + '\n'
-                continue
-            else:
-                result = execute(code)
+    stderr = fake_stderr.getvalue()
+    fake_stderr.truncate(0)
 
-        if result['content']['status'] == 'ok':
-            last_line = ''
-        else:
-            return result
+    output = result.pop('text/plain', '')
+
+    if stdout:
+        output += stdout
+
+    if stderr:
+        output += stderr
+
+    output = output.rstrip()
+
+    # Only add the output if it exists, or if there are no other mimetypes in the result.
+    if output or not result:
+        result['text/plain'] = output.rstrip()
+
+    return execute_reply_ok(result)
 
-    if result is None:
-        return execute_reply_ok({
-            'text/plain': '',
-        })
-    else:
-        return result
 
 def magic_table_convert(value):
     try:
@@ -260,12 +327,12 @@ def magic_table(name):
 
     headers = [v for k, v in sorted(headers.iteritems())]
 
-    return execute_reply_ok({
+    return {
         'application/vnd.livy.table.v1+json': {
             'headers': headers,
             'data': data,
         }
-    })
+    }
 
 
 def magic_json(name):
@@ -275,9 +342,9 @@ def magic_json(name):
         exc_type, exc_value, tb = sys.exc_info()
         return execute_reply_error(exc_type, exc_value, [])
 
-    return execute_reply_ok({
+    return {
         'application/json': value,
-    })
+    }
 
 
 def shutdown_request(content):
@@ -346,6 +413,10 @@ try:
             LOG.error('missing content', exc_info=True)
             continue
 
+        if not isinstance(content, dict):
+            LOG.error('content is not a dictionary')
+            continue
+
         try:
             handler = msg_type_router[msg_type]
         except KeyError:
@@ -353,6 +424,7 @@ try:
             continue
 
         response = handler(content)
+
         try:
             response = json.dumps(response)
         except ValueError, e:

+ 52 - 0
apps/spark/java/livy-repl/src/test/scala/com/cloudera/hue/livy/repl/PythonInterpreterSpec.scala

@@ -67,6 +67,29 @@ class PythonInterpreterSpec extends BaseInterpreterSpec {
     ))
   }
 
+  it should "parse a class" in withInterpreter { interpreter =>
+    val response = interpreter.execute(
+      """
+        |class Counter(object):
+        |   def __init__(self):
+        |       self.count = 0
+        |
+        |   def add_one(self):
+        |       self.count += 1
+        |
+        |   def add_two(self):
+        |       self.count += 2
+        |
+        |counter = Counter()
+        |counter.add_one()
+        |counter.add_two()
+        |counter.count
+      """.stripMargin)
+    response should equal(Interpreter.ExecuteSuccess(
+      repl.TEXT_PLAIN -> "3"
+    ))
+  }
+
   it should "do json magic" in withInterpreter { interpreter =>
     val response = interpreter.execute(
       """x = [[1, 'a'], [3, 'b']]
@@ -132,6 +155,35 @@ class PythonInterpreterSpec extends BaseInterpreterSpec {
     ))
   }
 
+  it should "not execute part of the block if there is a syntax error" in withInterpreter { interpreter =>
+    var response = interpreter.execute(
+      """x = 1
+        |'
+      """.stripMargin)
+
+    response should equal(Interpreter.ExecuteError(
+      "SyntaxError",
+      "EOL while scanning string literal (<stdin>, line 2)",
+      List(
+        "  File \"<stdin>\", line 2\n",
+        "    '\n",
+        "    ^\n",
+        "SyntaxError: EOL while scanning string literal\n"
+      )
+    ))
+
+    response = interpreter.execute("x")
+    response should equal(Interpreter.ExecuteError(
+      "NameError",
+      "name 'x' is not defined",
+      List(
+        "Traceback (most recent call last):\n",
+        "NameError: name 'x' is not defined\n"
+      )
+    ))
+  }
+
+
   it should "execute spark commands" in withInterpreter { interpreter =>
     val response = interpreter.execute(
       """sc.parallelize(xrange(0, 2)).map(lambda i: i + 1).collect()""")