浏览代码

[sparksql] Store Livy session details in the UserProfile (#2860)

- Global variable are not a good option to act as a caching option specially with gunicorn server. We can either use some caching service or store the session info in the Hue db to have atomicity in operations.
- Currently we are storing it in the user profile and in future we can see for some cache service like Memcached or Redis.
Harsh Gupta 3 年之前
父节点
当前提交
c3fe65ef4a

+ 59 - 28
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -21,12 +21,14 @@ import re
 import sys
 import time
 import textwrap
+import json
 
 from desktop.conf import USE_DEFAULT_CONFIGURATION
 from desktop.lib.exceptions_renderable import PopupException
 from desktop.lib.i18n import force_unicode
 from desktop.lib.rest.http_client import RestException
 from desktop.models import DefaultConfiguration
+from desktop.auth.backend import rewrite_user
 
 from notebook.data_export import download as spark_download
 from notebook.connectors.base import Api, QueryError, SessionExpired, _get_snippet_session
@@ -46,8 +48,8 @@ try:
 except ImportError as e:
   LOG.exception('Spark is not enabled')
 
-SESSIONS = {}
 SESSION_KEY = '%(username)s-%(interpreter_name)s'
+
 class SparkApi(Api):
 
   SPARK_UI_RE = re.compile("Started SparkUI at (http[s]?://([0-9a-zA-Z-_\.]+):(\d+))")
@@ -134,12 +136,12 @@ class SparkApi(Api):
 
   def create_session(self, lang='scala', properties=None):
     api = self.get_api()
-    session_key = self._get_session_key()
+    stored_session_info = self._get_session_info_from_user()
 
-    if SESSIONS.get(session_key):
-      session_present = self._check_session(SESSIONS[session_key])
+    if stored_session_info:
+      session_present = self._check_session(stored_session_info)
       if session_present:
-        return SESSIONS[session_key]
+        return stored_session_info
 
     if not properties and USE_DEFAULT_CONFIGURATION.get():
       user_config = DefaultConfiguration.objects.get_configuration_for_user(app='spark', user=self.user)
@@ -162,12 +164,14 @@ class SparkApi(Api):
       info = '\n'.join(status['log']) if status['log'] else 'timeout'
       raise QueryError(_('The Spark session is %s and could not be created in the cluster: %s') % (status['state'], info))
 
-    SESSIONS[session_key] = {
+    new_session_info = {
         'type': lang,
         'id': response['id'],
         'properties': self.to_properties(props)
     }
-    return SESSIONS[session_key]
+    self._set_session_info_to_user(new_session_info)
+
+    return new_session_info
     
 
   def execute(self, notebook, snippet):
@@ -179,11 +183,11 @@ class SparkApi(Api):
 
 
   def _execute(self, api, session, snippet_type, statement):
-    session_key = self._get_session_key()
 
     if not session or not self._check_session(session):
-      if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
-        session = SESSIONS[session_key]
+      stored_session_info = self._get_session_info_from_user()
+      if stored_session_info and self._check_session(stored_session_info):
+        session = stored_session_info
       else:
         session = self.create_session(snippet_type)
 
@@ -327,11 +331,11 @@ class SparkApi(Api):
   
 
   def _handle_session_health_check(self, session):
-    session_key = self._get_session_key()
 
     if not session or not self._check_session(session):
-      if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
-        session = SESSIONS[session_key]
+      stored_session_info = self._get_session_info_from_user()
+      if stored_session_info and self._check_session(stored_session_info):
+        session = stored_session_info
       else:
         raise PopupException(_("Session expired. Please create new session and try again."))
     
@@ -344,7 +348,6 @@ class SparkApi(Api):
 
   def close_session(self, session):
     api = self.get_api()
-    session_key = self._get_session_key()
 
     if session['id'] is not None:
       try:
@@ -357,8 +360,9 @@ class SparkApi(Api):
         if e.code == 404 or e.code == 500: # TODO remove the 500
           raise SessionExpired(e)
       finally:
-        if SESSIONS.get(session_key) and session['id'] == SESSIONS[session_key]['id']:
-          del SESSIONS[session_key]
+        stored_session_info = self._get_session_info_from_user()
+        if stored_session_info and session['id'] == stored_session_info['id']:
+          self._remove_session_info_from_user()
     else:
       return {'status': -1}
 
@@ -376,17 +380,16 @@ class SparkApi(Api):
     response = {}
     # As booting a new SQL session is slow and we don't send the id of the current one in /autocomplete
     # we could implement this by introducing an API cache per user similarly to SqlAlchemy.
-
     api = self.get_api()
-    session_key = self._get_session_key()
 
     # Trying to close unused sessions if there are any.
     # Calling the method here since this /autocomplete call can be frequent enough and we dont need dedicated one.
-    if SESSIONS.get(session_key):
+    if self._get_session_info_from_user():
       self._close_unused_sessions()
     
-    if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
-      session = SESSIONS[session_key]
+    stored_session_info = self._get_session_info_from_user()
+    if stored_session_info and self._check_session(stored_session_info):
+      session = stored_session_info
     else:
       session = self.create_session(snippet.get('type'))
 
@@ -413,9 +416,8 @@ class SparkApi(Api):
     Closes all unused Livy sessions for a particular user to free up session resources.
     '''
     api = self.get_api()
-    session_key = self._get_session_key()
+    all_sessions = {}
 
-    all_session = {}
     try:
       all_sessions = api.get_sessions()
     except Exception as e:
@@ -423,8 +425,9 @@ class SparkApi(Api):
       LOG.debug(message)
 
     if all_sessions:
+      stored_session_info = self._get_session_info_from_user()
       for session in all_sessions['sessions']:
-        if session['owner'] == self.user.username and session['id'] != SESSIONS[session_key]['id'] and \
+        if session['owner'] == self.user.username and session['id'] != stored_session_info['id'] and \
           session['state'] in ('idle', 'shutting_down', 'error', 'dead', 'killed'):
           self.close_session(session)
 
@@ -478,15 +481,15 @@ class SparkApi(Api):
 
   def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
     api = self.get_api()
-    session_key = self._get_session_key()
 
     # Trying to close unused sessions if there are any.
     # Calling the method here since this /sample_data call can be frequent enough and we dont need dedicated one.
-    if SESSIONS.get(session_key):
+    if self._get_session_info_from_user():
       self._close_unused_sessions()
 
-    if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
-      session = SESSIONS[session_key]
+    stored_session_info = self._get_session_info_from_user()
+    if stored_session_info and self._check_session(stored_session_info):
+      session = stored_session_info
     else:
       session = self.create_session(snippet.get('type'))
 
@@ -573,6 +576,34 @@ class SparkApi(Api):
     return LIVY_SERVER_SESSION_KIND.get() == "yarn"
 
 
+  def _get_session_info_from_user(self):
+    self.user = rewrite_user(self.user)
+    session_key = self._get_session_key()
+
+    if self.user.profile.data.get(session_key):
+      return self.user.profile.data[session_key]
+
+
+  def _set_session_info_to_user(self, session_info):
+    self.user = rewrite_user(self.user)
+    session_key = self._get_session_key()
+
+    self.user.profile.update_data({session_key: session_info})
+    self.user.profile.save()
+
+
+  def _remove_session_info_from_user(self):
+    self.user = rewrite_user(self.user)
+    session_key = self._get_session_key()
+
+    if self.user.profile.data.get(session_key):
+      json_data = self.user.profile.data
+      json_data.pop(session_key)
+      self.user.profile.json_data = json.dumps(json_data)
+    
+    self.user.profile.save()
+
+
 class SparkConfiguration(object):
 
   APP_NAME = 'spark'

+ 9 - 13
desktop/libs/notebook/src/notebook/connectors/spark_shell_tests.py

@@ -20,7 +20,10 @@ import sys
 from builtins import object
 from nose.tools import assert_equal, assert_true, assert_false, assert_raises
 
-from notebook.connectors.spark_shell import SparkApi, SESSIONS
+from desktop.lib.django_test_util import make_logged_in_client
+from useradmin.models import User
+
+from notebook.connectors.spark_shell import SparkApi
 
 if sys.version_info[0] > 2:
   from unittest.mock import patch, Mock
@@ -31,7 +34,9 @@ else:
 class TestSparkApi(object):
 
   def setUp(self):
-    self.user = 'hue_test'
+    self.client = make_logged_in_client(username="hue_test", groupname="default", recreate=True, is_superuser=False)
+    self.user = User.objects.get(username="hue_test")
+
     self.interpreter = {
         'name': 'livy',
         'options': {
@@ -94,8 +99,8 @@ class TestSparkApi(object):
               cores = p['value']
           assert_equal(cores, 2)
 
-          if SESSIONS.get(session_key):
-            del SESSIONS[session_key]
+          if self.api._get_session_info_from_user():
+            self.api._remove_session_info_from_user()
 
           # Case without user configuration. Expected 1 driverCores
           USE_DEFAULT_CONFIGURATION.get.return_value = True
@@ -110,9 +115,6 @@ class TestSparkApi(object):
               cores = p['value']
           assert_equal(cores, 1)
 
-          if SESSIONS.get(session_key):
-            del SESSIONS[session_key]
-
           # Case with no user configuration. Expected 1 driverCores
           USE_DEFAULT_CONFIGURATION.get.return_value = False
           session3 = self.api.create_session(lang=lang, properties=properties)
@@ -125,9 +127,6 @@ class TestSparkApi(object):
               cores = p['value']
           assert_equal(cores, 1)
 
-          if SESSIONS.get(session_key):
-            del SESSIONS[session_key]
-
 
   def test_create_session_plain(self):
     lang = 'pyspark'
@@ -153,9 +152,6 @@ class TestSparkApi(object):
       assert_true(files_properties, session['properties'])
       assert_equal(files_properties[0]['value'], [], session['properties'])
 
-      if SESSIONS.get(session_key):
-        del SESSIONS[session_key]
-
 
   def test_execute(self):
     with patch('notebook.connectors.spark_shell._get_snippet_session') as _get_snippet_session: