Просмотр исходного кода

HUE-2588 [livy] Unify most of the servlet code

Erick Tryzelaar 10 лет назад
Родитель
Сommit
d22622f130

+ 12 - 0
apps/spark/java/livy-core/src/main/scala/com/cloudera/hue/livy/sessions/Kind.scala

@@ -18,6 +18,9 @@
 
 package com.cloudera.hue.livy.sessions
 
+import org.json4s.CustomSerializer
+import org.json4s.JsonAST.JString
+
 sealed trait Kind
 case class Spark() extends Kind {
   override def toString = "spark"
@@ -26,3 +29,12 @@ case class Spark() extends Kind {
 case class PySpark() extends Kind {
   override def toString = "pyspark"
 }
+
+case object SessionKindSerializer extends CustomSerializer[Kind](implicit formats => ( {
+  case JString("spark") | JString("scala") => Spark()
+  case JString("pyspark") | JString("python") => PySpark()
+}, {
+  case kind: Kind => JString(kind.toString)
+}
+  )
+)

+ 2 - 2
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/Main.scala

@@ -117,8 +117,8 @@ object Main {
 
 class ScalatraBootstrap extends LifeCycle with Logging {
 
-  var sessionManager: SessionManager[InteractiveSession, CreateInteractiveRequest] = null
-  var batchManager: SessionManager[BatchSession, CreateBatchRequest] = null
+  var sessionManager: SessionManager[InteractiveSession] = null
+  var batchManager: SessionManager[BatchSession] = null
 
   override def init(context: ServletContext): Unit = {
     val livyConf = new LivyConf()

+ 6 - 2
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SessionFactory.scala

@@ -18,11 +18,15 @@
 
 package com.cloudera.hue.livy.server
 
+import org.json4s.{DefaultFormats, Formats, JValue}
+
 import scala.concurrent.Future
 
-abstract class SessionFactory[S <: Session, C] {
+abstract class SessionFactory[S <: Session] {
+
+  protected implicit def jsonFormats: Formats = DefaultFormats
 
-  def create(id: Int, createRequest: C): Future[S]
+  def create(id: Int, createRequest: JValue): Future[S]
 
   def close(): Unit = {}
 }

+ 10 - 12
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SessionManager.scala

@@ -22,8 +22,9 @@ import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicInteger
 
 import com.cloudera.hue.livy.Logging
+import org.json4s.JValue
 
-import scala.collection.JavaConversions._
+import scala.collection.convert.decorateAsScala._
 import scala.concurrent.{ExecutionContext, Future}
 
 object SessionManager {
@@ -34,7 +35,7 @@ object SessionManager {
   val GC_PERIOD = 1000 * 60 * 60
 }
 
-class SessionManager[S <: Session, C](factory: SessionFactory[S, C])
+class SessionManager[S <: Session](factory: SessionFactory[S])
   extends Logging {
 
   import SessionManager._
@@ -42,12 +43,12 @@ class SessionManager[S <: Session, C](factory: SessionFactory[S, C])
   private implicit def executor: ExecutionContext = ExecutionContext.global
 
   protected[this] val _idCounter = new AtomicInteger()
-  protected[this] val _sessions = new ConcurrentHashMap[Int, S]()
+  protected[this] val _sessions = new ConcurrentHashMap[Int, S]().asScala
 
   private val garbageCollector = new GarbageCollector
   garbageCollector.start()
 
-  def create(createRequest: C): Future[S] = {
+  def create(createRequest: JValue): Future[S] = {
     val id = _idCounter.getAndIncrement
     val session: Future[S] = factory.create(id, createRequest)
 
@@ -58,15 +59,12 @@ class SessionManager[S <: Session, C](factory: SessionFactory[S, C])
     })
   }
 
-  def get(id: Int): Option[S] = Option(_sessions.get(id))
+  def get(id: Int): Option[S] = _sessions.get(id)
 
-  def all(): Seq[S] = _sessions.values().toSeq
+  def all(): Iterable[S] = _sessions.values
 
-  def delete(id: Int): Future[Unit] = {
-    get(id) match {
-      case Some(session) => delete(session)
-      case None => Future.successful(())
-    }
+  def delete(id: Int): Option[Future[Unit]] = {
+    get(id).map(delete)
   }
 
   def delete(session: S): Future[Unit] = {
@@ -77,7 +75,7 @@ class SessionManager[S <: Session, C](factory: SessionFactory[S, C])
   }
 
   def remove(id: Int): Option[S] = {
-    Option(_sessions.remove(id))
+    _sessions.remove(id)
   }
 
   def shutdown(): Unit = {}

+ 96 - 3
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SessionServlet.scala

@@ -18,13 +18,19 @@
 
 package com.cloudera.hue.livy.server
 
-import org.json4s.{DefaultFormats, Formats}
+import com.cloudera.hue.livy.Logging
+import com.cloudera.hue.livy.server.interactive.InteractiveSession.SessionFailedToStart
+import com.fasterxml.jackson.core.JsonParseException
+import org.json4s.JsonDSL._
+import org.json4s.{MappingException, DefaultFormats, Formats, JValue}
+import org.scalatra._
 import org.scalatra.json.JacksonJsonSupport
-import org.scalatra.{FutureSupport, MethodOverride, ScalatraServlet, UrlGeneratorSupport}
 
 import scala.concurrent.ExecutionContext
 
-abstract class SessionServlet[S <: Session, C](sessionManager: SessionManager[S, C])
+object SessionServlet extends Logging
+
+abstract class SessionServlet[S <: Session](sessionManager: SessionManager[S])
   extends ScalatraServlet
   with FutureSupport
   with MethodOverride
@@ -35,7 +41,94 @@ abstract class SessionServlet[S <: Session, C](sessionManager: SessionManager[S,
 
   override protected implicit def jsonFormats: Formats = DefaultFormats
 
+  protected def serializeSession(session: S): JValue
+
   before() {
     contentType = formats("json")
   }
+
+  get("/") {
+    val sessions = sessionManager.all().map(serializeSession)
+    Map("sessions" -> sessions)
+  }
+
+  val getSession = get("/:id") {
+    val id = params("id").toInt
+
+    sessionManager.get(id) match {
+      case None => NotFound("session not found")
+      case Some(session) => serializeSession(session)
+    }
+  }
+
+  get("/:id/state") {
+    val id = params("id").toInt
+
+    sessionManager.get(id) match {
+      case None => NotFound("batch not found")
+      case Some(batch) =>
+        ("id", batch.id) ~ ("state", batch.state.toString)
+    }
+  }
+
+  get("/:id/log") {
+    val id = params("id").toInt
+
+    sessionManager.get(id) match {
+      case None => NotFound("session not found")
+      case Some(session) =>
+        val from = params.get("from").map(_.toInt)
+        val size = params.get("size").map(_.toInt)
+        val (from_, total, logLines) = serializeLogs(session, from, size)
+
+        ("id", session.id) ~
+          ("from", from_) ~
+          ("total", total) ~
+          ("log", logLines)
+    }
+  }
+
+  delete("/:id") {
+    val id = params("id").toInt
+
+    sessionManager.delete(id) match {
+      case None => NotFound("session not found")
+      case Some(future) => new AsyncResult {
+        val is = future.map { case () => Ok(Map("msg" -> "deleted")) }
+      }
+    }
+  }
+
+  post("/") {
+    new AsyncResult {
+      val is = for {
+        session <- sessionManager.create(parsedBody)
+      } yield Created(session,
+          headers = Map("Location" -> url(getSession, "id" -> session.id.toString))
+        )
+    }
+  }
+
+  error {
+    case e: JsonParseException => BadRequest(e.getMessage)
+    case e: MappingException => BadRequest(e.getMessage)
+    case e: SessionFailedToStart => InternalServerError(e.getMessage)
+    case e: dispatch.StatusCode => ActionResult(ResponseStatus(e.code), e.getMessage, Map.empty)
+    case e =>
+      SessionServlet.error("internal error", e)
+      InternalServerError(e.toString)
+  }
+
+  private def serializeLogs(session: S, fromOpt: Option[Int], sizeOpt: Option[Int]) = {
+    val lines = session.logLines()
+
+    val size = sizeOpt.getOrElse(100)
+    var from = fromOpt.getOrElse(-1)
+    if (from < 0) {
+      from = math.max(0, lines.length - size)
+    }
+    val until = from + size
+
+    (from, lines.length, lines.view(from, until))
+  }
 }

+ 8 - 3
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/batch/BatchSessionFactory.scala

@@ -19,8 +19,13 @@
 package com.cloudera.hue.livy.server.batch
 
 import com.cloudera.hue.livy.server.SessionFactory
+import org.json4s.JValue
 
-abstract class BatchSessionFactory
-  extends SessionFactory[BatchSession, CreateBatchRequest]
-{
+import scala.concurrent.Future
+
+abstract class BatchSessionFactory extends SessionFactory[BatchSession] {
+  override def create(id: Int, createRequest: JValue) =
+    create(id, createRequest.extract[CreateBatchRequest])
+
+  def create(id: Int, createRequest: CreateBatchRequest): Future[BatchSession]
 }

+ 5 - 91
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/batch/BatchSessionServlet.scala

@@ -19,107 +19,21 @@
 package com.cloudera.hue.livy.server.batch
 
 import com.cloudera.hue.livy.Logging
-import com.cloudera.hue.livy.server.SessionManager
-import com.fasterxml.jackson.core.JsonParseException
-import org.json4s.JsonDSL._
+import com.cloudera.hue.livy.server.{SessionManager, SessionServlet}
 import org.json4s._
-import org.scalatra._
-import org.scalatra.json.JacksonJsonSupport
 
-import scala.concurrent.{Future, ExecutionContext, ExecutionContextExecutor}
+import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}
 
 object BatchSessionServlet extends Logging
 
-class BatchSessionServlet(batchManager: SessionManager[BatchSession, CreateBatchRequest])
-  extends ScalatraServlet
-  with FutureSupport
-  with MethodOverride
-  with JacksonJsonSupport
-  with UrlGeneratorSupport
+class BatchSessionServlet(batchManager: SessionManager[BatchSession])
+  extends SessionServlet[BatchSession](batchManager)
 {
   override protected implicit def executor: ExecutionContextExecutor = ExecutionContext.global
   override protected implicit def jsonFormats: Formats = DefaultFormats ++ Serializers.Formats
 
-  before() {
-    contentType = formats("json")
-  }
-
-  get("/") {
-    Map(
-      "batches" -> batchManager.all()
-    )
-  }
-
-  post("/") {
-    val createBatchRequest = parsedBody.extract[CreateBatchRequest]
-
-    new AsyncResult {
-      val is = for {
-        batch <- batchManager.create(createBatchRequest)
-      } yield Created(batch,
-          headers = Map("Location" -> url(getBatch, "id" -> batch.id.toString))
-        )
-    }
-  }
-
-  val getBatch = get("/:id") {
-    val id = params("id").toInt
-
-    batchManager.get(id) match {
-      case None => NotFound("batch not found")
-      case Some(batch) => Serializers.serializeBatch(batch)
-    }
-  }
-
-  get("/:id/state") {
-    val id = params("id").toInt
+  override protected def serializeSession(session: BatchSession) = Serializers.serializeBatch(session)
 
-    batchManager.get(id) match {
-      case None => NotFound("batch not found")
-      case Some(batch) =>
-        ("id", batch.id) ~ ("state", batch.state.toString)
-    }
-  }
-
-  get("/:id/log") {
-    val id = params("id").toInt
-
-    batchManager.get(id) match {
-      case None => NotFound("batch not found")
-      case Some(batch) =>
-        val from = params.get("from").map(_.toInt)
-        val size = params.get("size").map(_.toInt)
-        val (from_, total, logLines) = Serializers.getLogs(batch, from, size)
-
-        ("id", batch.id) ~
-          ("from", from_) ~
-          ("total", total) ~
-          ("log", logLines)
-    }
-  }
-
-  delete("/:id") {
-    val id = params("id").toInt
-
-    batchManager.remove(id) match {
-      case None => NotFound("batch not found")
-      case Some(batch) =>
-        new AsyncResult {
-          val is = batch.stop().map { case () =>
-            batchManager.delete(batch)
-            Ok(Map("msg" -> "deleted"))
-          }
-        }
-    }
-  }
-
-  error {
-    case e: JsonParseException => BadRequest(e.getMessage)
-    case e: MappingException => BadRequest(e.getMessage)
-    case e =>
-      BatchSessionServlet.error("internal error", e)
-      InternalServerError(e.toString)
-  }
 }
 
 private object Serializers {

+ 12 - 3
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/interactive/InteractiveSessionFactory.scala

@@ -19,8 +19,17 @@
 package com.cloudera.hue.livy.server.interactive
 
 import com.cloudera.hue.livy.server.SessionFactory
+import com.cloudera.hue.livy.sessions.SessionKindSerializer
+import org.json4s.{DefaultFormats, Formats, JValue}
 
-trait InteractiveSessionFactory
-  extends SessionFactory[InteractiveSession, CreateInteractiveRequest]
-{
+import scala.concurrent.Future
+
+trait InteractiveSessionFactory extends SessionFactory[InteractiveSession] {
+
+  override protected implicit def jsonFormats: Formats = DefaultFormats ++ List(SessionKindSerializer)
+
+  override def create(id: Int, createRequest: JValue) =
+    create(id, createRequest.extract[CreateInteractiveRequest])
+
+  def create(id: Int, createRequest: CreateInteractiveRequest): Future[InteractiveSession]
 }

+ 2 - 82
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/interactive/InteractiveSessionServlet.scala

@@ -37,43 +37,12 @@ import scala.concurrent.duration._
 
 object InteractiveSessionServlet extends Logging
 
-class InteractiveSessionServlet(sessionManager: SessionManager[InteractiveSession, CreateInteractiveRequest])
+class InteractiveSessionServlet(sessionManager: SessionManager[InteractiveSession])
   extends SessionServlet(sessionManager)
 {
   override protected implicit def jsonFormats: Formats = DefaultFormats ++ Serializers.Formats
 
-  get("/") {
-    Map(
-      "sessions" -> sessionManager.all
-    )
-  }
-
-  val getSession = get("/:sessionId") {
-    val sessionId = params("sessionId").toInt
-
-    sessionManager.get(sessionId) match {
-      case Some(session) => session
-      case None => NotFound("Session not found")
-    }
-  }
-
-  post("/") {
-    val createInteractiveRequest = parsedBody.extract[CreateInteractiveRequest]
-
-    new AsyncResult {
-      val is = {
-        val sessionFuture = sessionManager.create(createInteractiveRequest)
-
-        sessionFuture.map { case session =>
-          Created(session,
-            headers = Map(
-              "Location" -> url(getSession, "sessionId" -> session.id.toString)
-            )
-          )
-        }
-      }
-    }
-  }
+  override protected def serializeSession(session: InteractiveSession) = Serializers.serializeSession(session)
 
   post("/:sessionId/callback") {
     val sessionId = params("sessionId").toInt
@@ -116,36 +85,6 @@ class InteractiveSessionServlet(sessionManager: SessionManager[InteractiveSessio
     }
   }
 
-  delete("/:sessionId") {
-    val sessionId = params("sessionId").toInt
-    sessionManager.get(sessionId) match {
-      case Some(session) =>
-        val future = for {
-          _ <- sessionManager.delete(session)
-        } yield Ok(Map("msg" -> "deleted"))
-
-        new AsyncResult { val is = future }
-      case None => NotFound("Session not found")
-    }
-  }
-
-  get("/:sessionId/log") {
-    val sessionId = params("sessionId").toInt
-
-    sessionManager.get(sessionId) match {
-      case None => NotFound("Session not found")
-      case Some(session: InteractiveSession) =>
-        val from = params.get("from").map(_.toInt)
-        val size = params.get("size").map(_.toInt)
-        val (from_, total, logLines) = Serializers.getLogs(session, from, size)
-
-        ("id", session.id) ~
-          ("from", from_) ~
-          ("total", total) ~
-          ("log", logLines)
-    }
-  }
-
   get("/:sessionId/statements") {
     val sessionId = params("sessionId").toInt
 
@@ -196,16 +135,6 @@ class InteractiveSessionServlet(sessionManager: SessionManager[InteractiveSessio
       case None => NotFound("Session not found")
     }
   }
-
-  error {
-    case e: JsonParseException => BadRequest(e.getMessage)
-    case e: MappingException => BadRequest(e.getMessage)
-    case e: SessionFailedToStart => InternalServerError(e.getMessage)
-    case e: dispatch.StatusCode => ActionResult(ResponseStatus(e.code), e.getMessage, Map.empty)
-    case e =>
-      InteractiveSessionServlet.error("internal error", e)
-      InternalServerError(e.toString)
-  }
 }
 
 private case class CallbackRequest(url: String)
@@ -267,15 +196,6 @@ private object Serializers {
     )
   )
 
-  case object SessionKindSerializer extends CustomSerializer[Kind](implicit formats => ( {
-    case JString("spark") | JString("scala") => Spark()
-    case JString("pyspark") | JString("python") => PySpark()
-  }, {
-    case kind: Kind => serializeSessionKind(kind)
-  }
-    )
-  )
-
   case object SessionStateSerializer extends CustomSerializer[State](implicit formats => ( {
     // We don't support deserialization.
     PartialFunction.empty

+ 1 - 1
apps/spark/java/livy-server/src/test/scala/com/cloudera/hue/livy/server/batches/BatchServletSpec.scala

@@ -70,7 +70,7 @@ class BatchServletSpec extends ScalatraSuite with FunSpecLike with BeforeAndAfte
         status should equal (200)
         header("Content-Type") should include("application/json")
         val parsedBody = parse(body)
-        parsedBody \ "batches" should equal (JArray(List()))
+        parsedBody \ "sessions" should equal (JArray(List()))
       }
 
       val createBatchRequest = write(CreateBatchRequest(