Browse Source

[livy] Clean up livy-server code

Erick Tryzelaar 10 years ago
parent
commit
c2e190af1f

+ 17 - 4
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/Statement.scala

@@ -6,11 +6,24 @@ import org.json4s.JValue
 import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
 import scala.util.{Failure, Success}
 
-class Statement(val id: Int, val request: ExecuteRequest, val output: Future[JValue]) {
+object Statement {
   sealed trait State
-  case class Running() extends State
-  case class Available() extends State
-  case class Error() extends State
+
+  case class Running() extends State {
+    override def toString = "running"
+  }
+
+  case class Available() extends State {
+    override def toString = "available"
+  }
+
+  case class Error() extends State {
+    override def toString = "error"
+  }
+}
+
+class Statement(val id: Int, val request: ExecuteRequest, val output: Future[JValue]) {
+  import Statement._
 
   protected implicit def executor: ExecutionContextExecutor = ExecutionContext.global
 

+ 86 - 42
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/WebApp.scala

@@ -8,30 +8,25 @@ import com.cloudera.hue.livy.msgs.ExecuteRequest
 import com.cloudera.hue.livy.server.sessions.Session
 import com.cloudera.hue.livy.server.sessions.Session.SessionFailedToStart
 import com.fasterxml.jackson.core.JsonParseException
-import org.json4s.{DefaultFormats, Formats, MappingException}
+import org.json4s.JsonAST.JString
+import org.json4s._
 import org.scalatra._
 import org.scalatra.json.JacksonJsonSupport
 
 import scala.concurrent._
 import scala.concurrent.duration._
 
-object WebApp extends Logging {
-  case class CreateSessionRequest(lang: String)
-}
-
-case class CallbackRequest(url: String)
+object WebApp extends Logging
 
 class WebApp(sessionManager: SessionManager)
   extends ScalatraServlet
   with FutureSupport
   with MethodOverride
   with JacksonJsonSupport
-  with UrlGeneratorSupport {
-
-  import WebApp._
-
+  with UrlGeneratorSupport
+{
   override protected implicit def executor: ExecutionContextExecutor = ExecutionContext.global
-  override protected implicit def jsonFormats: Formats = DefaultFormats
+  override protected implicit def jsonFormats: Formats = DefaultFormats ++ Serializers.Formats
 
   before() {
     contentType = formats("json")
@@ -45,24 +40,21 @@ class WebApp(sessionManager: SessionManager)
 
   val getSession = get("/sessions/:sessionId") {
     sessionManager.get(params("sessionId")) match {
-      case Some(session) => formatSession(session)
+      case Some(session) => session
       case None => NotFound("Session not found")
     }
   }
 
   post("/sessions") {
     val createSessionRequest = parsedBody.extract[CreateSessionRequest]
+    val sessionFuture = sessionManager.createSession(createSessionRequest.lang)
 
-    val sessionFuture = createSessionRequest.lang match {
-      case "spark" | "scala" => sessionManager.createSession(Session.Spark())
-      case "pyspark" | "python" => sessionManager.createSession(Session.PySpark())
-      case lang => halt(400, "unsupported language: " + lang)
-    }
-
-    val rep = sessionFuture.map {
-      case session =>
-        Created(formatSession(session),
-          headers = Map("Location" -> url(getSession, "sessionId" -> session.id.toString)))
+    val rep = sessionFuture.map { case session =>
+      Created(session,
+        headers = Map(
+          "Location" -> url(getSession, "sessionId" -> session.id.toString)
+        )
+      )
     }
 
     new AsyncResult { val is = rep }
@@ -118,7 +110,7 @@ class WebApp(sessionManager: SessionManager)
     sessionManager.get(params("sessionId")) match {
       case Some(session: Session) =>
         Map(
-          "statements" -> session.statements().map(formatStatement)
+          "statements" -> session.statements()
         )
       case None => NotFound("Session not found")
     }
@@ -128,7 +120,7 @@ class WebApp(sessionManager: SessionManager)
     sessionManager.get(params("sessionId")) match {
       case Some(session) =>
         session.statement(params("statementId").toInt) match {
-          case Some(statement) => formatStatement(statement)
+          case Some(statement) => statement
           case None => NotFound("Statement not found")
         }
       case None => NotFound("Session not found")
@@ -142,7 +134,7 @@ class WebApp(sessionManager: SessionManager)
       case Some(session) =>
         val statement = session.executeStatement(req)
 
-        Created(formatStatement(statement),
+        Created(statement,
           headers = Map(
             "Location" -> url(getStatement,
               "sessionId" -> session.id.toString,
@@ -160,27 +152,79 @@ class WebApp(sessionManager: SessionManager)
       WebApp.error("internal error", e)
       InternalServerError(e.toString)
   }
+}
 
-  private def formatSession(session: Session) = {
-    Map(
-      "id" -> session.id,
-      "kind" -> session.kind.toString,
-      "state" -> session.state.getClass.getSimpleName.toLowerCase
-    )
+private case class CreateSessionRequest(lang: Session.Kind, proxyUser: Option[String])
+private case class CallbackRequest(url: String)
+
+private object Serializers {
+  import JsonDSL._
+
+  def SessionFormats: List[CustomSerializer[_]] = List(SessionSerializer, SessionKindSerializer, SessionStateSerializer)
+  def StatementFormats: List[CustomSerializer[_]] = List(StatementSerializer, StatementStateSerializer)
+  def Formats: List[CustomSerializer[_]] = SessionFormats ++ StatementFormats
+
+  private def serializeSessionState(state: Session.State) = JString(state.toString)
+
+  private def serializeSessionKind(kind: Session.Kind) = JString(kind.toString)
+
+  private def serializeStatementState(state: Statement.State) = JString(state.toString)
+
+  case object SessionSerializer extends CustomSerializer[Session](implicit formats => ( {
+    // We don't support deserialization.
+    PartialFunction.empty
+  }, {
+    case session: Session =>
+      import JsonDSL._
+
+      ("id", session.id) ~
+      ("state", serializeSessionState(session.state)) ~
+      ("kind", serializeSessionKind(session.kind))
   }
+    )
+  )
 
-  private def formatStatement(statement: Statement) = {
-    // Take a couple milliseconds to see if the statement has finished.
-    val output = try {
-      Await.result(statement.output, Duration(100, TimeUnit.MILLISECONDS))
-    } catch {
-      case _: TimeoutException => null
-    }
+  case object SessionKindSerializer extends CustomSerializer[Session.Kind](implicit formats => ( {
+    case JString("spark") | JString("scala") => Session.Spark()
+    case JString("pyspark") | JString("python") => Session.PySpark()
+  }, {
+    case kind: Session.Kind => serializeSessionKind(kind)
+  }
+    )
+  )
 
-    Map(
-      "id" -> statement.id,
-      "state" -> statement.state.getClass.getSimpleName.toLowerCase,
-      "output" -> output
+  case object SessionStateSerializer extends CustomSerializer[Session.State](implicit formats => ( {
+    // We don't support deserialization.
+    PartialFunction.empty
+  }, {
+    case state: Session.State => JString(state.toString)
+  }
     )
+  )
+
+  case object StatementSerializer extends CustomSerializer[Statement](implicit formats => ( {
+    // We don't support deserialization.
+    PartialFunction.empty
+  }, {
+    case statement: Statement =>
+      // Take a couple milliseconds to see if the statement has finished.
+      val output = try {
+        Await.result(statement.output, Duration(100, TimeUnit.MILLISECONDS))
+      } catch {
+        case _: TimeoutException => null
+      }
+
+      ("id" -> statement.id) ~
+        ("state" -> serializeStatementState(statement.state)) ~
+        ("output" -> output)
+  }))
+
+  case object StatementStateSerializer extends CustomSerializer[Statement.State](implicit formats => ( {
+    // We don't support deserialization.
+    PartialFunction.empty
+  }, {
+    case state: Statement.State => JString(state.toString)
   }
+    )
+  )
 }