Преглед на файлове

HUE-2917 [livy] Implement scala spark introspection

Erick Tryzelaar преди 10 години
родител
ревизия
62420de

+ 1 - 1
apps/spark/java/livy-repl/src/main/resources/fake_shell.py

@@ -221,7 +221,7 @@ def magic_table(name):
         if isinstance(row, (list, tuple)):
             iterator = enumerate(row)
         else:
-            iterator = row.iteritems()
+            iterator = sorted(row.iteritems())
 
         for name, col in iterator:
             col_type, col = magic_table_convert(col)

+ 8 - 2
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/SparkSession.scala

@@ -55,7 +55,7 @@ private class SparkSession extends Session {
 
   override def execute(code: String): Statement = synchronized {
     val result = Future {
-      val content = interpreter.execute(code) match {
+      val response = interpreter.execute(code) match {
         case ExecuteComplete(executeCount, output) =>
           Map(
             "status" -> "ok",
@@ -64,6 +64,12 @@ private class SparkSession extends Session {
               "text/plain" -> output
             )
           )
+        case ExecuteMagic(executeCount, content) =>
+          Map(
+            "status" -> "ok",
+            "execution_count" -> executeCount,
+            "data" -> content
+          )
         case ExecuteIncomplete(executeCount, output) =>
           Map(
             "status" -> "error",
@@ -80,7 +86,7 @@ private class SparkSession extends Session {
           )
       }
 
-      parse(write(content))
+      Extraction.decompose(response)
     }
 
     val statement = Statement(_history.length, result)

+ 175 - 20
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/interpreter/Interpreter.scala

@@ -20,8 +20,10 @@ package com.cloudera.hue.livy.repl.scala.interpreter
 
 import java.io._
 
-import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.repl.SparkIMain
+import org.apache.spark.{SparkConf, SparkContext}
+import org.json4s.JsonAST._
+import org.json4s.{DefaultFormats, Extraction}
 
 import scala.concurrent.ExecutionContext
 import scala.tools.nsc.Settings
@@ -36,15 +38,21 @@ object Interpreter {
   case class Busy() extends State
   case class ShuttingDown() extends State
   case class ShutDown() extends State
+
+  private val MAGIC_REGEX = "^%(\\w+)\\W*(.*)".r
 }
 
 sealed abstract class ExecuteResponse(executeCount: Int)
 case class ExecuteComplete(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
 case class ExecuteIncomplete(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
 case class ExecuteError(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
+case class ExecuteMagic(executeCount: Int, content: JValue) extends ExecuteResponse(executeCount)
 
 class Interpreter {
+  import Interpreter._
+
   private implicit def executor: ExecutionContext = ExecutionContext.global
+  private implicit def formats = DefaultFormats
 
   private var _state: Interpreter.State = Interpreter.NotStarted()
   private val outputStream = new ByteArrayOutputStream()
@@ -92,37 +100,184 @@ class Interpreter {
     constructor.newInstance(settings, out, false: java.lang.Boolean).asInstanceOf[SparkIMain]
   }
 
-  def execute(code: String): ExecuteResponse = synchronized {
-    executeCount += 1
+  private def executeMagic(magic: String, rest: String): ExecuteResponse = {
+    magic match {
+      case "json" => executeJsonMagic(rest)
+      case "table" => executeTableMagic(rest)
+      case _ =>
+        ExecuteError(executeCount, f"Unknown magic command $magic")
+    }
+  }
 
-    _state = Interpreter.Busy()
+  private def executeJsonMagic(name: String): ExecuteResponse = {
+    sparkIMain.valueOfTerm(name) match {
+      case Some(value) =>
+        ExecuteMagic(
+          executeCount,
+          Extraction.decompose(Map(
+            "application/json" -> value
+          ))
+        )
+      case None =>
+        ExecuteError(executeCount, f"Value $name does not exist")
+    }
+  }
+
+  private class TypesDoNotMatch extends Exception
+
+  private def convertTableType(value: JValue): String = {
+    value match {
+      case (JNothing | JNull) => "NULL_TYPE"
+      case JBool(_) => "BOOLEAN_TYPE"
+      case JString(_) => "STRING_TYPE"
+      case JInt(_) => "BIGINT_TYPE"
+      case JDouble(_) => "DOUBLE_TYPE"
+      case JDecimal(_) => "DECIMAL_TYPE"
+      case JArray(arr) =>
+        if (allSameType(arr.iterator)) {
+          "ARRAY_TYPE"
+        } else {
+          throw new TypesDoNotMatch
+        }
+      case JObject(obj) =>
+        if (allSameType(obj.iterator.map(_._2))) {
+          "MAP_TYPE"
+        } else {
+          throw new TypesDoNotMatch
+        }
+    }
+  }
 
-    val result = scala.Console.withOut(outputStream) {
-      sparkIMain.interpret(code) match {
-        case Results.Success =>
-          val output = outputStream.toString("UTF-8").trim
-          outputStream.reset()
+  private def allSameType(values: Iterator[JValue]): Boolean = {
+    if (values.hasNext) {
+      val type_name = convertTableType(values.next())
+      values.forall { case value => type_name.equals(convertTableType(value)) }
+    } else {
+      true
+    }
+  }
 
-          ExecuteComplete(executeCount - 1, output)
+  private def executeTableMagic(name: String): ExecuteResponse = {
+    sparkIMain.valueOfTerm(name) match {
+      case None =>
+        ExecuteError(executeCount, f"Value $name does not exist")
+      case Some(valueRef) =>
+        // Convert the value into JSON and map it to a table.
+        val rows: List[JValue] = Extraction.decompose(valueRef) match {
+          case JArray(arr) => arr
+          case value => List(value)
+        }
+
+        try {
+          val headers = scala.collection.mutable.Map[String, Map[String, String]]()
+
+          val data = rows.map { case row =>
+            val cols: List[JField] = row match {
+              case JArray(arr: List[JValue]) =>
+                arr.zipWithIndex.map { case (v, index) => JField(index.toString, v) }
+              case JObject(obj) => obj.sortBy(_._1)
+              case value: JValue => List(JField("0", value))
+            }
+
+            cols.map { case (name, value) =>
+              val typeName = convertTableType(value)
+
+              headers.get(name) match {
+                case Some(header) =>
+                  if (header.get("type").get != typeName) {
+                    throw new TypesDoNotMatch
+                  }
+                case None =>
+                  headers.put(name, Map(
+                    "name" -> name,
+                    "type" -> typeName
+                  ))
+              }
+
+              value
+            }
+          }
+
+          ExecuteMagic(
+            executeCount,
+            Extraction.decompose(Map(
+              "application/vnd.livy.table.v1+json" -> Map(
+                "headers" -> headers.toSeq.sortBy(_._1).map(_._2),
+                "data" -> data
+              )
+            ))
+          )
+        } catch {
+          case _: TypesDoNotMatch =>
+            ExecuteError(
+              executeCount,
+              "table rows have different types"
+            )
+        }
+    }
+  }
 
-        case Results.Incomplete =>
-          val output = outputStream.toString("UTF-8").trim
-          outputStream.reset()
+  def execute(code: String): ExecuteResponse = synchronized {
+    executeCount += 1
 
-          ExecuteIncomplete(executeCount - 1, output)
+    _state = Interpreter.Busy()
 
-        case Results.Error =>
-          val output = outputStream.toString("UTF-8").trim
-          outputStream.reset()
-          ExecuteError(executeCount - 1, output)
-      }
-    }
+    val result = executeLines(code.trim.split("\n").toList, ExecuteComplete(executeCount, ""))
 
     _state = Interpreter.Idle()
 
     result
   }
 
+  private def executeLines(lines: List[String], result: ExecuteResponse): ExecuteResponse = {
+    lines match {
+      case Nil => result
+      case head :: tail =>
+        val result = executeLine(head)
+
+        result match {
+          case ExecuteIncomplete(_, _) =>
+            tail match {
+              case Nil => result
+              case next :: nextTail => executeLines(head + "\n" + next :: nextTail, result)
+            }
+          case ExecuteError(_, _) =>
+            result
+
+          case _ =>
+            executeLines(tail, result)
+        }
+    }
+  }
+
+  def executeLine(code: String) = {
+    code match {
+      case MAGIC_REGEX(magic, rest) =>
+        executeMagic(magic, rest)
+      case _ =>
+        scala.Console.withOut(outputStream) {
+          sparkIMain.interpret(code) match {
+            case Results.Success =>
+              val output = outputStream.toString("UTF-8").trim
+              outputStream.reset()
+
+              ExecuteComplete(executeCount - 1, output)
+
+            case Results.Incomplete =>
+              val output = outputStream.toString("UTF-8").trim
+              outputStream.reset()
+
+              ExecuteIncomplete(executeCount - 1, output)
+
+            case Results.Error =>
+              val output = outputStream.toString("UTF-8").trim
+              outputStream.reset()
+              ExecuteError(executeCount - 1, output)
+          }
+        }
+    }
+  }
+
   def shutdown(): Unit = synchronized {
     _state = Interpreter.ShuttingDown()
 

+ 24 - 3
apps/spark/java/livy-repl/src/test/scala/com/cloudera/hue/livy/repl/SparkSessionSpec.scala

@@ -156,9 +156,7 @@ class SparkSessionSpec extends BaseSessionSpec {
 
     it("should execute spark commands") {
       val statement = session.execute(
-        """
-          |sc.parallelize(0 to 1).map{i => i+1}.collect
-          |""".stripMargin)
+        """sc.parallelize(0 to 1).map{i => i+1}.collect""".stripMargin)
       statement.id should equal (0)
 
       val result = Await.result(statement.result, Duration.Inf)
@@ -173,5 +171,28 @@ class SparkSessionSpec extends BaseSessionSpec {
 
       result should equal (expectedResult)
     }
+
+    it("should do table magic") {
+      val statement = session.execute("val x = List((1, \"a\"), (3, \"b\"))\n%table x")
+      statement.id should equal (0)
+
+      val result = Await.result(statement.result, Duration.Inf)
+
+
+      val expectedResult = Extraction.decompose(Map(
+        "status" -> "ok",
+        "execution_count" -> 1,
+        "data" -> Map(
+          "application/vnd.livy.table.v1+json" -> Map(
+            "headers" -> List(
+              Map("type" -> "BIGINT_TYPE", "name" -> "_1"),
+              Map("type" -> "STRING_TYPE", "name" -> "_2")),
+            "data" -> List(List(1, "a"), List(3, "b"))
+          )
+        )
+      ))
+
+      result should equal (expectedResult)
+    }
   }
  }