Browse Source

[livy] Simplify creating a scala/spark interpreter

Erick Tryzelaar 10 years ago
parent
commit
f46ee77331

+ 1 - 0
apps/spark/java/livy-repl/pom.xml

@@ -96,6 +96,7 @@
                     <systemProperties>
                     <systemProperties>
                         <spark.master>local</spark.master>
                         <spark.master>local</spark.master>
                         <spark.driver.allowMultipleContexts>true</spark.driver.allowMultipleContexts>
                         <spark.driver.allowMultipleContexts>true</spark.driver.allowMultipleContexts>
+                        <settings.usejavacp.value>true</settings.usejavacp.value>
                     </systemProperties>
                     </systemProperties>
                 </configuration>
                 </configuration>
             </plugin>
             </plugin>

+ 36 - 18
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/Main.scala

@@ -11,6 +11,7 @@ import org.json4s.{DefaultFormats, Formats}
 import org.scalatra.LifeCycle
 import org.scalatra.LifeCycle
 import org.scalatra.servlet.ScalatraListener
 import org.scalatra.servlet.ScalatraListener
 
 
+import _root_.scala.annotation.tailrec
 import _root_.scala.concurrent.duration._
 import _root_.scala.concurrent.duration._
 import _root_.scala.concurrent.{Await, ExecutionContext}
 import _root_.scala.concurrent.{Await, ExecutionContext}
 
 
@@ -56,10 +57,11 @@ object Main extends Logging {
 
 
     server.start()
     server.start()
 
 
-    println("Starting livy-repl on port %s" format server.port)
-    System.setProperty("livy.repl.url", s"http://${server.host}:${server.port}")
-
     try {
     try {
+      val replUrl = s"http://${server.host}:${server.port}"
+      println(s"Starting livy-repl on $replUrl")
+      System.setProperty("livy.repl.url", replUrl)
+
       server.join()
       server.join()
       server.stop()
       server.stop()
     } finally {
     } finally {
@@ -90,29 +92,45 @@ class ScalatraBootstrap extends LifeCycle with Logging {
       .orElse(sys.env.get("LIVY_CALLBACK_URL"))
       .orElse(sys.env.get("LIVY_CALLBACK_URL"))
 
 
     // See if we want to notify someone that we've started on a url
     // See if we want to notify someone that we've started on a url
-    callbackUrl.foreach { case callbackUrl_ =>
-      info(s"Notifying $callbackUrl_ that we're up")
+    callbackUrl.foreach(notifyCallback)
+  }
+
+  override def destroy(context: ServletContext): Unit = {
+    if (session != null) {
+      Await.result(session.close(), Duration.Inf)
+    }
+  }
 
 
-      Future {
-        session.waitForStateChange(Session.Starting())
+  private def notifyCallback(callbackUrl: String): Unit = {
+    info(s"Notifying $callbackUrl that we're up")
 
 
-        val replUrl = System.getProperty("livy.repl.url")
-        var req = url(callbackUrl_).setContentType("application/json", "UTF-8")
-        req = req << write(Map("url" -> replUrl))
+    Future {
+      session.waitForStateChange(Session.Starting())
 
 
-        val rep = Http(req OK as.String)
-        rep.onFailure {
-          case _ => System.exit(1)
-        }
+      // Wait for our url to be discovered.
+      val replUrl = waitForReplUrl()
 
 
-        Await.result(rep, 10 seconds)
+      var req = url(callbackUrl).setContentType("application/json", "UTF-8")
+      req = req << write(Map("url" -> replUrl))
+
+      val rep = Http(req OK as.String)
+      rep.onFailure {
+        case _ => System.exit(1)
       }
       }
+
+      Await.result(rep, 10 seconds)
     }
     }
   }
   }
 
 
-  override def destroy(context: ServletContext): Unit = {
-    if (session != null) {
-      Await.result(session.close(), Duration.Inf)
+  /** Spin until The server may start up  */
+  @tailrec
+  private def waitForReplUrl(): String = {
+    val replUrl = System.getProperty("livy.repl.url")
+    if (replUrl == null) {
+      Thread.sleep(10)
+      waitForReplUrl()
+    } else {
+      replUrl
     }
     }
   }
   }
 }
 }

