瀏覽代碼

[livy] Factor out running a session interpreter from PythonInterpreter

Erick Tryzelaar 10 年之前
父節點
當前提交
6f86bb2b1a

+ 174 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/process/ProcessInterpreter.scala

@@ -0,0 +1,174 @@
+/*
+ * Licensed to Cloudera, Inc. under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  Cloudera, Inc. licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.cloudera.hue.livy.repl.process
+
+import java.io.{BufferedReader, IOException, InputStreamReader, PrintWriter}
+import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
+
+import com.cloudera.hue.livy.Logging
+import com.cloudera.hue.livy.repl.Interpreter
+import com.cloudera.hue.livy.sessions._
+import org.json4s.JValue
+
+import scala.annotation.tailrec
+import scala.concurrent.duration.Duration
+import scala.concurrent.{Await, ExecutionContext, Future, Promise}
+import scala.util
+
+private sealed trait Request
+private case class ExecuteRequest(code: String, promise: Promise[JValue]) extends Request
+private case class ShutdownRequest(promise: Promise[Unit]) extends Request
+
+abstract class ProcessInterpreter(process: Process)
+  extends Interpreter
+  with Logging
+{
+  implicit val executor: ExecutionContext = ExecutionContext.global
+
+  protected[this] var _state: State = Starting()
+
+  protected[this] val stdin = new PrintWriter(process.getOutputStream)
+  protected[this] val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1)
+
+  private[this] val _queue = new LinkedBlockingQueue[Request]
+
+  override def state: State = _state
+
+  override def execute(code: String): Future[JValue] = {
+    _state match {
+      case (Dead() | ShuttingDown() | Error()) =>
+        Future.failed(new IllegalStateException("interpreter is not running"))
+      case _ =>
+        val promise = Promise[JValue]()
+        _queue.add(ExecuteRequest(code, promise))
+        promise.future
+    }
+  }
+
+  protected def waitUntilReady(): Unit
+
+  protected def sendExecuteRequest(request: String): Option[JValue]
+
+  protected def sendShutdownRequest(): Option[JValue]
+
+  private[this] val thread = new Thread("process interpreter") {
+    override def run() = {
+      waitUntilReady()
+
+      _state = Idle()
+
+      loop()
+    }
+
+    @tailrec
+    private def loop(): Unit = {
+      (_state, _queue.take()) match {
+        case (Error(), ExecuteRequest(code, promise)) =>
+          promise.failure(new Exception("session has been terminated"))
+          loop()
+
+        case (state, ExecuteRequest(code, promise)) =>
+          require(state == Idle())
+
+          _state = Busy()
+
+          sendExecuteRequest(code) match {
+            case Some(rep) =>
+              synchronized {
+                _state = Idle()
+              }
+
+              promise.success(rep)
+            case None =>
+              synchronized {
+                _state = Error()
+              }
+
+              promise.failure(new Exception("session has been terminated"))
+          }
+          loop()
+
+        case (_, ShutdownRequest(promise)) =>
+          require(state == Idle() || state == Error())
+
+          synchronized {
+            _state = ShuttingDown()
+          }
+
+          try {
+            sendShutdownRequest() match {
+              case Some(rep) =>
+                warn(f"process failed to shut down while returning $rep")
+              case None =>
+            }
+
+            try {
+              process.getInputStream.close()
+              process.getOutputStream.close()
+            } catch {
+              case _: IOException =>
+            }
+
+            try {
+              process.destroy()
+            } finally {
+              synchronized {
+                _state = Dead()
+              }
+
+              promise.success(())
+            }
+          }
+      }
+    }
+  }
+
+  thread.start()
+
+  override def close(): Unit = {
+    val future = synchronized {
+      _state match {
+        case (Dead() | ShuttingDown()) =>
+          Future.successful()
+        case _ =>
+          val promise = Promise[Unit]()
+          _queue.add(ShutdownRequest(promise))
+
+          promise.future.andThen {
+            case util.Success(_) =>
+              thread.join()
+            case util.Failure(_) =>
+              thread.interrupt()
+              thread.join()
+          }
+      }
+    }
+
+    // Give ourselves 10 seconds to tear down the process.
+    try {
+      Await.result(future, Duration(60, TimeUnit.SECONDS))
+    } catch {
+      case e: Throwable =>
+        // Make sure if there are any problems we make sure we kill the process.
+        process.destroy()
+        thread.interrupt()
+        throw e
+    }
+  }
+}

+ 27 - 136
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/python/PythonInterpreter.scala

@@ -21,23 +21,20 @@ package com.cloudera.hue.livy.repl.python
 import java.io._
 import java.lang.ProcessBuilder.Redirect
 import java.nio.file.Files
-import java.util.concurrent.{LinkedBlockingQueue, ConcurrentLinkedQueue, SynchronousQueue, TimeUnit}
 
 import com.cloudera.hue.livy.repl.Interpreter
-import com.cloudera.hue.livy.sessions._
+import com.cloudera.hue.livy.repl.process.ProcessInterpreter
 import com.cloudera.hue.livy.{Logging, Utils}
 import org.apache.spark.SparkContext
-import org.json4s.{DefaultFormats, JValue}
 import org.json4s.jackson.JsonMethods._
 import org.json4s.jackson.Serialization.write
+import org.json4s.{DefaultFormats, JValue}
 import py4j.GatewayServer
 
 import scala.annotation.tailrec
 import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.duration.Duration
-import scala.concurrent.{ExecutionContext, Await, Future, Promise}
-import scala.util
+import scala.concurrent.ExecutionContext
 
 object PythonInterpreter {
   def create(): Interpreter = {
@@ -120,33 +117,21 @@ object PythonInterpreter {
 }
 
 private class PythonInterpreter(process: Process, gatewayServer: GatewayServer)
-  extends Interpreter
+  extends ProcessInterpreter(process)
   with Logging
 {
-  implicit val executor: ExecutionContext = ExecutionContext.global
   implicit val formats = DefaultFormats
 
-  private val stdin = new PrintWriter(process.getOutputStream)
-  private val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1)
-
-  private[this] var _state: State = Starting()
-  private[this] val _queue = new LinkedBlockingQueue[Request]
-
-  override def state: State = _state
-
-  override def execute(code: String): Future[JValue] = synchronized {
-    _state match {
-      case (Dead() | ShuttingDown() | Error()) =>
-        Future.failed(new IllegalStateException("interpreter is not running"))
-      case _ =>
-        val promise = Promise[JValue]()
-        _queue.add(ExecuteRequest(code, promise))
-        promise.future
+  override def close(): Unit = {
+    try {
+      super.close()
+    } finally {
+      gatewayServer.shutdown()
     }
   }
 
   @tailrec
-  private def waitUntilReady(): Unit = {
+  final override protected def waitUntilReady(): Unit = {
     val line = stdout.readLine()
     line match {
       case null | "READY" =>
@@ -154,123 +139,29 @@ private class PythonInterpreter(process: Process, gatewayServer: GatewayServer)
     }
   }
 
-  private[this] val thread = new Thread {
-    override def run() = {
-      waitUntilReady()
-
-      _state = Idle()
+  override protected def sendExecuteRequest(code: String): Option[JValue] = {
+    val rep = sendRequest(Map("msg_type" -> "execute_request", "content" -> Map("code" -> code)))
+    rep.map { case rep =>
+      assert((rep \ "msg_type").extract[String] == "execute_reply")
 
-      loop()
-    }
+      val content: JValue = rep \ "content"
 
-    @tailrec
-    private def waitUntilReady(): Unit = {
-      val line = stdout.readLine()
-      line match {
-        case null | "READY" =>
-        case _ => waitUntilReady()
-      }
-    }
-
-    private def sendRequest(request: Map[String, Any]): Option[JValue] = {
-      stdin.println(write(request))
-      stdin.flush()
-
-      Option(stdout.readLine()).map { case line => parse(line) }
-    }
-
-    @tailrec
-    def loop(): Unit = {
-      (_state, _queue.take()) match {
-        case (Error(), ExecuteRequest(code, promise)) =>
-          promise.failure(new Exception("session has been terminated"))
-          loop()
-
-        case (state, ExecuteRequest(code, promise)) =>
-          require(state == Idle())
-
-          _state = Busy()
-
-          sendRequest(Map("msg_type" -> "execute_request", "content" -> Map("code" -> code))) match {
-            case Some(rep) =>
-              assert((rep \ "msg_type").extract[String] == "execute_reply")
-
-              val content: JValue = rep \ "content"
-
-              synchronized {
-                _state = Idle()
-              }
-
-              promise.success(content)
-              loop()
-            case None =>
-              synchronized {
-                _state = Error()
-              }
-
-              promise.failure(new Exception("session has been terminated"))
-          }
-
-        case (_, ShutdownRequest(promise)) =>
-          require(state == Idle() || state == Error())
-
-          synchronized {
-            _state = ShuttingDown()
-          }
-
-          try {
-            sendRequest(Map("msg_type" -> "shutdown_request", "content" ->())) match {
-              case Some(rep) =>
-                warn(f"process failed to shut down while returning $rep")
-              case None =>
-            }
-
-            // Ignore IO errors, such as if the stream is already closed.
-            try {
-              process.getInputStream.close()
-              process.getOutputStream.close()
-            } catch {
-              case _: IOException =>
-            }
-
-            try {
-              process.destroy()
-            } finally {
-              synchronized {
-                _state = Dead()
-              }
-
-              promise.success(())
-            }
-          }
-      }
+      content
     }
   }
 
-  thread.start()
+  override protected def sendShutdownRequest(): Option[JValue] = {
+    val rep = sendRequest(Map(
+      "msg_type" -> "shutdown_request",
+      "content" -> ()
+    ))
+    rep
+  }
 
-  override def close(): Unit = {
-    val future = synchronized {
-      _state match {
-        case (Dead() | ShuttingDown() | Error()) =>
-          Future.successful()
-        case _ =>
-          val promise = Promise[Unit]()
-          _queue.add(ShutdownRequest(promise))
-
-          promise.future
-            .andThen {
-            case util.Success(_) =>
-              thread.join()
-            case util.Failure(_) =>
-              thread.interrupt()
-              thread.join()
-          }
-            .andThen { case _ => gatewayServer.shutdown() }
-      }
-    }
+  private def sendRequest(request: Map[String, Any]): Option[JValue] = {
+    stdin.println(write(request))
+    stdin.flush()
 
-    // Give ourselves 10 seconds to tear down the process.
-    Await.result(future, Duration(10, TimeUnit.SECONDS))
+    Option(stdout.readLine()).map { case line => parse(line) }
   }
 }