浏览代码

HUE-2864 [livy] Initial support for sparkr interactive sessions

This adds basic support for running sparkr interactive sessions.
It does this by shelling out to "sparkr" and parsing the output.
This has some limitations. First of all, in YARN mode it is not
yet able to reuse the application master like the scala and
python sessions, so it allocates a redundant master. Second,
it does not yet acknowledge the spark configuration settings.
Erick Tryzelaar 10 年之前
父节点
当前提交
673e676c65

+ 5 - 0
apps/spark/java/livy-core/src/main/scala/com/cloudera/hue/livy/sessions/Kind.scala

@@ -30,9 +30,14 @@ case class PySpark() extends Kind {
   override def toString = "pyspark"
 }
 
+case class SparkR() extends Kind {
+  override def toString = "sparkr"
+}
+
 case object SessionKindSerializer extends CustomSerializer[Kind](implicit formats => ( {
   case JString("spark") | JString("scala") => Spark()
   case JString("pyspark") | JString("python") => PySpark()
+  case JString("sparkr") | JString("r") => SparkR()
 }, {
   case kind: Kind => JString(kind.toString)
 }

+ 18 - 0
apps/spark/java/livy-repl/src/main/resources/fake_R.sh

@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+exec R --no-save --interactive --quiet "$@"

+ 3 - 7
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/process/ProcessInterpreter.scala

@@ -65,7 +65,7 @@ abstract class ProcessInterpreter(process: Process)
 
   protected def sendExecuteRequest(request: String): Option[JValue]
 
-  protected def sendShutdownRequest(): Option[JValue]
+  protected def sendShutdownRequest(): Unit = {}
 
   private[this] val thread = new Thread("process interpreter") {
     override def run() = {
@@ -112,11 +112,7 @@ abstract class ProcessInterpreter(process: Process)
           }
 
           try {
-            sendShutdownRequest() match {
-              case Some(rep) =>
-                warn(f"process failed to shut down while returning $rep")
-              case None =>
-            }
+            sendShutdownRequest()
 
             try {
               process.getInputStream.close()
@@ -162,7 +158,7 @@ abstract class ProcessInterpreter(process: Process)
 
     // Give ourselves 10 seconds to tear down the process.
     try {
-      Await.result(future, Duration(60, TimeUnit.SECONDS))
+      Await.result(future, Duration(10, TimeUnit.SECONDS))
     } catch {
       case e: Throwable =>
         // Make sure if there are any problems we make sure we kill the process.

+ 5 - 4
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/python/PythonInterpreter.scala

@@ -150,12 +150,13 @@ private class PythonInterpreter(process: Process, gatewayServer: GatewayServer)
     }
   }
 
-  override protected def sendShutdownRequest(): Option[JValue] = {
-    val rep = sendRequest(Map(
+  override protected def sendShutdownRequest(): Unit = {
+    sendRequest(Map(
       "msg_type" -> "shutdown_request",
       "content" -> ()
-    ))
-    rep
+    )).foreach { case rep =>
+      warn(f"process failed to shut down while returning $rep")
+    }
   }
 
   private def sendRequest(request: Map[String, Any]): Option[JValue] = {

+ 2 - 3
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/SparkSession.scala

@@ -18,14 +18,13 @@
 
 package com.cloudera.hue.livy.repl.scala
 
-import com.cloudera.hue.livy.repl.{Statement, Session}
 import com.cloudera.hue.livy.repl.scala.interpreter._
+import com.cloudera.hue.livy.repl.{Session, Statement}
 import com.cloudera.hue.livy.sessions._
+import org.json4s._
 import org.json4s.jackson.JsonMethods._
 import org.json4s.jackson.Serialization.write
-import org.json4s.{JValue, _}
 
-import scala.collection.mutable
 import scala.concurrent.{ExecutionContext, Future}
 
 object SparkSession {

+ 96 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/sparkr/SparkRInterpreter.scala

@@ -0,0 +1,96 @@
+/*
+ * Licensed to Cloudera, Inc. under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  Cloudera, Inc. licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.cloudera.hue.livy.repl.sparkr
+
+import com.cloudera.hue.livy.repl.process.ProcessInterpreter
+import org.json4s.jackson.JsonMethods._
+import org.json4s.jackson.Serialization.write
+import org.json4s.{JValue, _}
+
+import scala.annotation.tailrec
+
+private object SparkRInterpreter {
+  val LIVY_END_MARKER = "# ----LIVY_END_OF_COMMAND----"
+  val EXPECTED_OUTPUT = f"\n> $LIVY_END_MARKER"
+}
+
+private class SparkRInterpreter(process: Process)
+  extends ProcessInterpreter(process)
+{
+  import SparkRInterpreter._
+
+  implicit val formats = DefaultFormats
+
+  private var executionCount = 0
+
+  final override protected def waitUntilReady(): Unit = {
+    readTo("\n> ")
+  }
+
+  override protected def sendExecuteRequest(commands: String): Option[JValue] = synchronized {
+    commands.split("\n").map { case code =>
+      stdin.println(code)
+      stdin.println(LIVY_END_MARKER)
+      stdin.flush()
+
+      executionCount += 1
+
+      // Skip the line we just entered in.
+      if (!code.isEmpty) {
+        readTo(code)
+      }
+
+      readTo(EXPECTED_OUTPUT)
+    }.last match {
+      case (true, output) =>
+        Some(parse(write(Map(
+          "status" -> "ok",
+          "execution_count" -> (executionCount - 1),
+          "data" -> Map(
+            "text/plain" -> output
+          )
+        ))))
+      case (false, output) =>
+        None
+    }
+  }
+
+  override protected def sendShutdownRequest() = {
+    stdin.println("q()")
+    stdin.flush()
+
+    while (stdout.readLine() != null) {}
+  }
+
+  @tailrec
+  private def readTo(marker: String, output: StringBuilder = StringBuilder.newBuilder): (Boolean, String) = {
+    val char = stdout.read()
+    if (char == -1) {
+      (false, output.toString())
+    } else {
+      output.append(char.toChar)
+      if (output.endsWith(marker)) {
+        val result = output.toString()
+        (true, result.substring(0, result.length - marker.length).stripPrefix("\n"))
+      } else {
+        readTo(marker, output)
+      }
+    }
+  }
+}

+ 94 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/sparkr/SparkRSession.scala

@@ -0,0 +1,94 @@
+/*
+ * Licensed to Cloudera, Inc. under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  Cloudera, Inc. licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.cloudera.hue.livy.repl.sparkr
+
+import java.io.{File, FileOutputStream}
+import java.lang.ProcessBuilder.Redirect
+import java.nio.file.Files
+
+import com.cloudera.hue.livy.repl.{Session, Statement}
+import com.cloudera.hue.livy.sessions._
+
+import scala.collection.JavaConversions._
+
+object SparkRSession {
+  def create(): Session = {
+    val sparkrExec = sys.env.getOrElse("SPARKR_DRIVER_R", "sparkr")
+
+    val builder = new ProcessBuilder(Seq(
+      sparkrExec
+    ))
+
+    val env = builder.environment()
+    env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", "."))
+    env.put("SPARKR_DRIVER_R", createFakeShell().toString)
+
+    builder.redirectErrorStream(true)
+
+    val process = builder.start()
+
+    val interpreter = new SparkRInterpreter(process)
+
+    new SparkRSession(interpreter)
+  }
+
+  private def createFakeShell(): File = {
+    val source = getClass.getClassLoader.getResourceAsStream("fake_R.sh")
+
+    val file = Files.createTempFile("", "").toFile
+    file.deleteOnExit()
+
+    val sink = new FileOutputStream(file)
+    val buf = new Array[Byte](1024)
+    var n = source.read(buf)
+
+    while (n > 0) {
+      sink.write(buf, 0, n)
+      n = source.read(buf)
+    }
+
+    source.close()
+    sink.close()
+
+    file.setExecutable(true)
+
+    file
+  }
+}
+
+private class SparkRSession(interpreter: SparkRInterpreter) extends Session {
+  private var _history = IndexedSeq[Statement]()
+
+  override def kind: Kind = SparkR()
+
+  override def state: State = interpreter.state
+
+  override def execute(code: String): Statement = {
+    val result = interpreter.execute(code)
+    val statement = Statement(_history.length, result)
+    _history :+= statement
+    statement
+  }
+
+  override def close(): Unit = interpreter.close()
+
+  override def history: IndexedSeq[Statement] = _history
+}
+
+

+ 164 - 0
apps/spark/java/livy-repl/src/test/scala/com/cloudera/hue/livy/repl/SparkRSessionSpec.scala

@@ -0,0 +1,164 @@
+/*
+ * Licensed to Cloudera, Inc. under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  Cloudera, Inc. licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.cloudera.hue.livy.repl
+
+import com.cloudera.hue.livy.repl.sparkr.SparkRSession
+import org.json4s.Extraction
+import org.json4s.JsonAST.JValue
+
+import _root_.scala.concurrent.Await
+import _root_.scala.concurrent.duration.Duration
+
+class SparkRSessionSpec extends BaseSessionSpec {
+
+  override def createSession() = SparkRSession.create()
+
+  describe("A sparkr session") {
+    it("should execute `1 + 2` == 3") {
+      val statement = session.execute("1 + 2")
+      statement.id should equal(0)
+
+      val result = Await.result(statement.result, Duration.Inf)
+      val expectedResult = Extraction.decompose(Map(
+        "status" -> "ok",
+        "execution_count" -> 0,
+        "data" -> Map(
+          "text/plain" -> "[1] 3"
+        )
+      ))
+
+      result should equal(expectedResult)
+    }
+
+    it("should execute `x = 1`, then `y = 2`, then `x + y`") {
+      var statement = session.execute("x = 1")
+      statement.id should equal (0)
+
+      var result = Await.result(statement.result, Duration.Inf)
+      var expectedResult = Extraction.decompose(Map(
+        "status" -> "ok",
+        "execution_count" -> 0,
+        "data" -> Map(
+          "text/plain" -> ""
+        )
+      ))
+
+      result should equal (expectedResult)
+
+      statement = session.execute("y = 2")
+      statement.id should equal (1)
+
+      result = Await.result(statement.result, Duration.Inf)
+      expectedResult = Extraction.decompose(Map(
+        "status" -> "ok",
+        "execution_count" -> 1,
+        "data" -> Map(
+          "text/plain" -> ""
+        )
+      ))
+
+      result should equal (expectedResult)
+
+      statement = session.execute("x + y")
+      statement.id should equal (2)
+
+      result = Await.result(statement.result, Duration.Inf)
+      expectedResult = Extraction.decompose(Map(
+        "status" -> "ok",
+        "execution_count" -> 2,
+        "data" -> Map(
+          "text/plain" -> "[1] 3"
+        )
+      ))
+
+      result should equal (expectedResult)
+    }
+
+    it("should capture stdout") {
+      val statement = session.execute("""print('Hello World')""")
+      statement.id should equal (0)
+
+      val result = Await.result(statement.result, Duration.Inf)
+      val expectedResult = Extraction.decompose(Map(
+        "status" -> "ok",
+        "execution_count" -> 0,
+        "data" -> Map(
+          "text/plain" -> "[1] \"Hello World\""
+        )
+      ))
+
+      result should equal (expectedResult)
+    }
+
+    it("should report an error if accessing an unknown variable") {
+      val statement = session.execute("""x""")
+      statement.id should equal (0)
+
+      val result = Await.result(statement.result, Duration.Inf)
+      val expectedResult = Extraction.decompose(Map(
+        "status" -> "ok",
+        "execution_count" -> 0,
+        "data" -> Map(
+          "text/plain" -> "Error: object 'x' not found"
+        )
+      ))
+
+      result should equal (expectedResult)
+    }
+
+    it("should access the spark context") {
+      val statement = session.execute("""sc""")
+      statement.id should equal (0)
+
+      val result = Await.result(statement.result, Duration.Inf)
+      val resultMap = result.extract[Map[String, JValue]]
+
+      val expectedResult = Extraction.decompose(Map(
+        "status" -> "ok",
+        "execution_count" -> 0,
+        "data" -> Map(
+          "text/plain" -> "Java ref type org.apache.spark.api.java.JavaSparkContext id 0"
+        )
+      ))
+    }
+
+    it("should execute spark commands") {
+      val statement = session.execute("""
+                                        |head(createDataFrame(sqlContext, faithful))
+                                        |""".stripMargin)
+      statement.id should equal (0)
+
+      val result = Await.result(statement.result, Duration.Inf)
+      val resultMap = result.extract[Map[String, JValue]]
+
+      // Manually extract since sparkr outputs a lot of spark logging information.
+      resultMap("status").extract[String] should equal ("ok")
+      resultMap("execution_count").extract[Int] should equal (1)
+
+      val data = resultMap("data").extract[Map[String, JValue]]
+      data("text/plain").extract[String] should include ("""  eruptions waiting
+                                                         |1     3.600      79
+                                                         |2     1.800      54
+                                                         |3     3.333      74
+                                                         |4     2.283      62
+                                                         |5     4.533      85
+                                                         |6     2.883      55""".stripMargin)
+    }
+  }
+}