Browse Source

[livy] Add pyspark support

Erick Tryzelaar 10 years ago
parent
commit
42218ce

+ 22 - 0
apps/spark/java/livy-repl/src/main/resources/fake_pyspark.sh

@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+
+set -e
+
+if [ -z "$SPARK_HOME" ]; then
+	echo "\$SPARK_HOME is not set" 1>&2
+	exit 1
+fi
+
+source "$SPARK_HOME"/bin/utils.sh
+source "$SPARK_HOME"/bin/load-spark-env.sh
+
+export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH"
+
+for path in $(ls $SPARK_HOME/python/lib/*.zip); do
+	export PYTHONPATH="$path:$PYTHONPATH"
+done
+
+export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
+export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py"
+
+exec python livy-repl/src/main/resources/fake_shell.py

+ 26 - 14
apps/spark/java/livy-repl/src/main/resources/fake_shell.py

@@ -4,24 +4,13 @@ import datetime
 import decimal
 import json
 import logging
+import os
 import sys
 import traceback
 
 logging.basicConfig()
 logger = logging.getLogger('fake_shell')
 
-sys_stdin = sys.stdin
-sys_stdout = sys.stdout
-sys_stderr = sys.stderr
-
-fake_stdin = cStringIO.StringIO()
-fake_stdout = cStringIO.StringIO()
-fake_stderr = cStringIO.StringIO()
-
-sys.stdin = fake_stdin
-sys.stdout = fake_stdout
-sys.stderr = fake_stderr
-
 global_dict = {}
 
 execution_count = 0
@@ -72,7 +61,10 @@ def execute(code):
         return execute_reply_error(*sys.exc_info())
 
     stdout = fake_stdout.getvalue()
+    fake_stdout.truncate(0)
+
     stderr = fake_stderr.getvalue()
+    fake_stderr.truncate(0)
 
     output = ''
 
@@ -253,11 +245,31 @@ msg_type_router = {
     'execute_request': execute_request,
 }
 
+sys_stdin = sys.stdin
+sys_stdout = sys.stdout
+sys_stderr = sys.stderr
+
+fake_stdin = cStringIO.StringIO()
+fake_stdout = cStringIO.StringIO()
+fake_stderr = cStringIO.StringIO()
+
+sys.stdin = fake_stdin
+sys.stdout = fake_stdout
+sys.stderr = fake_stderr
 
 try:
-    while True:
-        fake_stdout.truncate(0)
+    # Load any startup files
+    try:
+        startup = os.environ['PYTHONSTARTUP']
+    except KeyError:
+        pass
+    else:
+        execfile(startup, global_dict)
+
+    fake_stdout.truncate(0)
+    fake_stderr.truncate(0)
 
+    while True:
         line = sys_stdin.readline()
 
         if line == '':

+ 9 - 5
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/Main.scala

@@ -3,7 +3,7 @@ package com.cloudera.hue.livy.repl
 import javax.servlet.ServletContext
 
 import com.cloudera.hue.livy.repl.python.PythonSession
-import com.cloudera.hue.livy.repl.scala.ScalaSession
+import com.cloudera.hue.livy.repl.scala.SparkSession
 import com.cloudera.hue.livy.{Logging, WebServer}
 import org.scalatra.LifeCycle
 import org.scalatra.servlet.ScalatraListener
@@ -12,20 +12,22 @@ object Main extends Logging {
 
   val SESSION_KIND = "livy-repl.session.kind"
   val PYTHON_SESSION = "python"
+  val PYSPARK_SESSION = "pyspark"
   val SCALA_SESSION = "scala"
+  val SPARK_SESSION = "spark"
 
   def main(args: Array[String]): Unit = {
     val port = sys.env.getOrElse("PORT", "8999").toInt
 
     if (args.length != 1) {
-      println("Must specify either `python` or `scala` for the session kind")
+      println("Must specify either `python`/`pyspark`/`scala/`spark` for the session kind")
       sys.exit(1)
     }
 
     val session_kind = args(0)
 
     session_kind match {
-      case PYTHON_SESSION | SCALA_SESSION =>
+      case PYTHON_SESSION | PYSPARK_SESSION | SPARK_SESSION =>
       case _ =>
         println("Unknown session kind: " + session_kind)
         sys.exit(1)
@@ -52,8 +54,10 @@ class ScalatraBootstrap extends LifeCycle {
 
   override def init(context: ServletContext): Unit = {
     val session = context.getInitParameter(Main.SESSION_KIND) match {
-      case Main.PYTHON_SESSION => PythonSession.create()
-      case Main.SCALA_SESSION => ScalaSession.create()
+      case Main.PYTHON_SESSION => PythonSession.createPySpark()
+      case Main.PYSPARK_SESSION => PythonSession.createPySpark()
+      case Main.SCALA_SESSION => SparkSession.create()
+      case Main.SPARK_SESSION => SparkSession.create()
     }
 
     context.mount(new WebApp(session), "/*")

+ 27 - 24
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/python/PythonSession.scala

@@ -14,9 +14,17 @@ import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.{ExecutionContext, Future}
 
 object PythonSession {
-  def create(): Session = {
-    val file = createScript()
-    val pb = new ProcessBuilder("python", file.toString)
+  def createPython(): Session = {
+    create("python")
+  }
+
+  def createPySpark(): Session = {
+    create(createFakePySpark().toString)
+  }
+
+  private def create(driver: String) = {
+    val fakeShell = createFakeShell()
+    val pb = new ProcessBuilder(driver, fakeShell.toString)
     pb.redirectError(Redirect.INHERIT)
     val process = pb.start()
     val in = process.getInputStream
@@ -25,7 +33,7 @@ object PythonSession {
     new PythonSession(process, in, out)
   }
 
-  private def createScript(): File = {
+  private def createFakeShell(): File = {
     val source: InputStream = getClass.getClassLoader.getResourceAsStream("fake_shell.py")
 
     val file = Files.createTempFile("", "").toFile
@@ -46,32 +54,27 @@ object PythonSession {
     file
   }
 
-  // Java unfortunately wraps the input stream in a buffer, so we need to hack around it so we can read the output
-  // without blocking.
-  private def unwrapInputStream(inputStream: InputStream) = {
-    var filteredInputStream = inputStream
+  private def createFakePySpark(): File = {
+    val source: InputStream = getClass.getClassLoader.getResourceAsStream("fake_pyspark.sh")
 
-    while (filteredInputStream.isInstanceOf[FilterInputStream]) {
-      val field = classOf[FilterInputStream].getDeclaredField("in")
-      field.setAccessible(true)
-      filteredInputStream = field.get(filteredInputStream).asInstanceOf[InputStream]
-    }
+    val file = Files.createTempFile("", "").toFile
+    file.deleteOnExit()
 
-    filteredInputStream
-  }
+    file.setExecutable(true)
 
-  // Java unfortunately wraps the output stream in a buffer, so we need to hack around it so we can read the output
-  // without blocking.
-  private def unwrapOutputStream(outputStream: OutputStream) = {
-    var filteredOutputStream = outputStream
+    val sink = new FileOutputStream(file)
+    val buf = new Array[Byte](1024)
+    var n = source.read(buf)
 
-    while (filteredOutputStream.isInstanceOf[FilterOutputStream]) {
-      val field = classOf[FilterOutputStream].getDeclaredField("out")
-      field.setAccessible(true)
-      filteredOutputStream = field.get(filteredOutputStream).asInstanceOf[OutputStream]
+    while (n > 0) {
+      sink.write(buf, 0, n)
+      n = source.read(buf)
     }
 
-    filteredOutputStream
+    source.close()
+    sink.close()
+
+    file
   }
 }
 

+ 3 - 3
apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/ScalaSession.scala → apps/spark/java/livy-repl/src/main/scala/com/cloudera/hue/livy/repl/scala/SparkSession.scala

@@ -12,11 +12,11 @@ import scala.collection.mutable
 import scala.concurrent.duration.Duration
 import scala.concurrent.{Await, ExecutionContext, Future, Promise}
 
-object ScalaSession {
-  def create(): Session = new ScalaSession()
+object SparkSession {
+  def create(): Session = new SparkSession()
 }
 
-private class ScalaSession extends Session {
+private class SparkSession extends Session {
   private implicit def executor: ExecutionContext = ExecutionContext.global
 
   implicit val formats = DefaultFormats

+ 2 - 0
apps/spark/java/livy-server/src/main/scala/com/cloudera/hue/livy/server/WebApp.scala

@@ -47,6 +47,8 @@ class WebApp(sessionManager: SessionManager)
 
     val sessionFuture = createSessionRequest.lang match {
       case "scala" => sessionManager.createSession(createSessionRequest.lang)
+      case "spark" => sessionManager.createSession(createSessionRequest.lang)
+      case "pyspark" => sessionManager.createSession(createSessionRequest.lang)
       case "python" => sessionManager.createSession(createSessionRequest.lang)
       case lang => halt(400, "unsupported language: " + lang)
     }