Browse Source

[livy] Make sure to shut down the pyspark SparkContext

This confuses the tests, as spark really only wants one SparkContext
in a JVM, and the old behavior was relying on the GC to shutdown the
context.
Erick Tryzelaar 10 years ago
parent
commit
19cdb06f9c

+ 7 - 0
apps/spark/java/livy-repl/src/main/resources/fake_shell.py

@@ -235,6 +235,10 @@ def magic_table(name):
     })
 
 
+def shutdown_request(content):
+    sys.exit()
+
+
 magic_router = {
     'table': magic_table,
 }
@@ -242,6 +246,7 @@ magic_router = {
 
 msg_type_router = {
     'execute_request': execute_request,
+    'shutdown_request': shutdown_request,
 }
 
 sys_stdin = sys.stdin
@@ -319,6 +324,8 @@ try:
         print >> sys_stdout, response
         sys_stdout.flush()
 finally:
+    global_dict['sc'].stop()
+
     sys.stdin = sys_stdin
     sys.stdout = sys_stdout
     sys.stderr = sys_stderr

+ 27 - 24
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/python/PythonSession.scala

@@ -6,7 +6,7 @@ import java.nio.file.Files
 import java.util.concurrent.{TimeUnit, SynchronousQueue}
 
 
-import com.cloudera.hue.livy.Utils
+import com.cloudera.hue.livy.{Logging, Utils}
 import com.cloudera.hue.livy.repl.Session
 import org.apache.spark.SparkContext
 import org.json4s.jackson.JsonMethods._
@@ -104,7 +104,7 @@ object PythonSession {
   }
 }
 
-private class PythonSession(process: Process, gatewayServer: GatewayServer) extends Session {
+private class PythonSession(process: Process, gatewayServer: GatewayServer) extends Session with Logging {
   private implicit def executor: ExecutionContext = ExecutionContext.global
 
   implicit val formats = DefaultFormats
@@ -127,7 +127,7 @@ private class PythonSession(process: Process, gatewayServer: GatewayServer) exte
     }
 
     @tailrec
-    def waitUntilReady(): Unit = {
+    private def waitUntilReady(): Unit = {
       val line = stdout.readLine()
       line match {
         case null | "READY" =>
@@ -135,6 +135,13 @@ private class PythonSession(process: Process, gatewayServer: GatewayServer) exte
       }
     }
 
+    private def sendRequest(request: Map[String, Any]): Option[JValue] = {
+      stdin.println(write(request))
+      stdin.flush()
+
+      Option(stdout.readLine()).map { case line => parse(line) }
+    }
+
     @tailrec
     def loop(): Unit = {
       (_state, queue.take()) match {
@@ -147,38 +154,34 @@ private class PythonSession(process: Process, gatewayServer: GatewayServer) exte
 
           _state = Session.Busy()
 
-          val msg = Map(
-            "msg_type" -> "execute_request",
-            "content" -> Map("code" -> code))
-
-          stdin.println(write(msg))
-          stdin.flush()
-
-          val line = stdout.readLine()
-          // The python process shut down
-          if (line == null) {
-            _state = Session.Error()
-            promise.failure(new Exception("session has been terminated"))
-          } else {
-            val rep = parse(line)
-            assert((rep \ "msg_type").extract[String] == "execute_reply")
+          sendRequest(Map("msg_type" -> "execute_request", "content" -> Map("code" -> code))) match {
+            case Some(rep) =>
+              assert((rep \ "msg_type").extract[String] == "execute_reply")
 
-            val content: JValue = rep \ "content"
-            _history += content
+              val content: JValue = rep \ "content"
+              _history += content
 
-            _state = Session.Idle()
+              _state = Session.Idle()
 
-            promise.success(content)
+              promise.success(content)
+              loop()
+            case None =>
+              _state = Session.Error()
+              promise.failure(new Exception("session has been terminated"))
           }
 
-          loop()
-
         case (_, ShutdownRequest(promise)) =>
           require(state == Session.Idle() || state == Session.Error())
 
           _state = Session.ShuttingDown()
 
           try {
+            sendRequest(Map("msg_type" -> "shutdown_request", "content" -> ())) match {
+              case Some(rep) =>
+                warn(f"process failed to shut down while returning $rep")
+              case None =>
+            }
+
             process.getInputStream.close()
             process.getOutputStream.close()