Переглянути джерело

[livy] Flesh out batch, add yarn-client mode

Erick Tryzelaar 10 роки тому
батько
коміт
242ee0558f

+ 52 - 0
apps/spark/java/livy-core/src/main/scala/com/cloudera/hue/livy/LineBufferedProcess.scala

@@ -0,0 +1,52 @@
+package com.cloudera.hue.livy
+
+import scala.io.Source
+
+class LineBufferedProcess(process: Process) extends Logging {
+
+  private[this] var _stdoutLines: IndexedSeq[String] = IndexedSeq()
+  private[this] var _stderrLines: IndexedSeq[String] = IndexedSeq()
+
+  private val stdoutThread = new Thread {
+    override def run() = {
+      val lines = Source.fromInputStream(process.getInputStream).getLines()
+      for (line <- lines) {
+        trace("stdout: ", line)
+        _stdoutLines +:= line
+      }
+    }
+  }
+  stdoutThread.setDaemon(true)
+  stdoutThread.start()
+
+  private val stderrThread = new Thread {
+    override def run() = {
+      val lines = Source.fromInputStream(process.getErrorStream).getLines()
+      for (line <- lines) {
+        trace("stderr: ", line)
+        _stderrLines +:= line
+      }
+    }
+  }
+  stderrThread.setDaemon(true)
+  stderrThread.start()
+
+  def stdoutLines: IndexedSeq[String] = _stdoutLines
+
+  def stderrLines: IndexedSeq[String] = _stderrLines
+
+  def destroy(): Unit = {
+    process.destroy()
+  }
+
+  def exitValue(): Int = {
+    process.exitValue()
+  }
+
+  def waitFor(): Int = {
+    val output = process.waitFor()
+    stdoutThread.join()
+    stderrThread.join()
+    output
+  }
+}

+ 6 - 0
apps/spark/java/livy-core/src/main/scala/com/cloudera/hue/livy/Logging.scala

