Sfoglia il codice sorgente

[livy] Simplify creating a scala/spark interpreter

Erick Tryzelaar 10 anni fa
parent
commit
f46ee77

+ 1 - 0
apps/spark/java/livy-repl/pom.xml

@@ -96,6 +96,7 @@
                     <systemProperties>
                         <spark.master>local</spark.master>
                         <spark.driver.allowMultipleContexts>true</spark.driver.allowMultipleContexts>
+                        <settings.usejavacp.value>true</settings.usejavacp.value>
                     </systemProperties>
                 </configuration>
             </plugin>

+ 36 - 18
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/Main.scala

@@ -11,6 +11,7 @@ import org.json4s.{DefaultFormats, Formats}
 import org.scalatra.LifeCycle
 import org.scalatra.servlet.ScalatraListener
 
+import _root_.scala.annotation.tailrec
 import _root_.scala.concurrent.duration._
 import _root_.scala.concurrent.{Await, ExecutionContext}
 
@@ -56,10 +57,11 @@ object Main extends Logging {
 
     server.start()
 
-    println("Starting livy-repl on port %s" format server.port)
-    System.setProperty("livy.repl.url", s"http://${server.host}:${server.port}")
-
     try {
+      val replUrl = s"http://${server.host}:${server.port}"
+      println(s"Starting livy-repl on $replUrl")
+      System.setProperty("livy.repl.url", replUrl)
+
       server.join()
       server.stop()
     } finally {
@@ -90,29 +92,45 @@ class ScalatraBootstrap extends LifeCycle with Logging {
       .orElse(sys.env.get("LIVY_CALLBACK_URL"))
 
     // See if we want to notify someone that we've started on a url
-    callbackUrl.foreach { case callbackUrl_ =>
-      info(s"Notifying $callbackUrl_ that we're up")
+    callbackUrl.foreach(notifyCallback)
+  }
+
+  override def destroy(context: ServletContext): Unit = {
+    if (session != null) {
+      Await.result(session.close(), Duration.Inf)
+    }
+  }
 
-      Future {
-        session.waitForStateChange(Session.Starting())
+  private def notifyCallback(callbackUrl: String): Unit = {
+    info(s"Notifying $callbackUrl that we're up")
 
-        val replUrl = System.getProperty("livy.repl.url")
-        var req = url(callbackUrl_).setContentType("application/json", "UTF-8")
-        req = req << write(Map("url" -> replUrl))
+    Future {
+      session.waitForStateChange(Session.Starting())
 
-        val rep = Http(req OK as.String)
-        rep.onFailure {
-          case _ => System.exit(1)
-        }
+      // Wait for our url to be discovered.
+      val replUrl = waitForReplUrl()
 
-        Await.result(rep, 10 seconds)
+      var req = url(callbackUrl).setContentType("application/json", "UTF-8")
+      req = req << write(Map("url" -> replUrl))
+
+      val rep = Http(req OK as.String)
+      rep.onFailure {
+        case _ => System.exit(1)
       }
+
+      Await.result(rep, 10 seconds)
     }
   }
 
-  override def destroy(context: ServletContext): Unit = {
-    if (session != null) {
-      Await.result(session.close(), Duration.Inf)
+  /** Spin until The server may start up  */
+  @tailrec
+  private def waitForReplUrl(): String = {
+    val replUrl = System.getProperty("livy.repl.url")
+    if (replUrl == null) {
+      Thread.sleep(10)
+      waitForReplUrl()
+    } else {
+      replUrl
     }
   }
 }

+ 1 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/Session.scala

@@ -7,6 +7,7 @@ import _root_.scala.concurrent.Future
 
 object Session {
   sealed trait State
+  case class NotStarted() extends State
   case class Starting() extends State
   case class Idle() extends State
   case class Busy() extends State

+ 1 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/WebApp.scala

@@ -27,6 +27,7 @@ class WebApp(session: Session) extends ScalatraServlet with FutureSupport with J
 
   get("/") {
     val state = session.state match {
+      case Session.NotStarted() => "not_started"
       case Session.Starting() => "starting"
       case Session.Idle() => "idle"
       case Session.Busy() => "busy"

+ 37 - 12
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/SparkSession.scala

@@ -1,7 +1,7 @@
 package com.cloudera.hue.livy.repl.scala
 
 import com.cloudera.hue.livy.repl.Session
-import com.cloudera.hue.livy.repl.scala.interpreter.Interpreter
+import com.cloudera.hue.livy.repl.scala.interpreter._
 import org.json4s.jackson.JsonMethods._
 import org.json4s.jackson.Serialization.write
 import org.json4s.{JValue, _}
@@ -20,8 +20,10 @@ private class SparkSession extends Session {
 
   private var _history = new mutable.ArrayBuffer[JValue]
   private val interpreter = new Interpreter()
+  interpreter.start()
 
   override def state: Session.State = interpreter.state match {
+    case Interpreter.NotStarted() => Session.NotStarted()
     case Interpreter.Starting() => Session.Starting()
     case Interpreter.Idle() => Session.Idle()
     case Interpreter.Busy() => Session.Busy()
@@ -39,22 +41,45 @@ private class SparkSession extends Session {
   }
 
   override def execute(code: String): Future[JValue] = {
-    interpreter.execute(code).map {
-      case rep =>
-        val content = parse(write(Map(
-          "status" -> "ok",
-          "execution_count" -> rep.executionCount,
-          "data" -> Map(
-            "text/plain" -> rep.data
+    Future {
+      val content = interpreter.execute(code) match {
+        case ExecuteComplete(executeCount, output) =>
+          Map(
+            "status" -> "ok",
+            "execution_count" -> executeCount,
+            "data" -> Map(
+              "text/plain" -> output
+            )
           )
-        )))
+        case ExecuteIncomplete(executeCount, output) =>
+          Map(
+            "status" -> "error",
+            "execution_count" -> executeCount,
+            "ename" -> "Error",
+            "evalue" -> "output",
+            "traceback" -> List()
+          )
+        case ExecuteError(executeCount, output) =>
+          Map(
+            "status" -> "error",
+            "execution_count" -> executeCount,
+            "ename" -> "Error",
+            "evalue" -> "output",
+            "traceback" -> List()
+          )
+      }
+
+      val jsonContent = parse(write(content))
 
-        _history += content
-        content
+      _history += jsonContent
+
+      jsonContent
     }
   }
 
   override def close(): Future[Unit] = {
-    interpreter.shutdown()
+    Future {
+      interpreter.shutdown()
+    }
   }
 }

+ 55 - 146
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/interpreter/Interpreter.scala

@@ -1,189 +1,98 @@
 package com.cloudera.hue.livy.repl.scala.interpreter
 
-import java.io.{StringWriter, BufferedReader, StringReader}
-import java.util.concurrent.SynchronousQueue
+import java.io._
 
-import org.apache.spark.repl.SparkILoop
+import org.apache.spark.repl.SparkIMain
+
+import scala.concurrent.ExecutionContext
+import scala.tools.nsc.Settings
+import scala.tools.nsc.interpreter.{JPrintWriter, Results}
 
-import scala.annotation.tailrec
-import scala.concurrent.{ExecutionContext, Future, Promise}
-import scala.tools.nsc.SparkHelper
-import scala.tools.nsc.interpreter.{Formatting, JPrintWriter}
-import scala.tools.nsc.util.ClassPath
 
 object Interpreter {
   sealed trait State
+  case class NotStarted() extends State
   case class Starting() extends State
   case class Idle() extends State
   case class Busy() extends State
   case class ShuttingDown() extends State
 }
 
+sealed abstract class ExecuteResponse(executeCount: Int)
+case class ExecuteComplete(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
+case class ExecuteIncomplete(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
+case class ExecuteError(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
+
 class Interpreter {
   private implicit def executor: ExecutionContext = ExecutionContext.global
 
-  private val queue = new SynchronousQueue[Request]()
-
-  // We start up the ILoop in it's own class loader because the SparkILoop store
-  // itself in a global variable.
-  private val iloop = {
-    val classLoader = new ILoopClassLoader(classOf[Interpreter].getClassLoader)
-    val cls = classLoader.loadClass(classOf[ILoop].getName)
-    val constructor = cls.getConstructor(classOf[SynchronousQueue[Request]])
-    constructor.newInstance(queue).asInstanceOf[ILoop]
-  }
-
-  // We also need to start the ILoop in it's own thread because it wants to run
-  // inside a loop.
-  private val thread = new Thread {
-    override def run() = {
-      val args = Array("-usejavacp")
-      iloop.process(args)
-    }
-  }
-
-  thread.start()
-
-  def state = iloop.state
-
-  def execute(code: String): Future[ExecuteResponse] = {
-    val promise = Promise[ExecuteResponse]()
-    queue.put(ExecuteRequest(code, promise))
-    promise.future
-  }
-
-  def shutdown(): Future[Unit] = {
-    val promise = Promise[Unit]()
-    queue.put(ShutdownRequest(promise))
-    promise.future.map({ case () => thread.join() })
-  }
-}
-
-private class ILoopClassLoader(classLoader: ClassLoader) extends ClassLoader(classLoader) { }
-
-private sealed trait Request
-private case class ExecuteRequest(code: String, promise: Promise[ExecuteResponse]) extends Request
-private case class ShutdownRequest(promise: Promise[Unit]) extends Request
-
-case class ExecuteResponse(executionCount: Int, data: String)
-
-private class ILoop(queue: SynchronousQueue[Request], outWriter: StringWriter) extends SparkILoop(
-  new BufferedReader(new StringReader("")),
-  new JPrintWriter(outWriter)
-) {
-  def this(queue: SynchronousQueue[Request]) = this(queue, new StringWriter)
-
-  var _state: Interpreter.State = Interpreter.Starting()
-
-  var _executionCount = 0
+  private var _state: Interpreter.State = Interpreter.NotStarted()
+  private val outputStream = new ByteArrayOutputStream()
+  private var sparkIMain: SparkIMain = _
+  private var executeCount = 0
 
   def state = _state
 
-  org.apache.spark.repl.Main.interp = this
+  def start() = {
+    require(_state == Interpreter.NotStarted() && sparkIMain == null)
 
-  private class ILoopInterpreter extends SparkILoopInterpreter {
-    override lazy val formatting = new Formatting {
-      def prompt = ILoop.this.prompt
-    }
-    override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
-  }
+    _state = Interpreter.Starting()
 
-  /** Create a new interpreter. */
-  override def createInterpreter() {
-    require(settings != null)
+    class InterpreterClassLoader(classLoader: ClassLoader) extends ClassLoader(classLoader) {}
+    val classLoader = new InterpreterClassLoader(classOf[Interpreter].getClassLoader)
 
-    if (addedClasspath != "") settings.classpath.append(addedClasspath)
-    // work around for Scala bug
-    val totalClassPath = SparkILoop.getAddedJars.foldLeft(
-      settings.classpath.value)((l, r) => ClassPath.join(l, r))
-    this.settings.classpath.value = totalClassPath
+    val settings = new Settings()
+    settings.usejavacp.value = true
 
-    intp = new ILoopInterpreter
-  }
+    sparkIMain = createSparkIMain(classLoader, settings)
 
-  private val replayQuestionMessage =
-    """|That entry seems to have slain the compiler.  Shall I replay
-      |your session? I can re-run each line except the last one.
-      |[y/n]
-    """.trim.stripMargin
-
-  private def crashRecovery(ex: Throwable): Boolean = {
-    echo(ex.toString)
-    ex match {
-      case _: NoSuchMethodError | _: NoClassDefFoundError =>
-        echo("\nUnrecoverable error.")
-        throw ex
-      case _  =>
-        def fn(): Boolean =
-          try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
-          catch { case _: RuntimeException => false }
-
-        if (fn()) replay()
-        else echo("\nAbandoning crashed session.")
-    }
-    true
+    _state = Interpreter.Idle()
   }
 
-  override def prompt = ""
+  private def createSparkIMain(classLoader: ClassLoader, settings: Settings) = {
+    val out = new JPrintWriter(outputStream, true)
+    val cls = classLoader.loadClass(classOf[SparkIMain].getName)
+    val constructor = cls.getConstructor(classOf[Settings], classOf[JPrintWriter], java.lang.Boolean.TYPE)
+    constructor.newInstance(settings, out, false: java.lang.Boolean).asInstanceOf[SparkIMain]
+  }
 
-  override def loop(): Unit = {
-    def readOneLine() = queue.take()
+  def execute(code: String): ExecuteResponse = {
+    synchronized {
+      executeCount += 1
 
-    // return false if repl should exit
-    def processLine(request: Request): Boolean = {
       _state = Interpreter.Busy()
 
-      if (isAsync) {
-        if (!awaitInitialized()) return false
-        runThunks()
-      }
-
-      request match {
-        case ExecuteRequest(statement, promise) =>
-          _executionCount += 1
+      val result = sparkIMain.interpret(code) match {
+        case Results.Success =>
+          val output = outputStream.toString("UTF-8").trim
+          outputStream.reset()
 
-          command(statement) match {
-            case Result(false, _) => false
-            case Result(true, finalLine) =>
-              finalLine match {
-                case Some(line) => addReplay(line)
-                case None =>
-              }
+          ExecuteComplete(executeCount - 1, output)
 
-              var output = outWriter.getBuffer.toString
+        case Results.Incomplete =>
+          val output = outputStream.toString("UTF-8").trim
+          outputStream.reset()
 
-              // Strip the trailing '\n'
-              output = output.stripSuffix("\n")
+          ExecuteIncomplete(executeCount - 1, output)
 
-              outWriter.getBuffer.setLength(0)
-
-              promise.success(ExecuteResponse(_executionCount - 1, output))
-
-              true
-          }
-        case ShutdownRequest(promise) =>
-          promise.success(())
-          false
+        case Results.Error =>
+          val output = outputStream.toString("UTF-8").trim
+          outputStream.reset()
+          ExecuteError(executeCount - 1, output)
       }
-    }
 
-    @tailrec
-    def innerLoop() {
       _state = Interpreter.Idle()
 
-      outWriter.getBuffer.setLength(0)
+      result
+    }
+  }
 
-      val shouldContinue = try {
-        processLine(readOneLine())
-      } catch {
-        case t: Throwable => crashRecovery(t)
-      }
+  def shutdown(): Unit = {
+    _state = Interpreter.ShuttingDown()
 
-      if (shouldContinue) {
-        innerLoop()
-      }
+    if (sparkIMain != null) {
+      sparkIMain.close()
+      sparkIMain = null
     }
-
-    innerLoop()
   }
 }

+ 18 - 9
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/sessions/ProcessSession.scala

@@ -1,14 +1,12 @@
 package com.cloudera.hue.livy.server.sessions
 
 import java.lang.ProcessBuilder.Redirect
-import java.net.URL
 
-import com.cloudera.hue.livy.{Utils, Logging}
-import com.cloudera.hue.livy.server.sessions.Session.SessionFailedToStart
+import com.cloudera.hue.livy.{Logging, Utils}
 
-import scala.annotation.tailrec
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.Future
-import scala.io.Source
 
 object ProcessSession extends Logging {
   def create(id: String, lang: String): Session = {
@@ -18,16 +16,27 @@ object ProcessSession extends Logging {
 
   // Loop until we've started a process with a valid port.
   private def startProcess(id: String, lang: String): Process = {
-    val pb = new ProcessBuilder(
+    val args = ArrayBuffer(
       "spark-submit",
-      "--class", "com.cloudera.hue.livy.repl.Main",
-      Utils.jarOfClass(getClass).head,
-      lang)
+      "--class",
+      "com.cloudera.hue.livy.repl.Main"
+    )
+
+    sys.env.get("LIVY_REPL_JAVA_OPTS").foreach { case javaOpts =>
+      args += "--driver-java-options"
+      args += javaOpts
+    }
+
+    args += Utils.jarOfClass(getClass).head
+    args += lang
+
+    val pb = new ProcessBuilder(args)
 
     val callbackUrl = System.getProperty("livy.server.callback-url")
     pb.environment().put("LIVY_CALLBACK_URL", f"$callbackUrl/sessions/$id/callback")
     pb.environment().put("LIVY_PORT", "0")
 
+
     pb.redirectOutput(Redirect.INHERIT)
     pb.redirectError(Redirect.INHERIT)
 

+ 1 - 0
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/sessions/Session.scala

@@ -9,6 +9,7 @@ import scala.concurrent.Future
 
 object Session {
   sealed trait State
+  case class NotStarted() extends State
   case class Starting() extends State
   case class Idle() extends State
   case class Busy() extends State

+ 1 - 0
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/sessions/ThreadSession.scala

@@ -35,6 +35,7 @@ private class ThreadSession(val id: String, session: com.cloudera.hue.livy.repl.
 
   override def state: State = {
     session.state match {
+      case repl.Session.NotStarted() => NotStarted()
       case repl.Session.Starting() => Starting()
       case repl.Session.Idle() => Idle()
       case repl.Session.Busy() => Busy()