Pārlūkot izejas kodu

[livy] Factor out the python interpreter

Erick Tryzelaar 10 gadi atpakaļ
vecāks
revīzija
05b431f

+ 40 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/Interpreter.scala

@@ -0,0 +1,40 @@
+/*
+ * Licensed to Cloudera, Inc. under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  Cloudera, Inc. licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.cloudera.hue.livy.repl
+
+import com.cloudera.hue.livy.Utils
+import com.cloudera.hue.livy.sessions.State
+import org.json4s._
+
+import _root_.scala.concurrent._
+import _root_.scala.concurrent.duration.Duration
+
+trait Interpreter {
+  def state: State
+
+  def execute(code: String): Future[JValue]
+
+  def close(): Unit
+
+  @throws(classOf[TimeoutException])
+  @throws(classOf[InterruptedException])
+  final def waitForStateChange(oldState: State, atMost: Duration) = {
+    Utils.waitUntil({ () => state != oldState }, atMost)
+  }
+}

+ 1 - 1
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/Main.scala

@@ -102,7 +102,7 @@ class ScalatraBootstrap extends LifeCycle with Logging {
 
   override def init(context: ServletContext): Unit = {
     session = context.getInitParameter(Main.SESSION_KIND) match {
-      case Main.PYSPARK_SESSION | Main.PYTHON_SESSION => PythonSession.createPySpark()
+      case Main.PYSPARK_SESSION | Main.PYTHON_SESSION => PythonSession.create()
       case Main.SPARK_SESSION | Main.SCALA_SESSION => SparkSession.create()
     }
 

+ 245 - 0
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/python/PythonInterpreter.scala

@@ -0,0 +1,245 @@
+/*
+ * Licensed to Cloudera, Inc. under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  Cloudera, Inc. licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.cloudera.hue.livy.repl.python
+
+import java.io._
+import java.lang.ProcessBuilder.Redirect
+import java.nio.file.Files
+import java.util.concurrent.{SynchronousQueue, TimeUnit}
+
+import com.cloudera.hue.livy.repl.Interpreter
+import com.cloudera.hue.livy.sessions._
+import com.cloudera.hue.livy.{Logging, Utils}
+import org.apache.spark.SparkContext
+import org.json4s.{DefaultFormats, JValue}
+import org.json4s.jackson.JsonMethods._
+import org.json4s.jackson.Serialization.write
+import py4j.GatewayServer
+
+import scala.annotation.tailrec
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration.Duration
+import scala.concurrent.{Await, Future, Promise}
+
+object PythonInterpreter {
+  def create(): Interpreter = {
+    val pythonExec = sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python")
+
+    val gatewayServer = new GatewayServer(null, 0)
+    gatewayServer.start()
+
+    val builder = new ProcessBuilder(Seq(
+      pythonExec,
+      createFakeShell().toString
+    ))
+
+    val env = builder.environment()
+    env.put("PYTHONPATH", pythonPath.mkString(File.pathSeparator))
+    env.put("PYTHONUNBUFFERED", "YES")
+    env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
+    env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", "."))
+
+    builder.redirectError(Redirect.INHERIT)
+
+    val process = builder.start()
+
+    new PythonInterpreter(process, gatewayServer)
+  }
+
+  private def pythonPath = {
+    val pythonPath = new ArrayBuffer[String]
+    pythonPath ++= Utils.jarOfClass(classOf[SparkContext])
+    pythonPath
+  }
+
+  private def createFakeShell(): File = {
+    val source: InputStream = getClass.getClassLoader.getResourceAsStream("fake_shell.py")
+
+    val file = Files.createTempFile("", "").toFile
+    file.deleteOnExit()
+
+    val sink = new FileOutputStream(file)
+    val buf = new Array[Byte](1024)
+    var n = source.read(buf)
+
+    while (n > 0) {
+      sink.write(buf, 0, n)
+      n = source.read(buf)
+    }
+
+    source.close()
+    sink.close()
+
+    file
+  }
+
+  private def createFakePySpark(): File = {
+    val source: InputStream = getClass.getClassLoader.getResourceAsStream("fake_pyspark.sh")
+
+    val file = Files.createTempFile("", "").toFile
+    file.deleteOnExit()
+
+    file.setExecutable(true)
+
+    val sink = new FileOutputStream(file)
+    val buf = new Array[Byte](1024)
+    var n = source.read(buf)
+
+    while (n > 0) {
+      sink.write(buf, 0, n)
+      n = source.read(buf)
+    }
+
+    source.close()
+    sink.close()
+
+    file
+  }
+}
+
+private class PythonInterpreter(process: Process, gatewayServer: GatewayServer)
+  extends Interpreter
+  with Logging
+{
+  implicit val formats = DefaultFormats
+
+  private val stdin = new PrintWriter(process.getOutputStream)
+  private val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1)
+
+  private[this] var _state: State = Starting()
+  private[this] val _queue = new SynchronousQueue[Request]
+
+  override def state: State = _state
+
+  override def execute(code: String): Future[JValue] = {
+    val promise = Promise[JValue]()
+    _queue.put(ExecuteRequest(code, promise))
+    promise.future
+  }
+
+  override def close(): Unit = {
+    _state match {
+      case Dead() =>
+      case ShuttingDown() =>
+        // Another thread must be tearing down the process.
+        waitForStateChange(ShuttingDown(), Duration(10, TimeUnit.SECONDS))
+      case _ =>
+        val promise = Promise[Unit]()
+        _queue.put(ShutdownRequest(promise))
+
+        // Give ourselves 10 seconds to tear down the process.
+        try {
+          Await.result(promise.future, Duration(10, TimeUnit.SECONDS))
+          thread.join()
+        } finally {
+          gatewayServer.shutdown()
+        }
+    }
+  }
+
+  @tailrec
+  private def waitUntilReady(): Unit = {
+    val line = stdout.readLine()
+    line match {
+      case null | "READY" =>
+      case _ => waitUntilReady()
+    }
+  }
+
+  private[this] val thread = new Thread {
+    override def run() = {
+      waitUntilReady()
+
+      _state = Idle()
+
+      loop()
+    }
+
+    @tailrec
+    private def waitUntilReady(): Unit = {
+      val line = stdout.readLine()
+      line match {
+        case null | "READY" =>
+        case _ => waitUntilReady()
+      }
+    }
+
+    private def sendRequest(request: Map[String, Any]): Option[JValue] = {
+      stdin.println(write(request))
+      stdin.flush()
+
+      Option(stdout.readLine()).map { case line => parse(line) }
+    }
+
+    @tailrec
+    def loop(): Unit = {
+      (_state, _queue.take()) match {
+        case (Error(), ExecuteRequest(code, promise)) =>
+          promise.failure(new Exception("session has been terminated"))
+          loop()
+
+        case (state, ExecuteRequest(code, promise)) =>
+          require(state == Idle())
+
+          _state = Busy()
+
+          sendRequest(Map("msg_type" -> "execute_request", "content" -> Map("code" -> code))) match {
+            case Some(rep) =>
+              assert((rep \ "msg_type").extract[String] == "execute_reply")
+
+              val content: JValue = rep \ "content"
+
+              _state = Idle()
+
+              promise.success(content)
+              loop()
+            case None =>
+              _state = Error()
+              promise.failure(new Exception("session has been terminated"))
+          }
+
+        case (_, ShutdownRequest(promise)) =>
+          require(state == Idle() || state == Error())
+
+          _state = ShuttingDown()
+
+          try {
+            sendRequest(Map("msg_type" -> "shutdown_request", "content" -> ())) match {
+              case Some(rep) =>
+                warn(f"process failed to shut down while returning $rep")
+              case None =>
+            }
+
+            process.getInputStream.close()
+            process.getOutputStream.close()
+
+            try {
+              process.destroy()
+            } finally {
+              _state = Dead()
+              promise.success(())
+            }
+          }
+      }
+    }
+  }
+
+  thread.start()
+}

+ 15 - 208
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/python/PythonSession.scala

@@ -18,207 +18,28 @@
 
 package com.cloudera.hue.livy.repl.python
 
-import java.io._
-import java.lang.ProcessBuilder.Redirect
-import java.nio.file.Files
-import java.util.concurrent.{SynchronousQueue, TimeUnit}
-
-import com.cloudera.hue.livy.repl.Session
+import com.cloudera.hue.livy.Logging
+import com.cloudera.hue.livy.repl.{Interpreter, Session}
 import com.cloudera.hue.livy.sessions._
-import com.cloudera.hue.livy.{Logging, Utils}
-import org.apache.spark.SparkContext
-import org.json4s.jackson.JsonMethods._
-import org.json4s.jackson.Serialization.write
-import org.json4s.{DefaultFormats, JValue}
-import py4j.GatewayServer
+import org.json4s.JValue
 
-import scala.annotation.tailrec
-import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent._
-import scala.concurrent.duration.Duration
 
 object PythonSession {
-  def createPython(): Session = {
-    create("python")
-  }
-
-  def createPySpark(): Session = {
-    create(createFakePySpark().toString)
-  }
-
-  private def pythonPath = {
-    val pythonPath = new ArrayBuffer[String]
-    pythonPath ++= Utils.jarOfClass(classOf[SparkContext])
-    pythonPath
-  }
-
-  private def create(driver: String) = {
-    val pythonExec = sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python")
-
-    val gatewayServer = new GatewayServer(null, 0)
-    gatewayServer.start()
-
-    val builder = new ProcessBuilder(Seq(
-      pythonExec,
-      createFakeShell().toString
-    ))
-
-    val env = builder.environment()
-    env.put("PYTHONPATH", pythonPath.mkString(File.pathSeparator))
-    env.put("PYTHONUNBUFFERED", "YES")
-    env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
-    env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", "."))
-
-    builder.redirectError(Redirect.INHERIT)
-
-    val process = builder.start()
-
-    new PythonSession(process, gatewayServer)
-  }
-
-  private def createFakeShell(): File = {
-    val source: InputStream = getClass.getClassLoader.getResourceAsStream("fake_shell.py")
-
-    val file = Files.createTempFile("", "").toFile
-    file.deleteOnExit()
-
-    val sink = new FileOutputStream(file)
-    val buf = new Array[Byte](1024)
-    var n = source.read(buf)
-
-    while (n > 0) {
-      sink.write(buf, 0, n)
-      n = source.read(buf)
-    }
-
-    source.close()
-    sink.close()
-
-    file
-  }
-
-  private def createFakePySpark(): File = {
-    val source: InputStream = getClass.getClassLoader.getResourceAsStream("fake_pyspark.sh")
-
-    val file = Files.createTempFile("", "").toFile
-    file.deleteOnExit()
-
-    file.setExecutable(true)
-
-    val sink = new FileOutputStream(file)
-    val buf = new Array[Byte](1024)
-    var n = source.read(buf)
-
-    while (n > 0) {
-      sink.write(buf, 0, n)
-      n = source.read(buf)
-    }
-
-    source.close()
-    sink.close()
-
-    file
+  def create(): Session = {
+    new PythonSession(PythonInterpreter.create())
   }
 }
 
-private class PythonSession(process: Process, gatewayServer: GatewayServer) extends Session with Logging {
+private class PythonSession(interpreter: Interpreter) extends Session with Logging {
   private implicit def executor: ExecutionContext = ExecutionContext.global
 
-  implicit val formats = DefaultFormats
-
-  private val stdin = new PrintWriter(process.getOutputStream)
-  private val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1)
-
   private var _history = ArrayBuffer[JValue]()
-  private var _state: State = Starting()
-
-  private val queue = new SynchronousQueue[Request]
-
-  private val thread = new Thread {
-    override def run() = {
-      waitUntilReady()
-
-      _state = Idle()
-
-      loop()
-    }
-
-    @tailrec
-    private def waitUntilReady(): Unit = {
-      val line = stdout.readLine()
-      line match {
-        case null | "READY" =>
-        case _ => waitUntilReady()
-      }
-    }
-
-    private def sendRequest(request: Map[String, Any]): Option[JValue] = {
-      stdin.println(write(request))
-      stdin.flush()
-
-      Option(stdout.readLine()).map { case line => parse(line) }
-    }
-
-    @tailrec
-    def loop(): Unit = {
-      (_state, queue.take()) match {
-        case (Error(), ExecuteRequest(code, promise)) =>
-          promise.failure(new Exception("session has been terminated"))
-          loop()
-
-        case (state, ExecuteRequest(code, promise)) =>
-          require(state == Idle())
-
-          _state = Busy()
-
-          sendRequest(Map("msg_type" -> "execute_request", "content" -> Map("code" -> code))) match {
-            case Some(rep) =>
-              assert((rep \ "msg_type").extract[String] == "execute_reply")
-
-              val content: JValue = rep \ "content"
-              _history += content
-
-              _state = Idle()
-
-              promise.success(content)
-              loop()
-            case None =>
-              _state = Error()
-              promise.failure(new Exception("session has been terminated"))
-          }
-
-        case (_, ShutdownRequest(promise)) =>
-          require(state == Idle() || state == Error())
-
-          _state = ShuttingDown()
-
-          try {
-            sendRequest(Map("msg_type" -> "shutdown_request", "content" -> ())) match {
-              case Some(rep) =>
-                warn(f"process failed to shut down while returning $rep")
-              case None =>
-            }
-
-            process.getInputStream.close()
-            process.getOutputStream.close()
-
-            try {
-              process.destroy()
-            } finally {
-              _state = Dead()
-              promise.success(())
-            }
-          }
-      }
-    }
-  }
-
-  thread.start()
 
   override def kind = PySpark()
 
-  override def state = _state
+  override def state = interpreter.state
 
   override def history(): Seq[JValue] = _history
 
@@ -231,30 +52,16 @@ private class PythonSession(process: Process, gatewayServer: GatewayServer) exte
   }
 
   override def execute(code: String): Future[JValue] = {
-    val promise = Promise[JValue]()
-    queue.put(ExecuteRequest(code, promise))
-    promise.future
-  }
-
-  override def close(): Unit = synchronized {
-    _state match {
-      case Dead() =>
-      case ShuttingDown() =>
-        // Another thread must be tearing down the process.
-        waitForStateChange(ShuttingDown(), Duration(10, TimeUnit.SECONDS))
-      case _ =>
-        val promise = Promise[Unit]()
-        queue.put(ShutdownRequest(promise))
-
-        // Give ourselves 10 seconds to tear down the process.
-        try {
-          Await.result(promise.future, Duration(10, TimeUnit.SECONDS))
-          thread.join()
-        } finally {
-          gatewayServer.shutdown()
-        }
+    val future = interpreter.execute(code)
+    future.onSuccess { case result =>
+      synchronized {
+        _history += result
+      }
     }
+    future
   }
+
+  override def close(): Unit = interpreter.close()
 }
 
 private sealed trait Request

+ 1 - 1
apps/spark/java/livy-repl/src/test/scala/com/cloudera/hue/livy/repl/PythonSessionSpec.scala

@@ -27,7 +27,7 @@ import _root_.scala.concurrent.duration.Duration
 
 class PythonSessionSpec extends BaseSessionSpec {
 
-  override def createSession() = PythonSession.createPySpark()
+  override def createSession() = PythonSession.create()
 
   describe("A python session") {
     it("should execute `1 + 2` == 3") {