|
|
@@ -20,8 +20,10 @@ package com.cloudera.hue.livy.repl.scala.interpreter
|
|
|
|
|
|
import java.io._
|
|
|
|
|
|
-import org.apache.spark.{SparkConf, SparkContext}
|
|
|
import org.apache.spark.repl.SparkIMain
|
|
|
+import org.apache.spark.{SparkConf, SparkContext}
|
|
|
+import org.json4s.JsonAST._
|
|
|
+import org.json4s.{DefaultFormats, Extraction}
|
|
|
|
|
|
import scala.concurrent.ExecutionContext
|
|
|
import scala.tools.nsc.Settings
|
|
|
@@ -36,15 +38,21 @@ object Interpreter {
|
|
|
case class Busy() extends State
|
|
|
case class ShuttingDown() extends State
|
|
|
case class ShutDown() extends State
|
|
|
+
|
|
|
+ private val MAGIC_REGEX = "^%(\\w+)\\W*(.*)".r
|
|
|
}
|
|
|
|
|
|
sealed abstract class ExecuteResponse(executeCount: Int)
|
|
|
case class ExecuteComplete(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
|
|
|
case class ExecuteIncomplete(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
|
|
|
case class ExecuteError(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
|
|
|
+case class ExecuteMagic(executeCount: Int, content: JValue) extends ExecuteResponse(executeCount)
|
|
|
|
|
|
class Interpreter {
|
|
|
+ import Interpreter._
|
|
|
+
|
|
|
private implicit def executor: ExecutionContext = ExecutionContext.global
|
|
|
+ private implicit def formats = DefaultFormats
|
|
|
|
|
|
private var _state: Interpreter.State = Interpreter.NotStarted()
|
|
|
private val outputStream = new ByteArrayOutputStream()
|
|
|
@@ -92,37 +100,184 @@ class Interpreter {
|
|
|
constructor.newInstance(settings, out, false: java.lang.Boolean).asInstanceOf[SparkIMain]
|
|
|
}
|
|
|
|
|
|
- def execute(code: String): ExecuteResponse = synchronized {
|
|
|
- executeCount += 1
|
|
|
+ private def executeMagic(magic: String, rest: String): ExecuteResponse = {
|
|
|
+ magic match {
|
|
|
+ case "json" => executeJsonMagic(rest)
|
|
|
+ case "table" => executeTableMagic(rest)
|
|
|
+ case _ =>
|
|
|
+ ExecuteError(executeCount, f"Unknown magic command $magic")
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- _state = Interpreter.Busy()
|
|
|
+ private def executeJsonMagic(name: String): ExecuteResponse = {
|
|
|
+ sparkIMain.valueOfTerm(name) match {
|
|
|
+ case Some(value) =>
|
|
|
+ ExecuteMagic(
|
|
|
+ executeCount,
|
|
|
+ Extraction.decompose(Map(
|
|
|
+ "application/json" -> value
|
|
|
+ ))
|
|
|
+ )
|
|
|
+ case None =>
|
|
|
+ ExecuteError(executeCount, f"Value $name does not exist")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private class TypesDoNotMatch extends Exception
|
|
|
+
|
|
|
+ private def convertTableType(value: JValue): String = {
|
|
|
+ value match {
|
|
|
+ case (JNothing | JNull) => "NULL_TYPE"
|
|
|
+ case JBool(_) => "BOOLEAN_TYPE"
|
|
|
+ case JString(_) => "STRING_TYPE"
|
|
|
+ case JInt(_) => "BIGINT_TYPE"
|
|
|
+ case JDouble(_) => "DOUBLE_TYPE"
|
|
|
+ case JDecimal(_) => "DECIMAL_TYPE"
|
|
|
+ case JArray(arr) =>
|
|
|
+ if (allSameType(arr.iterator)) {
|
|
|
+ "ARRAY_TYPE"
|
|
|
+ } else {
|
|
|
+ throw new TypesDoNotMatch
|
|
|
+ }
|
|
|
+ case JObject(obj) =>
|
|
|
+ if (allSameType(obj.iterator.map(_._2))) {
|
|
|
+ "MAP_TYPE"
|
|
|
+ } else {
|
|
|
+ throw new TypesDoNotMatch
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- val result = scala.Console.withOut(outputStream) {
|
|
|
- sparkIMain.interpret(code) match {
|
|
|
- case Results.Success =>
|
|
|
- val output = outputStream.toString("UTF-8").trim
|
|
|
- outputStream.reset()
|
|
|
+ private def allSameType(values: Iterator[JValue]): Boolean = {
|
|
|
+ if (values.hasNext) {
|
|
|
+ val type_name = convertTableType(values.next())
|
|
|
+ values.forall { case value => type_name.equals(convertTableType(value)) }
|
|
|
+ } else {
|
|
|
+ true
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- ExecuteComplete(executeCount - 1, output)
|
|
|
+ private def executeTableMagic(name: String): ExecuteResponse = {
|
|
|
+ sparkIMain.valueOfTerm(name) match {
|
|
|
+ case None =>
|
|
|
+ ExecuteError(executeCount, f"Value $name does not exist")
|
|
|
+ case Some(valueRef) =>
|
|
|
+ // Convert the value into JSON and map it to a table.
|
|
|
+ val rows: List[JValue] = Extraction.decompose(valueRef) match {
|
|
|
+ case JArray(arr) => arr
|
|
|
+ case value => List(value)
|
|
|
+ }
|
|
|
+
|
|
|
+ try {
|
|
|
+ val headers = scala.collection.mutable.Map[String, Map[String, String]]()
|
|
|
+
|
|
|
+ val data = rows.map { case row =>
|
|
|
+ val cols: List[JField] = row match {
|
|
|
+ case JArray(arr: List[JValue]) =>
|
|
|
+ arr.zipWithIndex.map { case (v, index) => JField(index.toString, v) }
|
|
|
+ case JObject(obj) => obj.sortBy(_._1)
|
|
|
+ case value: JValue => List(JField("0", value))
|
|
|
+ }
|
|
|
+
|
|
|
+ cols.map { case (name, value) =>
|
|
|
+ val typeName = convertTableType(value)
|
|
|
+
|
|
|
+ headers.get(name) match {
|
|
|
+ case Some(header) =>
|
|
|
+ if (header.get("type").get != typeName) {
|
|
|
+ throw new TypesDoNotMatch
|
|
|
+ }
|
|
|
+ case None =>
|
|
|
+ headers.put(name, Map(
|
|
|
+ "name" -> name,
|
|
|
+ "type" -> typeName
|
|
|
+ ))
|
|
|
+ }
|
|
|
+
|
|
|
+ value
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ ExecuteMagic(
|
|
|
+ executeCount,
|
|
|
+ Extraction.decompose(Map(
|
|
|
+ "application/vnd.livy.table.v1+json" -> Map(
|
|
|
+ "headers" -> headers.toSeq.sortBy(_._1).map(_._2),
|
|
|
+ "data" -> data
|
|
|
+ )
|
|
|
+ ))
|
|
|
+ )
|
|
|
+ } catch {
|
|
|
+ case _: TypesDoNotMatch =>
|
|
|
+ ExecuteError(
|
|
|
+ executeCount,
|
|
|
+ "table rows have different types"
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- case Results.Incomplete =>
|
|
|
- val output = outputStream.toString("UTF-8").trim
|
|
|
- outputStream.reset()
|
|
|
+ def execute(code: String): ExecuteResponse = synchronized {
|
|
|
+ executeCount += 1
|
|
|
|
|
|
- ExecuteIncomplete(executeCount - 1, output)
|
|
|
+ _state = Interpreter.Busy()
|
|
|
|
|
|
- case Results.Error =>
|
|
|
- val output = outputStream.toString("UTF-8").trim
|
|
|
- outputStream.reset()
|
|
|
- ExecuteError(executeCount - 1, output)
|
|
|
- }
|
|
|
- }
|
|
|
+ val result = executeLines(code.trim.split("\n").toList, ExecuteComplete(executeCount, ""))
|
|
|
|
|
|
_state = Interpreter.Idle()
|
|
|
|
|
|
result
|
|
|
}
|
|
|
|
|
|
+ private def executeLines(lines: List[String], result: ExecuteResponse): ExecuteResponse = {
|
|
|
+ lines match {
|
|
|
+ case Nil => result
|
|
|
+ case head :: tail =>
|
|
|
+ val result = executeLine(head)
|
|
|
+
|
|
|
+ result match {
|
|
|
+ case ExecuteIncomplete(_, _) =>
|
|
|
+ tail match {
|
|
|
+ case Nil => result
|
|
|
+ case next :: nextTail => executeLines(head + "\n" + next :: nextTail, result)
|
|
|
+ }
|
|
|
+ case ExecuteError(_, _) =>
|
|
|
+ result
|
|
|
+
|
|
|
+ case _ =>
|
|
|
+ executeLines(tail, result)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ def executeLine(code: String) = {
|
|
|
+ code match {
|
|
|
+ case MAGIC_REGEX(magic, rest) =>
|
|
|
+ executeMagic(magic, rest)
|
|
|
+ case _ =>
|
|
|
+ scala.Console.withOut(outputStream) {
|
|
|
+ sparkIMain.interpret(code) match {
|
|
|
+ case Results.Success =>
|
|
|
+ val output = outputStream.toString("UTF-8").trim
|
|
|
+ outputStream.reset()
|
|
|
+
|
|
|
+ ExecuteComplete(executeCount - 1, output)
|
|
|
+
|
|
|
+ case Results.Incomplete =>
|
|
|
+ val output = outputStream.toString("UTF-8").trim
|
|
|
+ outputStream.reset()
|
|
|
+
|
|
|
+ ExecuteIncomplete(executeCount - 1, output)
|
|
|
+
|
|
|
+ case Results.Error =>
|
|
|
+ val output = outputStream.toString("UTF-8").trim
|
|
|
+ outputStream.reset()
|
|
|
+ ExecuteError(executeCount - 1, output)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
def shutdown(): Unit = synchronized {
|
|
|
_state = Interpreter.ShuttingDown()
|
|
|
|