Browse Source

[spark] Keep 5 spark sessions precached to speed up start time

Erick Tryzelaar 11 years ago
parent
commit
aeeaa4d

+ 3 - 3
apps/spark/java/sparker-server/src/main/java/com/cloudera/hue/sparker/server/resources/SessionResource.java

@@ -51,7 +51,7 @@ public class SessionResource {
     @Timed
     public Response createSession(@QueryParam("lang") String language,
                                   @Context HttpServletRequest request) throws IOException, InterruptedException, URISyntaxException {
-        int sessionType;
+        SessionManager.SessionType sessionType;
 
         if (language == null) {
             Response resp = new ResponseBuilderImpl().status(400).entity("missing language").build();
@@ -59,9 +59,9 @@ public class SessionResource {
         }
 
         if (language.equals(SCALA)) {
-            sessionType = SessionManager.SCALA;
+            sessionType = SessionManager.SessionType.SCALA;
         } else if (language.equals(PYTHON)) {
-            sessionType = SessionManager.PYTHON;
+            sessionType = SessionManager.SessionType.PYTHON;
         } else {
             Response resp = new ResponseBuilderImpl().status(400).entity("invalid language").build();
             throw new WebApplicationException(resp);

+ 66 - 19
apps/spark/java/sparker-server/src/main/java/com/cloudera/hue/sparker/server/sessions/SessionManager.java

@@ -18,24 +18,35 @@
 
 package com.cloudera.hue.sparker.server.sessions;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import java.io.IOException;
 import java.util.Enumeration;
 import java.util.UUID;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.TimeoutException;
+import java.util.concurrent.*;
 
 public class SessionManager {
 
-    public static final int UNKNOWN = 0;
-    public static final int SCALA = 1;
-    public static final int PYTHON = 2;
+    private static final Logger LOG = LoggerFactory.getLogger(SparkSession.class);
+
+    public enum SessionType {
+        SCALA,
+        PYTHON,
+    }
 
     private ConcurrentHashMap<String, Session> sessions = new ConcurrentHashMap<String, Session>();
+    private BlockingQueue<Session> freshScalaSessions = new LinkedBlockingQueue<Session>(5);
+
+    SessionManagerGarbageCollector gcThread = new SessionManagerGarbageCollector();
+    SessionCreator creatorThread = new SessionCreator(SessionType.SCALA);
 
     public SessionManager() {
-        SessionManagerGarbageCollector gc = new SessionManagerGarbageCollector(this);
-        gc.setDaemon(true);
-        gc.start();
+        gcThread.setDaemon(true);
+        gcThread.start();
+
+        creatorThread.setDaemon(true);
+        creatorThread.start();
     }
 
     public Session get(String id) throws SessionNotFound {
@@ -46,15 +57,14 @@ public class SessionManager {
         return session;
     }
 
-    public Session create(int language) throws IllegalArgumentException, IOException, InterruptedException {
-        String id = UUID.randomUUID().toString();
+    public Session create(SessionType type) throws IllegalArgumentException, IOException, InterruptedException {
         Session session;
-        switch (language) {
-            case SCALA:  session = new SparkSession(id); break;
+        switch (type) {
+            case SCALA: session = freshScalaSessions.take(); break;
             //case PYTHON: session = new PySparkSession(id); break;
             default: throw new IllegalArgumentException("Invalid language specified for shell session");
         }
-        sessions.put(id, session);
+        sessions.put(session.getId(), session);
         return session;
     }
 
@@ -63,12 +73,23 @@ public class SessionManager {
             sessions.remove(session.getId());
             session.close();
         }
+
+        gcThread.interrupt();
+        gcThread.join();
+        creatorThread.interrupt();
+        creatorThread.join();
+
+        Session session;
+        while ((session = freshScalaSessions.poll(500, TimeUnit.MILLISECONDS)) != null) {
+            session.close();
+        }
     }
 
     public void close(String id) throws InterruptedException, TimeoutException, IOException, SessionNotFound {
         Session session = this.get(id);
         sessions.remove(id);
         session.close();
+
     }
 
     public Enumeration<String> getSessionIds() {
@@ -89,21 +110,18 @@ public class SessionManager {
         }
     }
 
-    protected class SessionManagerGarbageCollector extends Thread {
-
-        protected SessionManager manager;
+    private class SessionManagerGarbageCollector extends Thread {
 
         protected long period = 600000; // Time in milliseconds; TODO: make configurable
 
-        public SessionManagerGarbageCollector(SessionManager manager) {
+        public SessionManagerGarbageCollector() {
             super();
-            this.manager = manager;
         }
 
         public void run() {
             try {
                 while(true) {
-                    manager.garbageCollect();
+                    garbageCollect();
                     sleep(period);
                 }
             } catch (InterruptedException e) {
@@ -121,4 +139,33 @@ public class SessionManager {
             super(id);
         }
     }
+
+    private class SessionCreator extends Thread {
+        SessionType type;
+
+        public SessionCreator(SessionType type) {
+            this.type = type;
+        }
+
+        public void run() {
+            try {
+                while(true) {
+                    String id = UUID.randomUUID().toString();
+
+                    Session session;
+                    switch (type) {
+                        case SCALA: session = new SparkSession(id); break;
+                        //case PYTHON: session = new PythonSession(id); break;
+                        default: throw new IllegalArgumentException("Invalid language specified for shell session");
+                    }
+
+                    freshScalaSessions.put(session);
+                }
+            } catch (InterruptedException e) {
+                e.printStackTrace();
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+        }
+    }
 }