Browse Source

HUE-2889 [livy] Add some magic to render server side R plots

Erick Tryzelaar 10 years ago
parent
commit
a723235

+ 1 - 1
apps/spark/java/livy-repl/src/main/resources/fake_R.sh

@@ -15,4 +15,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-exec R --no-save --interactive --quiet "$@"
+exec R --no-save --interactive --quiet --slave "$@"

+ 98 - 32
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/sparkr/SparkRInterpreter.scala

@@ -18,9 +18,12 @@
 
 package com.cloudera.hue.livy.repl.sparkr
 
+import java.io.File
+import java.nio.file.Files
 import java.util.concurrent.locks.ReentrantLock
 
 import com.cloudera.hue.livy.repl.process.ProcessInterpreter
+import org.apache.commons.codec.binary.Base64
 import org.json4s.jackson.JsonMethods._
 import org.json4s.jackson.Serialization.write
 import org.json4s.{JValue, _}
@@ -29,8 +32,9 @@ import scala.annotation.tailrec
 import scala.io.Source
 
 private object SparkRInterpreter {
-  val LIVY_END_MARKER = "# ----LIVY_END_OF_COMMAND----"
-  val EXPECTED_OUTPUT = f"> $LIVY_END_MARKER"
+  val LIVY_END_MARKER = "----LIVY_END_OF_COMMAND----"
+  val PRINT_MARKER = f"""print("$LIVY_END_MARKER")"""
+  val EXPECTED_OUTPUT = f"""\n$PRINT_MARKER\n[1] "$LIVY_END_MARKER""""
 }
 
 private class SparkRInterpreter(process: Process)
@@ -48,35 +52,96 @@ private class SparkRInterpreter(process: Process)
   }
 
   override protected def sendExecuteRequest(commands: String): Option[JValue] = synchronized {
-    commands.split("\n").map { case code =>
-      stdin.println(code)
-      stdin.println(LIVY_END_MARKER)
-      stdin.flush()
-
-      executionCount += 1
-
-      // Skip the line we just entered in.
-      if (!code.isEmpty) {
-        readTo(code)
-      }
-
-      readTo(EXPECTED_OUTPUT)
-    }.last match {
-      case (true, output) =>
-        val data = (output + takeErrorLines())
-
-        Some(parse(write(Map(
-          "status" -> "ok",
-          "execution_count" -> (executionCount - 1),
+    try {
+      commands.split("\n").map { case command =>
+        executionCount += 1
+
+        val content = sendSingleExecuteRequest(command)
+        Some(parse(write(
+          Map(
+            "status" -> "ok",
+            "execution_count" -> (executionCount - 1),
+            "data" -> content
+          ))))
+      }.last
+    } catch {
+      case e: Error =>
+        Some(parse(write(
+        Map(
+          "status" -> "error",
+          "ename" -> "Error",
+          "evalue" -> e.output,
           "data" -> Map(
-            "text/plain" -> data
+            "text/plain" -> takeErrorLines()
           )
         ))))
-      case (false, output) =>
+      case e: Exited =>
         None
     }
   }
 
+  private val plotRegex = (
+    "%(" +
+      "(?:" +
+        "(?:stripchart)|" +
+        "(?:hist)|" +
+        "(?:boxplot)|" +
+        "(?:plot)|" +
+        "(?:qqnorm)|" +
+        "(?:qqline)" +
+      ")" +
+      "\\([^;)]*\\)" +
+    ")"
+  ).r
+
+  private def sendSingleExecuteRequest(command: String) = {
+    if (command.startsWith("%")) {
+      command match {
+        case plotRegex(plotCommand) =>
+          val tempFile = Files.createTempFile("", ".png")
+          try {
+            val tempFileString = tempFile.toAbsolutePath.toString
+
+            val output = Seq(
+              f"""png("$tempFileString")""",
+              f"""$plotCommand""",
+              "dev.off()"
+            ).map { case code =>
+              sendRequest(code)
+            }.mkString("\n")
+
+            // Encode the image as a base64 image.
+            Map(
+              "image/png" -> Base64.encodeBase64String(Files.readAllBytes(tempFile))
+            )
+          } finally {
+            Files.delete(tempFile)
+          }
+        case _ =>
+          throw new Error(f"unknown magic command `$command`")
+      }
+    } else {
+      Map(
+        "text/plain" -> (sendRequest(command) + takeErrorLines())
+      )
+    }
+  }
+
+  private def sendRequest(code: String): String = {
+    stdin.println(code)
+    stdin.flush()
+
+    // Skip the line we just entered in.
+    if (!code.isEmpty) {
+      readTo(code)
+    }
+
+    stdin.println(PRINT_MARKER)
+    stdin.flush()
+
+    readTo(EXPECTED_OUTPUT)
+  }
+
   override protected def sendShutdownRequest() = {
     stdin.println("q()")
     stdin.flush()
@@ -85,26 +150,27 @@ private class SparkRInterpreter(process: Process)
   }
 
   @tailrec
-  private def readTo(marker: String, output: StringBuilder = StringBuilder.newBuilder): (Boolean, String) = {
+  private def readTo(marker: String, output: StringBuilder = StringBuilder.newBuilder): String = {
     val char = stdout.read()
     if (char == -1) {
-      (false, output.toString())
+      throw new Exited(output.toString())
     } else {
       output.append(char.toChar)
       if (output.endsWith(marker)) {
         val result = output.toString()
-        (
-          true,
-          result.substring(0, result.length - marker.length)
-            .replaceAll("\033\\[[0-9;]*[mG]", "") // Remove any ANSI color codes
-            .stripPrefix("\n")
-            .stripSuffix("\n"))
+        result.substring(0, result.length - marker.length)
+          .replaceAll("\033\\[[0-9;]*[mG]", "") // Remove any ANSI color codes
+          .stripPrefix("\n")
+          .stripSuffix("\n")
       } else {
         readTo(marker, output)
       }
     }
   }
 
+  private class Exited(val output: String) extends Exception {}
+  private class Error(val output: String) extends Exception {}
+
   private[this] val _lock = new ReentrantLock()
   private[this] var stderrLines = Seq[String]()