@@ -5,6 +5,12 @@ import org.slf4j.LoggerFactory
 trait Logging {
   lazy val logger = LoggerFactory.getLogger(this.getClass)
 
+  def trace(message: => Any) = {
+    if (logger.isTraceEnabled) {
+      logger.trace(message.toString)
+    }
+  }
+
   def debug(message: => Any) = {
     if (logger.isDebugEnabled) {
       logger.debug(message.toString)

+ 10 - 3
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/Main.scala

@@ -2,7 +2,7 @@ package com.cloudera.hue.livy.server
 
 import javax.servlet.ServletContext
 
-import com.cloudera.hue.livy.server.batch.{BatchProcessFactory, BatchServlet, BatchManager}
+import com.cloudera.hue.livy.server.batch.{BatchYarnFactory, BatchProcessFactory, BatchServlet, BatchManager}
 import com.cloudera.hue.livy.server.sessions._
 import com.cloudera.hue.livy.{Utils, Logging, LivyConf, WebServer}
 import org.scalatra._
@@ -59,13 +59,20 @@ class ScalatraBootstrap extends LifeCycle with Logging {
       case "process" => new ProcessSessionFactory(livyConf)
       case "yarn" => new YarnSessionFactory(livyConf)
       case _ =>
-        println(f"Unknown session factory: $sessionFactoryKind}")
+        println(f"Unknown session factory: $sessionFactoryKind")
         sys.exit(1)
     }
 
     sessionManager = new SessionManager(sessionFactory)
 
-    val batchFactory = new BatchProcessFactory()
+    val batchFactory = sessionFactoryKind match {
+      case "thread" | "process" => new BatchProcessFactory()
+      case "yarn" => new BatchYarnFactory()
+      case _ =>
+        println(f"Unknown batch factory: $sessionFactoryKind")
+        sys.exit(1)
+    }
+
     batchManager = new BatchManager(batchFactory)
 
     context.mount(new SessionServlet(sessionManager), "/sessions/*")

+ 61 - 8
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/batch/BatchManager.scala

@@ -1,8 +1,10 @@
 package com.cloudera.hue.livy.server.batch
 
+import java.lang.ProcessBuilder.Redirect
 import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicInteger
 
+import com.cloudera.hue.livy.LineBufferedProcess
 import com.cloudera.hue.livy.spark.SparkSubmitProcessBuilder
 
 import scala.collection.JavaConversions._
@@ -56,17 +58,45 @@ abstract class BatchFactory {
 
 class BatchProcessFactory extends BatchFactory {
   def createBatch(id: Int, createBatchRequest: CreateBatchRequest): Batch =
-    BatchProcess(id, createBatchRequest)
+    BatchProcess(id, "local[*]", createBatchRequest)
+}
+
+class BatchYarnFactory extends BatchFactory {
+  def createBatch(id: Int, createBatchRequest: CreateBatchRequest): Batch =
+    BatchProcess(id, "yarn-client", createBatchRequest)
+}
+
+sealed trait State
+
+case class Running() extends State {
+  override def toString = "running"
+}
+
+case class Dead() extends State {
+  override def toString = "dead"
 }
 
 abstract class Batch {
   def id: Int
 
+  def state: State
+
+  def lines: IndexedSeq[String]
+
   def stop(): Future[Unit]
 }
 
 object BatchProcess {
-  def apply(id: Int, createBatchRequest: CreateBatchRequest): Batch = {
+  def apply(id: Int, master: String, createBatchRequest: CreateBatchRequest): Batch = {
+    val builder = sparkBuilder(createBatchRequest)
+
+    builder.master(master)
+
+    val process = builder.start(createBatchRequest.file, createBatchRequest.args)
+    new BatchProcess(id, new LineBufferedProcess(process))
+  }
+
+  private def sparkBuilder(createBatchRequest: CreateBatchRequest): SparkSubmitProcessBuilder = {
     val builder = SparkSubmitProcessBuilder()
 
     createBatchRequest.className.foreach(builder.className)
@@ -79,20 +109,43 @@ object BatchProcess {
     createBatchRequest.executorCores.foreach(builder.executorCores)
     createBatchRequest.archives.foreach(builder.archive)
 
-    val process = builder.start(createBatchRequest.file, createBatchRequest.args)
-    new BatchProcess(id, process)
+    builder.redirectOutput(Redirect.PIPE)
+
+    builder
   }
 }
 
 private class BatchProcess(val id: Int,
-                           @transient
-                           process: Process) extends Batch {
+                           process: LineBufferedProcess) extends Batch {
   protected implicit def executor: ExecutionContextExecutor = ExecutionContext.global
 
+  private[this] var isAlive = true
+
+  override def state: State = {
+    if (isAlive) {
+      try {
+        process.exitValue()
+      } catch {
+        case e: IllegalThreadStateException => return Running()
+      }
+
+      destroyProcess()
+    }
+
+    Dead()
+  }
+
+  override def lines: IndexedSeq[String] = process.stdoutLines
+
   override def stop(): Future[Unit] = {
     Future {
-      process.destroy()
-      process.waitFor()
+      destroyProcess()
     }
   }
+
+  private def destroyProcess() = {
+    process.destroy()
+    process.waitFor()
+    isAlive = false
+  }
 }

+ 23 - 3
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/batch/BatchServlet.scala

@@ -47,9 +47,14 @@ class BatchServlet(batchManager: BatchManager)
 
   val getBatch = get("/:id") {
     val id = params("id").toInt
+
     batchManager.getBatch(id) match {
       case None => NotFound("batch not found")
-      case Some(batch) => batch
+      case Some(batch) =>
+        val from = params.get("from").map(_.toInt)
+        val size = params.get("size").map(_.toInt)
+
+        Serializers.serializeBatch(batch, from, size)
     }
   }
 
@@ -82,11 +87,26 @@ private object Serializers {
 
   def Formats: List[CustomSerializer[_]] = List(BatchSerializer)
 
-  case object BatchSerializer extends CustomSerializer[Batch](implicit formats => ( {
+  def serializeBatch(batch: Batch,
+                     fromOpt: Option[Int],
+                     sizeOpt: Option[Int]): JValue = {
+    val lines = batch.lines
+    val size = sizeOpt.getOrElse(10)
+    val from = fromOpt.getOrElse(math.max(0, lines.length - 10))
+    val until = from + size
+
+    ("id", batch.id) ~
+      ("state", batch.state.toString) ~
+      ("lines", lines.slice(from, until))
+  }
+
+  case object BatchSerializer extends CustomSerializer[Batch](
+    implicit formats => ( {
     // We don't support deserialization.
     PartialFunction.empty
   }, {
-    case batch: Batch => JObject(JField("id", batch.id))
+    case batch: Batch =>
+      serializeBatch(batch, None, None)
   }
     )
   )

+ 47 - 0
apps/spark/java/livy-server/src/test/scala/com/cloudera/hue/livy/server/batches/BatchProcessSpec.scala

@@ -0,0 +1,47 @@
+package com.cloudera.hue.livy.server.batches
+
+import java.io.FileWriter
+import java.nio.file.{Files, Path}
+import java.util.concurrent.TimeUnit
+
+import com.cloudera.hue.livy.Utils
+import com.cloudera.hue.livy.server.batch.{Dead, CreateBatchRequest, BatchProcess}
+import org.scalatest.{ShouldMatchers, BeforeAndAfterAll, FunSpec}
+
+import scala.concurrent.duration.Duration
+
+class BatchProcessSpec
+  extends FunSpec
+  with BeforeAndAfterAll
+  with ShouldMatchers {
+
+  val script: Path = {
+    val script = Files.createTempFile("livy-test", ".py")
+    script.toFile.deleteOnExit()
+    val writer = new FileWriter(script.toFile)
+    try {
+      writer.write(
+        """
+          |print "hello world"
+        """.stripMargin)
+    } finally {
+      writer.close()
+    }
+    script
+  }
+
+  describe("A Batch process") {
+    it("should create a process") {
+      val req = CreateBatchRequest(
+        file = script.toString
+      )
+      val batch = BatchProcess(0, "local[*]", req)
+
+      Utils.waitUntil({ () =>
+        batch.state == Dead()
+      }, Duration(10, TimeUnit.SECONDS))
+
+      batch.lines should contain("hello world")
+    }
+  }
+}

+ 39 - 16
apps/spark/java/livy-server/src/test/scala/com/cloudera/hue/livy/server/batches/BatchServletSpec.scala

@@ -2,7 +2,9 @@ package com.cloudera.hue.livy.server.batches
 
 import java.io.FileWriter
 import java.nio.file.{Files, Path}
+import java.util.concurrent.TimeUnit
 
+import com.cloudera.hue.livy.Utils
 import com.cloudera.hue.livy.server.batch._
 import org.json4s.JsonAST.{JArray, JInt, JObject, JString}
 import org.json4s.jackson.JsonMethods._
@@ -11,33 +13,32 @@ import org.json4s.{DefaultFormats, Formats}
 import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSpecLike}
 import org.scalatra.test.scalatest.ScalatraSuite
 
+import scala.concurrent.duration.Duration
+
 class BatchServletSpec extends ScalatraSuite with FunSpecLike with BeforeAndAfterAll with BeforeAndAfter {
 
   protected implicit def jsonFormats: Formats = DefaultFormats
 
-  var script: Path = _
-
-  val batchFactory = new BatchProcessFactory()
-  val batchManager = new BatchManager(batchFactory)
-  val servlet = new BatchServlet(batchManager)
-
-  addServlet(servlet, "/*")
-
-  override def beforeAll() = {
-    super.beforeAll()
-    script = Files.createTempFile("test", "livy-test")
+  val script: Path = {
+    val script = Files.createTempFile("livy-test", ".py")
+    script.toFile.deleteOnExit()
     val writer = new FileWriter(script.toFile)
     try {
-      writer.write("print 'hello world'")
+      writer.write(
+        """
+          |print "hello world"
+        """.stripMargin)
     } finally {
       writer.close()
     }
+    script
   }
 
-  override def afterAll() = {
-    script.toFile.delete()
-    super.afterAll()
-  }
+  val batchFactory = new BatchProcessFactory()
+  val batchManager = new BatchManager(batchFactory)
+  val servlet = new BatchServlet(batchManager)
+
+  addServlet(servlet, "/*")
 
   after {
     batchManager.shutdown()
@@ -67,6 +68,28 @@ class BatchServletSpec extends ScalatraSuite with FunSpecLike with BeforeAndAfte
         batch should be (defined)
       }
 
+      // Wait for the process to finish.
+      {
+        val batch: Batch = batchManager.getBatch(0).get
+        Utils.waitUntil({ () =>
+          batch.state == Dead()
+        }, Duration(10, TimeUnit.SECONDS))
+      }
+
+      get("/0") {
+        status should equal (200)
+        header("Content-Type") should include("application/json")
+        val parsedBody = parse(body)
+        parsedBody \ "id" should equal (JInt(0))
+        parsedBody \ "state" should equal (JString("dead"))
+        parsedBody \ "lines" should equal (JArray(List(
+          JString("hello world")
+        )))
+
+        val batch = batchManager.getBatch(0)
+        batch should be (defined)
+      }
+
       delete("/0") {
         status should equal (200)
         header("Content-Type") should include("application/json")