Browse Source

HUE-2908 [livy] Partially address multiline commands

This allows for a R snippet to be spread across multiple lines, but
it requires that the entire R snippet is a whole statement, or
multiple statements. If a statement is spread across multiple snippets,
then the repl will hang. Unfortunately the only real way to handle
this is to parse the R command for correctness, or just do the
proper thing and write a fake_shell.py-ish thing for R.
Erick Tryzelaar 10 years ago
parent
commit
3f58f865a6

+ 1 - 1
apps/spark/java/livy-repl/src/main/resources/fake_R.sh

@@ -15,4 +15,4 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-exec R --no-save --interactive --quiet --slave "$@"
+exec R --slave "$@"

+ 60 - 75
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/sparkr/SparkRInterpreter.scala

@@ -33,7 +33,29 @@ import scala.io.Source
 private object SparkRInterpreter {
 private object SparkRInterpreter {
   val LIVY_END_MARKER = "----LIVY_END_OF_COMMAND----"
   val LIVY_END_MARKER = "----LIVY_END_OF_COMMAND----"
   val PRINT_MARKER = f"""print("$LIVY_END_MARKER")"""
   val PRINT_MARKER = f"""print("$LIVY_END_MARKER")"""
-  val EXPECTED_OUTPUT = f"""$PRINT_MARKER\n[1] "$LIVY_END_MARKER""""
+  val EXPECTED_OUTPUT = f"""[1] "$LIVY_END_MARKER""""
+
+  val PLOT_REGEX = (
+    "(" +
+      "(?:bagplot)|" +
+      "(?:barplot)|" +
+      "(?:boxplot)|" +
+      "(?:dotchart)|" +
+      "(?:hist)|" +
+      "(?:lines)|" +
+      "(?:pie)|" +
+      "(?:pie3D)|" +
+      "(?:plot)|" +
+      "(?:qqline)|" +
+      "(?:qqnorm)|" +
+      "(?:scatterplot)|" +
+      "(?:scatterplot3d)|" +
+      "(?:scatterplot\\.matrix)|" +
+      "(?:splom)|" +
+      "(?:stripchart)|" +
+      "(?:vioplot)" +
+    ")"
+    ).r.unanchored
 }
 }
 
 
 private class SparkRInterpreter(process: Process)
 private class SparkRInterpreter(process: Process)
@@ -46,23 +68,46 @@ private class SparkRInterpreter(process: Process)
   private[this] var executionCount = 0
   private[this] var executionCount = 0
 
 
   final override protected def waitUntilReady(): Unit = {
   final override protected def waitUntilReady(): Unit = {
-    sendExecuteRequest("")
+    // Set the option to catch and ignore errors instead of halting.
+    sendExecuteRequest("options(error = dump.frames)")
     executionCount = 0
     executionCount = 0
   }
   }
 
 
-  override protected def sendExecuteRequest(commands: String): Option[JValue] = synchronized {
+  override protected def sendExecuteRequest(command: String): Option[JValue] = synchronized {
+    var code = command
+
+    // Create a image file if this command is trying to plot.
+    val tempFile = PLOT_REGEX.findFirstIn(code).map { case _ =>
+      val tempFile = Files.createTempFile("", ".png")
+      val tempFileString = tempFile.toAbsolutePath
+
+      code = f"""png("$tempFileString")\n$code\ndev.off()"""
+
+      tempFile
+    }
+
     try {
     try {
-      commands.split("\n").map { case command =>
-        executionCount += 1
+      executionCount += 1
 
 
-        val content = sendSingleExecuteRequest(command)
-        Some(parse(write(
-          Map(
-            "status" -> "ok",
-            "execution_count" -> (executionCount - 1),
-            "data" -> content
-          ))))
-      }.last
+      var content = Map(
+        "text/plain" -> (sendRequest(code) + takeErrorLines())
+      )
+
+      // If we rendered anything, pass along the last image.
+      tempFile.foreach { case file =>
+        val bytes = Files.readAllBytes(file)
+        if (bytes.nonEmpty) {
+          val image = Base64.encodeBase64String(bytes)
+          content = content + (("image/png", image))
+        }
+      }
+
+      Some(parse(write(
+        Map(
+          "status" -> "ok",
+          "execution_count" -> (executionCount - 1),
+          "data" -> content
+        ))))
     } catch {
     } catch {
       case e: Error =>
       case e: Error =>
         Some(parse(write(
         Some(parse(write(
@@ -76,76 +121,16 @@ private class SparkRInterpreter(process: Process)
         ))))
         ))))
       case e: Exited =>
       case e: Exited =>
         None
         None
+    } finally {
+      tempFile.foreach(Files.delete)
     }
     }
