|
|
@@ -18,9 +18,12 @@
|
|
|
|
|
|
package com.cloudera.hue.livy.repl.sparkr
|
|
|
|
|
|
+import java.io.File
|
|
|
+import java.nio.file.Files
|
|
|
import java.util.concurrent.locks.ReentrantLock
|
|
|
|
|
|
import com.cloudera.hue.livy.repl.process.ProcessInterpreter
|
|
|
+import org.apache.commons.codec.binary.Base64
|
|
|
import org.json4s.jackson.JsonMethods._
|
|
|
import org.json4s.jackson.Serialization.write
|
|
|
import org.json4s.{JValue, _}
|
|
|
@@ -29,8 +32,9 @@ import scala.annotation.tailrec
|
|
|
import scala.io.Source
|
|
|
|
|
|
private object SparkRInterpreter {
|
|
|
- val LIVY_END_MARKER = "# ----LIVY_END_OF_COMMAND----"
|
|
|
- val EXPECTED_OUTPUT = f"> $LIVY_END_MARKER"
|
|
|
+ val LIVY_END_MARKER = "----LIVY_END_OF_COMMAND----"
|
|
|
+ val PRINT_MARKER = f"""print("$LIVY_END_MARKER")"""
|
|
|
+ val EXPECTED_OUTPUT = f"""\n$PRINT_MARKER\n[1] "$LIVY_END_MARKER""""
|
|
|
}
|
|
|
|
|
|
private class SparkRInterpreter(process: Process)
|
|
|
@@ -48,35 +52,96 @@ private class SparkRInterpreter(process: Process)
|
|
|
}
|
|
|
|
|
|
override protected def sendExecuteRequest(commands: String): Option[JValue] = synchronized {
|
|
|
- commands.split("\n").map { case code =>
|
|
|
- stdin.println(code)
|
|
|
- stdin.println(LIVY_END_MARKER)
|
|
|
- stdin.flush()
|
|
|
-
|
|
|
- executionCount += 1
|
|
|
-
|
|
|
- // Skip the line we just entered in.
|
|
|
- if (!code.isEmpty) {
|
|
|
- readTo(code)
|
|
|
- }
|
|
|
-
|
|
|
- readTo(EXPECTED_OUTPUT)
|
|
|
- }.last match {
|
|
|
- case (true, output) =>
|
|
|
- val data = (output + takeErrorLines())
|
|
|
-
|
|
|
- Some(parse(write(Map(
|
|
|
- "status" -> "ok",
|
|
|
- "execution_count" -> (executionCount - 1),
|
|
|
+ try {
|
|
|
+ commands.split("\n").map { case command =>
|
|
|
+ executionCount += 1
|
|
|
+
|
|
|
+ val content = sendSingleExecuteRequest(command)
|
|
|
+ Some(parse(write(
|
|
|
+ Map(
|
|
|
+ "status" -> "ok",
|
|
|
+ "execution_count" -> (executionCount - 1),
|
|
|
+ "data" -> content
|
|
|
+ ))))
|
|
|
+ }.last
|
|
|
+ } catch {
|
|
|
+ case e: Error =>
|
|
|
+ Some(parse(write(
|
|
|
+ Map(
|
|
|
+ "status" -> "error",
|
|
|
+ "ename" -> "Error",
|
|
|
+ "evalue" -> e.output,
|
|
|
"data" -> Map(
|
|
|
- "text/plain" -> data
|
|
|
+ "text/plain" -> takeErrorLines()
|
|
|
)
|
|
|
))))
|
|
|
- case (false, output) =>
|
|
|
+ case e: Exited =>
|
|
|
None
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private val plotRegex = (
|
|
|
+ "%(" +
|
|
|
+ "(?:" +
|
|
|
+ "(?:stripchart)|" +
|
|
|
+ "(?:hist)|" +
|
|
|
+ "(?:boxplot)|" +
|
|
|
+ "(?:plot)|" +
|
|
|
+ "(?:qqnorm)|" +
|
|
|
+ "(?:qqline)" +
|
|
|
+ ")" +
|
|
|
+ "\\([^;)]*\\)" +
|
|
|
+ ")"
|
|
|
+ ).r
|
|
|
+
|
|
|
+ private def sendSingleExecuteRequest(command: String) = {
|
|
|
+ if (command.startsWith("%")) {
|
|
|
+ command match {
|
|
|
+ case plotRegex(plotCommand) =>
|
|
|
+ val tempFile = Files.createTempFile("", ".png")
|
|
|
+ try {
|
|
|
+ val tempFileString = tempFile.toAbsolutePath.toString
|
|
|
+
|
|
|
+ val output = Seq(
|
|
|
+ f"""png("$tempFileString")""",
|
|
|
+ f"""$plotCommand""",
|
|
|
+ "dev.off()"
|
|
|
+ ).map { case code =>
|
|
|
+ sendRequest(code)
|
|
|
+ }.mkString("\n")
|
|
|
+
|
|
|
+ // Encode the image as a base64 image.
|
|
|
+ Map(
|
|
|
+ "image/png" -> Base64.encodeBase64String(Files.readAllBytes(tempFile))
|
|
|
+ )
|
|
|
+ } finally {
|
|
|
+ Files.delete(tempFile)
|
|
|
+ }
|
|
|
+ case _ =>
|
|
|
+ throw new Error(f"unknown magic command `$command`")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ Map(
|
|
|
+ "text/plain" -> (sendRequest(command) + takeErrorLines())
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private def sendRequest(code: String): String = {
|
|
|
+ stdin.println(code)
|
|
|
+ stdin.flush()
|
|
|
+
|
|
|
+ // Skip the line we just entered in.
|
|
|
+ if (!code.isEmpty) {
|
|
|
+ readTo(code)
|
|
|
+ }
|
|
|
+
|
|
|
+ stdin.println(PRINT_MARKER)
|
|
|
+ stdin.flush()
|
|
|
+
|
|
|
+ readTo(EXPECTED_OUTPUT)
|
|
|
+ }
|
|
|
+
|
|
|
override protected def sendShutdownRequest() = {
|
|
|
stdin.println("q()")
|
|
|
stdin.flush()
|
|
|
@@ -85,26 +150,27 @@ private class SparkRInterpreter(process: Process)
|
|
|
}
|
|
|
|
|
|
@tailrec
|
|
|
- private def readTo(marker: String, output: StringBuilder = StringBuilder.newBuilder): (Boolean, String) = {
|
|
|
+ private def readTo(marker: String, output: StringBuilder = StringBuilder.newBuilder): String = {
|
|
|
val char = stdout.read()
|
|
|
if (char == -1) {
|
|
|
- (false, output.toString())
|
|
|
+ throw new Exited(output.toString())
|
|
|
} else {
|
|
|
output.append(char.toChar)
|
|
|
if (output.endsWith(marker)) {
|
|
|
val result = output.toString()
|
|
|
- (
|
|
|
- true,
|
|
|
- result.substring(0, result.length - marker.length)
|
|
|
- .replaceAll("\033\\[[0-9;]*[mG]", "") // Remove any ANSI color codes
|
|
|
- .stripPrefix("\n")
|
|
|
- .stripSuffix("\n"))
|
|
|
+ result.substring(0, result.length - marker.length)
|
|
|
+ .replaceAll("\033\\[[0-9;]*[mG]", "") // Remove any ANSI color codes
|
|
|
+ .stripPrefix("\n")
|
|
|
+ .stripSuffix("\n")
|
|
|
} else {
|
|
|
readTo(marker, output)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private class Exited(val output: String) extends Exception {}
|
|
|
+ private class Error(val output: String) extends Exception {}
|
|
|
+
|
|
|
private[this] val _lock = new ReentrantLock()
|
|
|
private[this] var stderrLines = Seq[String]()
|
|
|
|