Преглед изворни кода

[livy] Fix a deadlock if the python interpreter thread dies unexpectedly

Erick Tryzelaar пре 10 година
родитељ
комит
ae34f65

+ 51 - 29
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/python/PythonInterpreter.scala

@@ -21,7 +21,7 @@ package com.cloudera.hue.livy.repl.python
 import java.io._
 import java.lang.ProcessBuilder.Redirect
 import java.nio.file.Files
-import java.util.concurrent.{SynchronousQueue, TimeUnit}
+import java.util.concurrent.{LinkedBlockingQueue, ConcurrentLinkedQueue, SynchronousQueue, TimeUnit}
 
 import com.cloudera.hue.livy.repl.Interpreter
 import com.cloudera.hue.livy.sessions._
@@ -36,7 +36,8 @@ import scala.annotation.tailrec
 import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.duration.Duration
-import scala.concurrent.{Await, Future, Promise}
+import scala.concurrent.{ExecutionContext, Await, Future, Promise}
+import scala.util
 
 object PythonInterpreter {
   def create(): Interpreter = {
@@ -118,39 +119,25 @@ private class PythonInterpreter(process: Process, gatewayServer: GatewayServer)
   extends Interpreter
   with Logging
 {
+  implicit val executor: ExecutionContext = ExecutionContext.global
   implicit val formats = DefaultFormats
 
   private val stdin = new PrintWriter(process.getOutputStream)
   private val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1)
 
   private[this] var _state: State = Starting()
-  private[this] val _queue = new SynchronousQueue[Request]
+  private[this] val _queue = new LinkedBlockingQueue[Request]
 
   override def state: State = _state
 
-  override def execute(code: String): Future[JValue] = {
-    val promise = Promise[JValue]()
-    _queue.put(ExecuteRequest(code, promise))
-    promise.future
-  }
-
-  override def close(): Unit = {
+  override def execute(code: String): Future[JValue] = synchronized {
     _state match {
-      case Dead() =>
-      case ShuttingDown() =>
-        // Another thread must be tearing down the process.
-        waitForStateChange(ShuttingDown(), Duration(10, TimeUnit.SECONDS))
+      case (Dead() | ShuttingDown() | Error()) =>
+        Future.failed(new IllegalStateException("interpreter is not running"))
       case _ =>
-        val promise = Promise[Unit]()
-        _queue.put(ShutdownRequest(promise))
-
-        // Give ourselves 10 seconds to tear down the process.
-        try {
-          Await.result(promise.future, Duration(10, TimeUnit.SECONDS))
-          thread.join()
-        } finally {
-          gatewayServer.shutdown()
-        }
+        val promise = Promise[JValue]()
+        _queue.add(ExecuteRequest(code, promise))
+        promise.future
     }
   }
 
@@ -206,22 +193,29 @@ private class PythonInterpreter(process: Process, gatewayServer: GatewayServer)
 
               val content: JValue = rep \ "content"
 
-              _state = Idle()
+              synchronized {
+                _state = Idle()
+              }
 
               promise.success(content)
               loop()
             case None =>
-              _state = Error()
+              synchronized {
+                _state = Error()
+              }
+
               promise.failure(new Exception("session has been terminated"))
           }
 
         case (_, ShutdownRequest(promise)) =>
           require(state == Idle() || state == Error())
 
-          _state = ShuttingDown()
+          synchronized {
+            _state = ShuttingDown()
+          }
 
           try {
-            sendRequest(Map("msg_type" -> "shutdown_request", "content" -> ())) match {
+            sendRequest(Map("msg_type" -> "shutdown_request", "content" ->())) match {
               case Some(rep) =>
                 warn(f"process failed to shut down while returning $rep")
               case None =>
@@ -238,7 +232,10 @@ private class PythonInterpreter(process: Process, gatewayServer: GatewayServer)
             try {
               process.destroy()
             } finally {
-              _state = Dead()
+              synchronized {
+                _state = Dead()
+              }
+
               promise.success(())
             }
           }
@@ -247,4 +244,29 @@ private class PythonInterpreter(process: Process, gatewayServer: GatewayServer)
   }
 
   thread.start()
+
+  override def close(): Unit = {
+    val future = synchronized {
+      _state match {
+        case (Dead() | ShuttingDown() | Error()) =>
+          Future.successful()
+        case _ =>
+          val promise = Promise[Unit]()
+          _queue.add(ShutdownRequest(promise))
+
+          promise.future
+            .andThen {
+            case util.Success(_) =>
+              thread.join()
+            case util.Failure(_) =>
+              thread.interrupt()
+              thread.join()
+          }
+            .andThen { case _ => gatewayServer.shutdown() }
+      }
+    }
+
+    // Give ourselves 10 seconds to tear down the process.
+    Await.result(future, Duration(10, TimeUnit.SECONDS))
+  }
 }