Эх сурвалжийг харах

[spark] Initial progress on a SparkYarnSession

Erick Tryzelaar 11 жил өмнө
parent
commit
4df0913

+ 6 - 0
apps/spark/java/livy-server/pom.xml

@@ -64,6 +64,12 @@
             <version>${project.version}</version>
         </dependency>
 
+        <dependency>
+            <groupId>com.cloudera.hue.livy</groupId>
+            <artifactId>livy-yarn</artifactId>
+            <version>${project.version}</version>
+        </dependency>
+
         <dependency>
             <groupId>com.fasterxml.jackson.core</groupId>
             <artifactId>jackson-databind</artifactId>

+ 1 - 1
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/Main.scala

@@ -23,7 +23,7 @@ object Main {
 
 class ScalatraBootstrap extends LifeCycle {
 
-  val sessionFactory = new ProcessSessionFactory
+  val sessionFactory = new YarnSessionFactory
   val sessionManager = new SessionManager(sessionFactory)
 
   override def init(context: ServletContext): Unit = {

+ 24 - 2
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SessionFactory.scala

@@ -2,20 +2,42 @@ package com.cloudera.hue.livy.server
 
 import java.util.UUID
 
+import com.cloudera.hue.livy.yarn.Client
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+
 import scala.concurrent.{ExecutionContext, Future}
 
 trait SessionFactory {
-  def createSparkSession: Future[Session]
+  def createSparkSession(): Future[Session]
+
+  def close(): Unit = {}
 }
 
 class ProcessSessionFactory extends SessionFactory {
 
   implicit def executor: ExecutionContext = ExecutionContext.global
 
-  override def createSparkSession: Future[Session] = {
+  override def createSparkSession(): Future[Session] = {
     Future {
       val id = UUID.randomUUID().toString
       SparkProcessSession.create(id)
     }
   }
 }
+
+class YarnSessionFactory extends SessionFactory {
+
+  val yarnConf = new YarnConfiguration()
+  yarnConf.set("yarn.resourcemanager.am.max-attempts", "1")
+
+  val client = new Client(yarnConf)
+
+  override def createSparkSession(): Future[Session] = {
+    val id = UUID.randomUUID().toString
+    SparkYarnSession.create(client, id)
+  }
+
+  override def close(): Unit = {
+    client.close()
+  }
+}

+ 52 - 0
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/SparkYarnSession.scala

@@ -0,0 +1,52 @@
+package com.cloudera.hue.livy.server
+
+import com.cloudera.hue.livy.yarn.{Client, Job}
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.yarn.api.ApplicationConstants
+
+import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future, TimeoutException}
+
+object SparkYarnSession {
+  private val LIVY_YARN_PACKAGE = System.getenv("LIVY_YARN_PACKAGE")
+
+  protected implicit def executor: ExecutionContextExecutor = ExecutionContext.global
+
+  def create(client: Client, id: String): Future[Session] = {
+    val packagePath = new Path(LIVY_YARN_PACKAGE)
+
+    val job = client.submitApplication(
+      packagePath,
+      List(
+        "__package/bin/run-am.sh 1>%s/stdout 2>%s/stderr" format (
+          ApplicationConstants.LOG_DIR_EXPANSION_VAR,
+          ApplicationConstants.LOG_DIR_EXPANSION_VAR
+          )
+      )
+    )
+
+    Future {
+      var x = job.waitForRPC(10000)
+
+      println("x: %s" format x)
+
+      x match {
+        case Some((hostname, port)) =>
+          new SparkYarnSession(id, job, hostname, port)
+        case None =>
+          throw new TimeoutException()
+      }
+    }
+  }
+}
+
+private class SparkYarnSession(id: String, job: Job, hostname: String, port: Int)
+  extends SparkWebSession(id, hostname, port) {
+
+  override def close(): Future[Unit] = {
+    super.close() andThen { case r =>
+      job.waitForFinish(10000)
+      r
+    }
+  }
+
+}

+ 17 - 0
apps/spark/java/livy-yarn/src/main/scala/com/cloudera/hue/livy/yarn/Client.scala

@@ -163,6 +163,23 @@ class Job(client: YarnClient, appId: ApplicationId) {
     None
   }
 
+  def waitForRPC(timeoutMs: Long): Option[(String, Int)] = {
+    waitForStatus(Running(), timeoutMs)
+
+    val startTimeMs = System.currentTimeMillis()
+
+    while (System.currentTimeMillis() - startTimeMs < timeoutMs) {
+      val statusResponse = client.getApplicationReport(appId)
+
+      (statusResponse.getHost, statusResponse.getRpcPort) match {
+        case ("N/A", _) | (_, -1) =>
+        case (hostname, port) => return Some((hostname, port))
+      }
+    }
+
+    None
+  }
+
   def getHost: String = {
     val statusResponse = client.getApplicationReport(appId)
     statusResponse.getHost