Explorar o código

[spark] Factor the spark web session into it's own class

Erick Tryzelaar %!s(int64=11) %!d(string=hai) anos
pai
achega
486f652bc1

+ 2 - 2
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/Session.scala

@@ -24,7 +24,7 @@ trait Session {
 
   def statements(fromIndex: Integer, toIndex: Integer): Future[List[ExecuteResponse]]
 
-  def interrupt(): Unit
+  def interrupt(): Future[Unit]
 
-  def close(): Unit
+  def close(): Future[Unit]
 }

+ 3 - 3
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SessionFactory.scala

@@ -2,7 +2,7 @@ package com.cloudera.hue.livy.server
 
 import java.util.UUID
 
-import scala.concurrent.{ExecutionContext, Future, future}
+import scala.concurrent.{ExecutionContext, Future}
 
 trait SessionFactory {
   def createSparkSession: Future[Session]
@@ -13,9 +13,9 @@ class ProcessSessionFactory extends SessionFactory {
   implicit def executor: ExecutionContext = ExecutionContext.global
 
   override def createSparkSession: Future[Session] = {
-    future {
+    Future {
       val id = UUID.randomUUID().toString
-      new SparkProcessSession(id)
+      SparkProcessSession.create(id)
     }
   }
 }

+ 12 - 9
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SessionManager.scala

@@ -1,7 +1,8 @@
 package com.cloudera.hue.livy.server
 
 import scala.collection.concurrent.TrieMap
-import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
+import scala.concurrent.duration.Duration
+import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
 
 object SessionManager {
   // Time in milliseconds; TODO: make configurable
@@ -38,20 +39,22 @@ class SessionManager(factory: SessionFactory) {
   }
 
   def close(): Unit = {
-    sessions.values.foreach(close)
+    Await.result(Future.sequence(sessions.values.map(close)), Duration.Inf)
     garbageCollector.shutdown()
   }
 
-  def close(sessionId: String): Unit = {
-    sessions.remove(sessionId) match {
-      case Some(session) => session.close()
-      case None =>
+  def close(sessionId: String): Future[Unit] = {
+    sessions.get(sessionId) match {
+      case Some(session) => close(session)
+      case None => Future.successful(Unit)
     }
   }
 
-  def close(session: Session): Unit = {
-    sessions.remove(session.id)
-    session.close()
+  def close(session: Session): Future[Unit] = {
+    session.close().map { case _ =>
+        sessions.remove(session.id)
+        Unit
+    }
   }
 
   def collectGarbage() = {

+ 12 - 106
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SparkProcessSession.scala

@@ -1,23 +1,18 @@
 package com.cloudera.hue.livy.server
 
-import java.util.concurrent.TimeoutException
-
-import com.cloudera.hue.livy.{ExecuteRequest, ExecuteResponse, Logging}
-import dispatch._, Defaults._
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods._
-import org.json4s.jackson.Serialization.write
-import org.json4s.{DefaultFormats, Formats}
-
 import scala.annotation.tailrec
-import scala.concurrent.duration._
-import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
+import scala.concurrent.Future
 import scala.io.Source
 
 object SparkProcessSession {
   val LIVY_HOME = System.getenv("LIVY_HOME")
   val SPARK_SHELL = LIVY_HOME + "/spark-shell"
 
+  def create(id: String): Session = {
+    val (process, port) = startProcess()
+    new SparkProcessSession(id, process, port)
+  }
+
   // Loop until we've started a process with a valid port.
   private def startProcess(): (Process, Int) = {
     val regex = """Starting livy-repl on port (\d+)""".r
@@ -60,103 +55,14 @@ object SparkProcessSession {
   }
 }
 
-class SparkProcessSession(val id: String) extends Session with Logging {
-
-  import com.cloudera.hue.livy.server.SparkProcessSession._
-
-  private[this] implicit def executor: ExecutionContextExecutor = ExecutionContext.global
-  private[this] implicit def jsonFormats: Formats = DefaultFormats
-
-  private[this] var _lastActivity = Long.MaxValue
-  private[this] var _state: State = Running()
-  private[this] val (process, port) = startProcess()
-  private[this] val svc = host("localhost", port)
-
-  override def lastActivity: Long = _lastActivity
-
-  override def state: State = _state
-
-  override def executeStatement(statement: String): Future[ExecuteResponse] = {
-    ensureRunning {
-      touchLastActivity()
-
-      var req = (svc / "statements").setContentType("application/json", "UTF-8")
-      req = req << write(ExecuteRequest(statement))
-
-      for {
-        body <- Http(req OK as.json4s.Json)
-      } yield body.extract[ExecuteResponse]
-    }
-  }
-
-  override def statement(statementId: Int): Future[ExecuteResponse] = {
-    ensureRunning {
-      val req = svc / "statements" / statementId
-
-      for {
-        body <- Http(req OK as.json4s.Json)
-      } yield body.extract[ExecuteResponse]
-    }
-  }
-
-  override def statements(): Future[List[ExecuteResponse]] = {
-    ensureRunning {
-      val req = svc / "statements"
-
-      for {
-        body <- Http(req OK as.json4s.Json)
-      } yield body.extract[List[ExecuteResponse]]
-    }
-  }
-
-  override def statements(fromIndex: Integer, toIndex: Integer): Future[List[ExecuteResponse]] = {
-    ensureRunning {
-      val req = (svc / "statements")
-        .addQueryParameter("from", fromIndex.toString)
-        .addQueryParameter("to", toIndex.toString)
-
-      for {
-        body <- Http(req OK as.json4s.Json)
-      } yield body.extract[List[ExecuteResponse]]
-    }
-  }
-    override def interrupt(): Unit = {
-    close()
-  }
-
-  override def close(): Unit = {
-    synchronized {
-      _state match {
-        case Running() =>
-          _state = Stopping()
-
-          // Give the repl some time to shut down cleanly.
-          try {
-            Await.ready(Http(svc.DELETE OK as.String), 5 seconds)
-          } catch {
-            // Ignore timeouts
-            case _: TimeoutException =>
-            case _: InterruptedException =>
-          }
+private class SparkProcessSession(id: String, process: Process, port: Int) extends SparkWebSession(id, "localhost", port) {
 
-          process.destroy()
-          _state = Stopped()
-        case Stopping() | Stopped() =>
-      }
-    }
-  }
+  override def close(): Future[Unit] = {
+    super.close() andThen { case r =>
+      // Make sure the process is reaped.
+      process.waitFor()
 
-  private def touchLastActivity() = {
-    _lastActivity = System.currentTimeMillis()
-  }
-
-  private def ensureRunning[A](f: => A) = {
-    synchronized {
-      if (_state == Running()) {
-        f
-      } else {
-        throw new IllegalStateException("Session is in state %s" format _state)
-      }
+      r
     }
   }
 }

+ 124 - 0
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SparkWebSession.scala

@@ -0,0 +1,124 @@
+package com.cloudera.hue.livy.server
+
+import com.cloudera.hue.livy._
+import dispatch._
+import org.json4s.jackson.Serialization.write
+import org.json4s.{DefaultFormats, Formats}
+
+import scala.annotation.tailrec
+import scala.concurrent.{Future, _}
+
+abstract class SparkWebSession(val id: String, hostname: String, port: Int)
+  extends Session
+  with Logging {
+
+  protected implicit def executor: ExecutionContextExecutor = ExecutionContext.global
+  protected implicit def jsonFormats: Formats = DefaultFormats
+
+  private[this] var _lastActivity = Long.MaxValue
+  private[this] var _state: State = Running()
+  private[this] val svc = host(hostname, port)
+
+  override def lastActivity: Long = _lastActivity
+
+  override def state: State = _state
+
+  override def executeStatement(statement: String): Future[ExecuteResponse] = {
+    ensureRunning {
+      touchLastActivity()
+
+      var req = (svc / "statements").setContentType("application/json", "UTF-8")
+      req = req << write(ExecuteRequest(statement))
+
+      for {
+        body <- Http(req OK as.json4s.Json)
+      } yield body.extract[ExecuteResponse]
+    }
+  }
+
+  override def statement(statementId: Int): Future[ExecuteResponse] = {
+    ensureRunning {
+      val req = svc / "statements" / statementId
+
+      for {
+        body <- Http(req OK as.json4s.Json)
+      } yield body.extract[ExecuteResponse]
+    }
+  }
+
+  override def statements(): Future[List[ExecuteResponse]] = {
+    ensureRunning {
+      val req = svc / "statements"
+
+      for {
+        body <- Http(req OK as.json4s.Json)
+      } yield body.extract[List[ExecuteResponse]]
+    }
+  }
+
+  override def statements(fromIndex: Integer, toIndex: Integer): Future[List[ExecuteResponse]] = {
+    ensureRunning {
+      val req = (svc / "statements")
+        .addQueryParameter("from", fromIndex.toString)
+        .addQueryParameter("to", toIndex.toString)
+
+      for {
+        body <- Http(req OK as.json4s.Json)
+      } yield body.extract[List[ExecuteResponse]]
+    }
+  }
+  override def interrupt(): Future[Unit] = {
+    close()
+  }
+
+  override def close(): Future[Unit] = {
+    synchronized {
+      _state match {
+        case Running() =>
+          _state = Stopping()
+
+          Http(svc.DELETE OK as.String).map { case rep =>
+            synchronized {
+              _state = Stopped()
+            }
+
+            Unit
+          }
+        case Stopping() =>
+          @tailrec
+          def waitForStateChange(state: State): Unit = {
+            if (_state == state) {
+              Thread.sleep(1000)
+              waitForStateChange(state)
+            }
+          }
+
+          Future {
+            waitForStateChange(Stopping())
+
+            if (_state == Stopped()) {
+              Future.successful(Unit)
+            } else {
+              Future.failed(new IllegalStateException("livy-repl did not stop: %s" format _state))
+            }
+          }
+        case Stopped() =>
+          Future.successful(Unit)
+      }
+    }
+  }
+
+  private def touchLastActivity() = {
+    _lastActivity = System.currentTimeMillis()
+  }
+
+  private def ensureRunning[A](f: => A) = {
+    synchronized {
+      if (_state == Running()) {
+        f
+      } else {
+        throw new IllegalStateException("Session is in state %s" format _state)
+      }
+    }
+  }
+}

+ 5 - 2
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/WebApp.scala

@@ -60,8 +60,11 @@ class WebApp(sessionManager: SessionManager)
   }
 
   delete("/sessions/:sessionId") {
-    sessionManager.close(params("sessionId"))
-    NoContent
+    new AsyncResult() {
+      val is = for {
+      _ <- sessionManager.close(params("sessionId"))
+      } yield NoContent
+    }
   }
 
   post("/sessions/:sessionId/statements") {