-  }
 
 
-  private val plotRegex = (
-    "%(" +
-      "(?:" +
-        "(?:bagplot)|" +
-        "(?:barplot)|" +
-        "(?:boxplot)|" +
-        "(?:dotchart)|" +
-        "(?:hist)|" +
-        "(?:lines)|" +
-        "(?:pie)|" +
-        "(?:pie3D)|" +
-        "(?:plot)|" +
-        "(?:qqline)|" +
-        "(?:qqnorm)|" +
-        "(?:scatterplot)|" +
-        "(?:scatterplot3d)|" +
-        "(?:scatterplot\\.matrix)|" +
-        "(?:splom)|" +
-        "(?:stripchart)|" +
-        "(?:vioplot)" +
-      ")" +
-      "\\([^;)]*\\)" +
-    ")"
-  ).r
-
-  private def sendSingleExecuteRequest(command: String) = {
-    if (command.startsWith("%")) {
-      command match {
-        case plotRegex(plotCommand) =>
-          val tempFile = Files.createTempFile("", ".png")
-          try {
-            val tempFileString = tempFile.toAbsolutePath.toString
-
-            val output = Seq(
-              f"""png("$tempFileString")""",
-              f"""$plotCommand""",
-              "dev.off()"
-            ).map { case code =>
-              sendRequest(code)
-            }.mkString("\n")
-
-            // Encode the image as a base64 image.
-            Map(
-              "image/png" -> Base64.encodeBase64String(Files.readAllBytes(tempFile))
-            )
-          } finally {
-            Files.delete(tempFile)
-          }
-        case _ =>
-          throw new Error(f"unknown magic command `$command`")
-      }
-    } else {
-      Map(
-        "text/plain" -> (sendRequest(command) + takeErrorLines())
-      )
-    }
   }
   }
 
 
   private def sendRequest(code: String): String = {
   private def sendRequest(code: String): String = {
     stdin.println(code)
     stdin.println(code)
     stdin.flush()
     stdin.flush()
 
 
-    // Skip the line we just entered in.
-    if (!code.isEmpty) {
-      readTo(code)
-    }
-
     stdin.println(PRINT_MARKER)
     stdin.println(PRINT_MARKER)
     stdin.flush()
     stdin.flush()
 
 

+ 1 - 1
apps/spark/java/livy-repl/src/test/scala/com/cloudera/hue/livy/repl/SparkRSessionSpec.scala

@@ -165,7 +165,7 @@ class SparkRSessionSpec extends BaseSessionSpec {
 
 
       // Manually extract since sparkr outputs a lot of spark logging information.
       // Manually extract since sparkr outputs a lot of spark logging information.
       resultMap("status").extract[String] should equal ("ok")
       resultMap("status").extract[String] should equal ("ok")
-      resultMap("execution_count").extract[Int] should equal (1)
+      resultMap("execution_count").extract[Int] should equal (0)
 
 
       val data = resultMap("data").extract[Map[String, JValue]]
       val data = resultMap("data").extract[Map[String, JValue]]
       data("text/plain").extract[String] should include ("""  eruptions waiting
       data("text/plain").extract[String] should include ("""  eruptions waiting