+ 1 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/Session.scala

@@ -7,6 +7,7 @@ import _root_.scala.concurrent.Future
 
 
 object Session {
 object Session {
   sealed trait State
   sealed trait State
+  case class NotStarted() extends State
   case class Starting() extends State
   case class Starting() extends State
   case class Idle() extends State
   case class Idle() extends State
   case class Busy() extends State
   case class Busy() extends State

+ 1 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/WebApp.scala

@@ -27,6 +27,7 @@ class WebApp(session: Session) extends ScalatraServlet with FutureSupport with J
 
 
   get("/") {
   get("/") {
     val state = session.state match {
     val state = session.state match {
+      case Session.NotStarted() => "not_started"
       case Session.Starting() => "starting"
       case Session.Starting() => "starting"
       case Session.Idle() => "idle"
       case Session.Idle() => "idle"
       case Session.Busy() => "busy"
       case Session.Busy() => "busy"

+ 37 - 12
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/SparkSession.scala

@@ -1,7 +1,7 @@
 package com.cloudera.hue.livy.repl.scala
 package com.cloudera.hue.livy.repl.scala
 
 
 import com.cloudera.hue.livy.repl.Session
 import com.cloudera.hue.livy.repl.Session
-import com.cloudera.hue.livy.repl.scala.interpreter.Interpreter
+import com.cloudera.hue.livy.repl.scala.interpreter._
 import org.json4s.jackson.JsonMethods._
 import org.json4s.jackson.JsonMethods._
 import org.json4s.jackson.Serialization.write
 import org.json4s.jackson.Serialization.write
 import org.json4s.{JValue, _}
 import org.json4s.{JValue, _}
@@ -20,8 +20,10 @@ private class SparkSession extends Session {
 
 
   private var _history = new mutable.ArrayBuffer[JValue]
   private var _history = new mutable.ArrayBuffer[JValue]
   private val interpreter = new Interpreter()
   private val interpreter = new Interpreter()
+  interpreter.start()
 
 
   override def state: Session.State = interpreter.state match {
   override def state: Session.State = interpreter.state match {
+    case Interpreter.NotStarted() => Session.NotStarted()
     case Interpreter.Starting() => Session.Starting()
     case Interpreter.Starting() => Session.Starting()
     case Interpreter.Idle() => Session.Idle()
     case Interpreter.Idle() => Session.Idle()
     case Interpreter.Busy() => Session.Busy()
     case Interpreter.Busy() => Session.Busy()
@@ -39,22 +41,45 @@ private class SparkSession extends Session {
   }
   }
 
 
   override def execute(code: String): Future[JValue] = {
   override def execute(code: String): Future[JValue] = {
-    interpreter.execute(code).map {
-      case rep =>
-        val content = parse(write(Map(
-          "status" -> "ok",
-          "execution_count" -> rep.executionCount,
-          "data" -> Map(
-            "text/plain" -> rep.data
+    Future {
+      val content = interpreter.execute(code) match {
+        case ExecuteComplete(executeCount, output) =>
+          Map(
+            "status" -> "ok",
+            "execution_count" -> executeCount,
+            "data" -> Map(
+              "text/plain" -> output
+            )
           )
           )
-        )))
+        case ExecuteIncomplete(executeCount, output) =>
+          Map(
+            "status" -> "error",
+            "execution_count" -> executeCount,
+            "ename" -> "Error",
+            "evalue" -> "output",
+            "traceback" -> List()
+          )
+        case ExecuteError(executeCount, output) =>
+          Map(
+            "status" -> "error",
+            "execution_count" -> executeCount,
+            "ename" -> "Error",
+            "evalue" -> "output",
+            "traceback" -> List()
+          )
+      }
+
+      val jsonContent = parse(write(content))
 
 
-        _history += content
-        content
+      _history += jsonContent
+
+      jsonContent
     }
     }
   }
   }
 
 
   override def close(): Future[Unit] = {
   override def close(): Future[Unit] = {
-    interpreter.shutdown()
+    Future {
+      interpreter.shutdown()
+    }
   }
   }
 }
 }

+ 55 - 146
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/interpreter/Interpreter.scala

@@ -1,189 +1,98 @@
 package com.cloudera.hue.livy.repl.scala.interpreter
 package com.cloudera.hue.livy.repl.scala.interpreter
 
 
-import java.io.{StringWriter, BufferedReader, StringReader}
-import java.util.concurrent.SynchronousQueue
+import java.io._
 
 
-import org.apache.spark.repl.SparkILoop
+import org.apache.spark.repl.SparkIMain
+
+import scala.concurrent.ExecutionContext
+import scala.tools.nsc.Settings
+import scala.tools.nsc.interpreter.{JPrintWriter, Results}
 
 
-import scala.annotation.tailrec
-import scala.concurrent.{ExecutionContext, Future, Promise}
-import scala.tools.nsc.SparkHelper
-import scala.tools.nsc.interpreter.{Formatting, JPrintWriter}
-import scala.tools.nsc.util.ClassPath
 
 
 object Interpreter {
 object Interpreter {
   sealed trait State
   sealed trait State
+  case class NotStarted() extends State
   case class Starting() extends State
   case class Starting() extends State
   case class Idle() extends State
   case class Idle() extends State
   case class Busy() extends State
   case class Busy() extends State
   case class ShuttingDown() extends State
   case class ShuttingDown() extends State
 }
 }
 
 
+sealed abstract class ExecuteResponse(executeCount: Int)
+case class ExecuteComplete(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
+case class ExecuteIncomplete(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
+case class ExecuteError(executeCount: Int, output: String) extends ExecuteResponse(executeCount)
+
 class Interpreter {
 class Interpreter {
   private implicit def executor: ExecutionContext = ExecutionContext.global
   private implicit def executor: ExecutionContext = ExecutionContext.global
 
 
-  private val queue = new SynchronousQueue[Request]()
-
-  // We start up the ILoop in it's own class loader because the SparkILoop store
-  // itself in a global variable.
-  private val iloop = {
-    val classLoader = new ILoopClassLoader(classOf[Interpreter].getClassLoader)
-    val cls = classLoader.loadClass(classOf[ILoop].getName)
-    val constructor = cls.getConstructor(classOf[SynchronousQueue[Request]])
-    constructor.newInstance(queue).asInstanceOf[ILoop]
-  }
-
-  // We also need to start the ILoop in it's own thread because it wants to run
-  // inside a loop.
-  private val thread = new Thread {
-    override def run() = {
-      val args = Array("-usejavacp")
-      iloop.process(args)
-    }
-  }
-
-  thread.start()
-
-  def state = iloop.state
-
-  def execute(code: String): Future[ExecuteResponse] = {
-    val promise = Promise[ExecuteResponse]()
-    queue.put(ExecuteRequest(code, promise))
-    promise.future
-  }
-
-  def shutdown(): Future[Unit] = {
-    val promise = Promise[Unit]()
-    queue.put(ShutdownRequest(promise))
-    promise.future.map({ case () => thread.join() })
-  }
-}
-
-private class ILoopClassLoader(classLoader: ClassLoader) extends ClassLoader(classLoader) { }
-
-private sealed trait Request
-private case class ExecuteRequest(code: String, promise: Promise[ExecuteResponse]) extends Request
-private case class ShutdownRequest(promise: Promise[Unit]) extends Request
-
-case class ExecuteResponse(executionCount: Int, data: String)
-
-private class ILoop(queue: SynchronousQueue[Request], outWriter: StringWriter) extends SparkILoop(
-  new BufferedReader(new StringReader("")),
-  new JPrintWriter(outWriter)
-) {
-  def this(queue: SynchronousQueue[Request]) = this(queue, new StringWriter)
-
-  var _state: Interpreter.State = Interpreter.Starting()
-
-  var _executionCount = 0
+  private var _state: Interpreter.State = Interpreter.NotStarted()
+  private val outputStream = new ByteArrayOutputStream()
+  private var sparkIMain: SparkIMain = _
+  private var executeCount = 0
 
 
   def state = _state
   def state = _state
 
 
-  org.apache.spark.repl.Main.interp = this
+  def start() = {
+    require(_state == Interpreter.NotStarted() && sparkIMain == null)
 
 
-  private class ILoopInterpreter extends SparkILoopInterpreter {
-    override lazy val formatting = new Formatting {
-      def prompt = ILoop.this.prompt
-    }
-    override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
-  }
+    _state = Interpreter.Starting()
 
 
-  /** Create a new interpreter. */
-  override def createInterpreter() {
-    require(settings != null)
+    class InterpreterClassLoader(classLoader: ClassLoader) extends ClassLoader(classLoader) {}
+    val classLoader = new InterpreterClassLoader(classOf[Interpreter].getClassLoader)
 
 
-    if (addedClasspath != "") settings.classpath.append(addedClasspath)
-    // work around for Scala bug
-    val totalClassPath = SparkILoop.getAddedJars.foldLeft(
-      settings.classpath.value)((l, r) => ClassPath.join(l, r))
-    this.settings.classpath.value = totalClassPath
+    val settings = new Settings()
+    settings.usejavacp.value = true
 
 
-    intp = new ILoopInterpreter
-  }
+    sparkIMain = createSparkIMain(classLoader, settings)
 
 
-  private val replayQuestionMessage =
-    """|That entry seems to have slain the compiler.  Shall I replay
-      |your session? I can re-run each line except the last one.
-      |[y/n]
-    """.trim.stripMargin
-
-  private def crashRecovery(ex: Throwable): Boolean = {
-    echo(ex.toString)
-    ex match {
-      case _: NoSuchMethodError | _: NoClassDefFoundError =>
-        echo("\nUnrecoverable error.")
-        throw ex
-      case _  =>
-        def fn(): Boolean =
-          try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
-          catch { case _: RuntimeException => false }
-
-        if (fn()) replay()
-        else echo("\nAbandoning crashed session.")
-    }
-    true
+    _state = Interpreter.Idle()
   }
   }
 
 
-  override def prompt = ""
+  private def createSparkIMain(classLoader: ClassLoader, settings: Settings) = {
+    val out = new JPrintWriter(outputStream, true)
+    val cls = classLoader.loadClass(classOf[SparkIMain].getName)
+    val constructor = cls.getConstructor(classOf[Settings], classOf[JPrintWriter], java.lang.Boolean.TYPE)
+    constructor.newInstance(settings, out, false: java.lang.Boolean).asInstanceOf[SparkIMain]
+  }
 
 
-  override def loop(): Unit = {
-    def readOneLine() = queue.take()
+  def execute(code: String): ExecuteResponse = {
+    synchronized {
+      executeCount += 1
 
 
-    // return false if repl should exit
-    def processLine(request: Request): Boolean = {
       _state = Interpreter.Busy()
       _state = Interpreter.Busy()
 
 
-      if (isAsync) {
-        if (!awaitInitialized()) return false
-        runThunks()
-      }
-
-      request match {
-        case ExecuteRequest(statement, promise) =>
-          _executionCount += 1
+      val result = sparkIMain.interpret(code) match {
+        case Results.Success =>
+          val output = outputStream.toString("UTF-8").trim
+          outputStream.reset()
 
 
-          command(statement) match {
-            case Result(false, _) => false
-            case Result(true, finalLine) =>
-              finalLine match {
-                case Some(line) => addReplay(line)
-                case None =>
-              }
+          ExecuteComplete(executeCount - 1, output)
 
 
-              var output = outWriter.getBuffer.toString
+        case Results.Incomplete =>
+          val output = outputStream.toString("UTF-8").trim
+          outputStream.reset()
 
 
-              // Strip the trailing '\n'
-              output = output.stripSuffix("\n")
+          ExecuteIncomplete(executeCount - 1, output)
 
 
-              outWriter.getBuffer.setLength(0)
-
-              promise.success(ExecuteResponse(_executionCount - 1, output))
-
-              true
-          }
-        case ShutdownRequest(promise) =>
-          promise.success(())
-          false
+        case Results.Error =>
+          val output = outputStream.toString("UTF-8").trim
+          outputStream.reset()
+          ExecuteError(executeCount - 1, output)
       }
       }
-    }
 
 
-    @tailrec
-    def innerLoop() {
       _state = Interpreter.Idle()
       _state = Interpreter.Idle()
 
 
-      outWriter.getBuffer.setLength(0)
+      result
+    }
+  }
 
 
-      val shouldContinue = try {
-        processLine(readOneLine())
-      } catch {
-        case t: Throwable => crashRecovery(t)
-      }
+  def shutdown(): Unit = {
+    _state = Interpreter.ShuttingDown()
 
 
-      if (shouldContinue) {
-        innerLoop()
-      }
+    if (sparkIMain != null) {
+      sparkIMain.close()
+      sparkIMain = null
     }
     }
-
-    innerLoop()
   }
   }
 }
 }

+ 18 - 9
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/sessions/ProcessSession.scala

@@ -1,14 +1,12 @@
 package com.cloudera.hue.livy.server.sessions
 package com.cloudera.hue.livy.server.sessions
 
 
 import java.lang.ProcessBuilder.Redirect
 import java.lang.ProcessBuilder.Redirect
-import java.net.URL
 
 
-import com.cloudera.hue.livy.{Utils, Logging}
-import com.cloudera.hue.livy.server.sessions.Session.SessionFailedToStart
+import com.cloudera.hue.livy.{Logging, Utils}
 
 
-import scala.annotation.tailrec
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.Future
 import scala.concurrent.Future
-import scala.io.Source
 
 
 object ProcessSession extends Logging {
 object ProcessSession extends Logging {
   def create(id: String, lang: String): Session = {
   def create(id: String, lang: String): Session = {
@@ -18,16 +16,27 @@ object ProcessSession extends Logging {
 
 
   // Loop until we've started a process with a valid port.
   // Loop until we've started a process with a valid port.
   private def startProcess(id: String, lang: String): Process = {
   private def startProcess(id: String, lang: String): Process = {
-    val pb = new ProcessBuilder(
+    val args = ArrayBuffer(
       "spark-submit",
       "spark-submit",
-      "--class", "com.cloudera.hue.livy.repl.Main",
-      Utils.jarOfClass(getClass).head,
-      lang)
+      "--class",
+      "com.cloudera.hue.livy.repl.Main"
+    )
+
+    sys.env.get("LIVY_REPL_JAVA_OPTS").foreach { case javaOpts =>
+      args += "--driver-java-options"
+      args += javaOpts
+    }
+
+    args += Utils.jarOfClass(getClass).head
+    args += lang
+
+    val pb = new ProcessBuilder(args)
 
 
     val callbackUrl = System.getProperty("livy.server.callback-url")
     val callbackUrl = System.getProperty("livy.server.callback-url")
     pb.environment().put("LIVY_CALLBACK_URL", f"$callbackUrl/sessions/$id/callback")
     pb.environment().put("LIVY_CALLBACK_URL", f"$callbackUrl/sessions/$id/callback")
     pb.environment().put("LIVY_PORT", "0")
     pb.environment().put("LIVY_PORT", "0")
 
 
+
     pb.redirectOutput(Redirect.INHERIT)
     pb.redirectOutput(Redirect.INHERIT)
     pb.redirectError(Redirect.INHERIT)
     pb.redirectError(Redirect.INHERIT)
 
 

+ 1 - 0
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/sessions/Session.scala

@@ -9,6 +9,7 @@ import scala.concurrent.Future
 
 
 object Session {
 object Session {
   sealed trait State
   sealed trait State
+  case class NotStarted() extends State
   case class Starting() extends State
   case class Starting() extends State
   case class Idle() extends State
   case class Idle() extends State
   case class Busy() extends State
   case class Busy() extends State

+ 1 - 0
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/sessions/ThreadSession.scala

@@ -35,6 +35,7 @@ private class ThreadSession(val id: String, session: com.cloudera.hue.livy.repl.
 
 
   override def state: State = {
   override def state: State = {
     session.state match {
     session.state match {
+      case repl.Session.NotStarted() => NotStarted()
       case repl.Session.Starting() => Starting()
       case repl.Session.Starting() => Starting()
       case repl.Session.Idle() => Idle()
       case repl.Session.Idle() => Idle()
       case repl.Session.Busy() => Busy()
       case repl.Session.Busy() => Busy()