浏览代码

[sparksql] Improve session reuse and fix corner cases (#2851)

- Improve session handling
- Fix failing corner cases
- Add checks for different session states
- Cancel statement improvements
- Fix failing UTs
Harsh Gupta 3 年之前
父节点
当前提交
33d4f05497

+ 3 - 0
apps/spark/src/spark/livy_client.py

@@ -153,6 +153,9 @@ class LivyClient(object):
   def get_batches(self):
     return self._root.get('batches')
 
+  def cancel_statement(self, session, statement_id):
+    return self._root.post('sessions/%s/statements/%s/cancel' % (session, statement_id))
+
   def submit_batch(self, properties):
     properties['proxyUser'] = self.user
     return self._root.post('batches', data=json.dumps(properties), contenttype=_JSON_CONTENT_TYPE)

+ 92 - 29
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -118,13 +118,26 @@ class SparkApi(Api):
     }
 
 
+  def _check_session(self, session):
+    '''
+    Check if the session is actually present and its state is healthy.
+    '''
+    api = self.get_api()
+    try:
+      session_present = api.get_session(session['id'])
+    except Exception as e:
+      session_present = None
+
+    if session_present and session_present['state'] not in ('dead', 'shutting_down', 'error', 'killed'):
+      return session_present
+
+
   def create_session(self, lang='scala', properties=None):
     api = self.get_api()
     session_key = self._get_session_key()
 
     if SESSIONS.get(session_key):
-      # Checking if the session is actually present to avoid stale value
-      session_present = api.get_session(SESSIONS[session_key]['id'])
+      session_present = self._check_session(SESSIONS[session_key])
       if session_present:
         return SESSIONS[session_key]
 
@@ -161,15 +174,18 @@ class SparkApi(Api):
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
 
-    response = self._execute(api, session, snippet['statement'])
+    response = self._execute(api, session, snippet.get('type'), snippet['statement'])
     return response
 
 
-  def _execute(self, api, session, statement):
+  def _execute(self, api, session, snippet_type, statement):
     session_key = self._get_session_key()
 
-    if session['id'] is None and SESSIONS.get(session_key) is not None:
-      session = SESSIONS[session_key]
+    if not session or not self._check_session(session):
+      if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
+        session = SESSIONS[session_key]
+      else:
+        session = self.create_session(snippet_type)
 
     try:
       response = api.submit_statement(session['id'], statement)
@@ -191,6 +207,8 @@ class SparkApi(Api):
     session = _get_snippet_session(notebook, snippet)
     cell = snippet['result']['handle']['id']
 
+    session = self._handle_session_health_check(session)
+
     try:
       response = api.fetch_data(session['id'], cell)
       return {
@@ -209,6 +227,8 @@ class SparkApi(Api):
     session = _get_snippet_session(notebook, snippet)
     cell = snippet['result']['handle']['id']
 
+    session = self._handle_session_health_check(session)
+
     response = self._fetch_result(api, session, cell, start_over)
     return response
 
@@ -279,16 +299,43 @@ class SparkApi(Api):
   def cancel(self, notebook, snippet):
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
-    response = api.cancel(session['id'])
+
+    session = self._handle_session_health_check(session)
+
+    try:
+      response = api.cancel(session['id'])
+    except Exception as e:
+      message = force_unicode(str(e)).lower()
+      LOG.debug(message)
 
     return {'status': 0}
 
 
   def get_log(self, notebook, snippet, startFrom=0, size=None):
+    response = {'status': 0}
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
 
-    return api.get_log(session['id'], startFrom=startFrom, size=size)
+    session = self._handle_session_health_check(session)
+    try:
+      response = api.get_log(session['id'], startFrom=startFrom, size=size)
+    except RestException as e:
+      message = force_unicode(str(e)).lower()
+      LOG.debug(message)
+
+    return response
+  
+
+  def _handle_session_health_check(self, session):
+    session_key = self._get_session_key()
+
+    if not session or not self._check_session(session):
+      if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
+        session = SESSIONS[session_key]
+      else:
+        raise PopupException(_("Session expired. Please create new session and try again."))
+    
+    return session
 
 
   def close_statement(self, notebook, snippet): # Individual statements cannot be closed
@@ -327,9 +374,9 @@ class SparkApi(Api):
 
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
     response = {}
-
     # As booting a new SQL session is slow and we don't send the id of the current one in /autocomplete
     # we could implement this by introducing an API cache per user similarly to SqlAlchemy.
+
     api = self.get_api()
     session_key = self._get_session_key()
 
@@ -338,14 +385,17 @@ class SparkApi(Api):
     if SESSIONS.get(session_key):
       self._close_unused_sessions()
     
-    session = SESSIONS[session_key] if SESSIONS.get(session_key) else self.create_session(snippet.get('type'))
+    if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
+      session = SESSIONS[session_key]
+    else:
+      session = self.create_session(snippet.get('type'))
 
     if database is None:
-      response['databases'] = self._show_databases(api, session)
+      response['databases'] = self._show_databases(api, session, snippet.get('type'))
     elif table is None:
-      response['tables_meta'] = self._show_tables(api, session, database)
+      response['tables_meta'] = self._show_tables(api, session, snippet.get('type'), database)
     elif column is None:
-      columns = self._get_columns(api, session, database, table)
+      columns = self._get_columns(api, session, snippet.get('type'), database, table)
       response['columns'] = [col['name'] for col in columns]
       response['extended_columns'] = [{
           'comment': col.get('comment'),
@@ -360,52 +410,62 @@ class SparkApi(Api):
 
   def _close_unused_sessions(self):
     '''
-    Closes all unsused Livy sessions for a particular user to free up session resources.
+    Closes all unused Livy sessions for a particular user to free up session resources.
     '''
     api = self.get_api()
     session_key = self._get_session_key()
 
-    all_sessions = api.get_sessions()
-    for session in all_sessions['sessions']:
-      if session['owner'] == self.user.username and session['id'] != SESSIONS[session_key]['id']:
-        self.close_session(session)
+    all_session = {}
+    try:
+      all_sessions = api.get_sessions()
+    except Exception as e:
+      message = force_unicode(str(e)).lower()
+      LOG.debug(message)
+
+    if all_sessions:
+      for session in all_sessions['sessions']:
+        if session['owner'] == self.user.username and session['id'] != SESSIONS[session_key]['id'] and \
+          session['state'] in ('idle', 'shutting_down', 'error', 'dead', 'killed'):
+          self.close_session(session)
 
 
   def _check_status_and_fetch_result(self, api, session, execute_resp):
     check_status = api.fetch_data(session['id'], execute_resp['id'])
 
-    while check_status['state'] in ['running', 'waiting']:
+    count = 0
+    while check_status['state'] in ['running', 'waiting'] and count < 120:
       check_status = api.fetch_data(session['id'], execute_resp['id'])
+      count += 1
       time.sleep(1)
 
     if check_status['state'] == 'available':
       return self._fetch_result(api, session, execute_resp['id'], start_over=True)
 
 
-  def _show_databases(self, api, session):
-    show_db_execute = self._execute(api, session, 'SHOW DATABASES')
+  def _show_databases(self, api, session, snippet_type):
+    show_db_execute = self._execute(api, session, snippet_type, 'SHOW DATABASES')
     db_list = self._check_status_and_fetch_result(api, session, show_db_execute)
 
     if db_list:
       return [db[0] for db in db_list['data']]
 
 
-  def _show_tables(self, api, session, database):
-    use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
+  def _show_tables(self, api, session, snippet_type, database):
+    use_db_execute = self._execute(api, session, snippet_type, 'USE %(database)s' % {'database': database})
     use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
 
-    show_tables_execute = self._execute(api, session, 'SHOW TABLES')
+    show_tables_execute = self._execute(api, session, snippet_type, 'SHOW TABLES')
     tables_list = self._check_status_and_fetch_result(api, session, show_tables_execute)
 
     if tables_list:
       return [table[1] for table in tables_list['data']]
 
 
-  def _get_columns(self, api, session, database, table):
-    use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
+  def _get_columns(self, api, session, snippet_type, database, table):
+    use_db_execute = self._execute(api, session, snippet_type, 'USE %(database)s' % {'database': database})
     use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
 
-    describe_tables_execute = self._execute(api, session, 'DESCRIBE %(table)s' % {'table': table})
+    describe_tables_execute = self._execute(api, session, snippet_type, 'DESCRIBE %(table)s' % {'table': table})
     columns_list = self._check_status_and_fetch_result(api, session, describe_tables_execute)
 
     if columns_list:
@@ -425,11 +485,14 @@ class SparkApi(Api):
     if SESSIONS.get(session_key):
       self._close_unused_sessions()
 
-    session = SESSIONS[session_key] if SESSIONS.get(session_key) else self.create_session(snippet.get('type'))
+    if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
+      session = SESSIONS[session_key]
+    else:
+      session = self.create_session(snippet.get('type'))
 
     statement = self._get_select_query(database, table, column, operation)
 
-    sample_execute = self._execute(api, session, statement)
+    sample_execute = self._execute(api, session, snippet.get('type'), statement)
     sample_result = self._check_status_and_fetch_result(api, session, sample_execute)
 
     response = {

+ 2 - 0
desktop/libs/notebook/src/notebook/connectors/spark_shell_tests.py

@@ -169,6 +169,7 @@ class TestSparkApi(object):
             return_value={'id': 'test_id'}
           )
         )
+        self.api._check_session = Mock(return_value={'id': '1'})
 
         response = self.api.execute(notebook, snippet)
         assert_equal(response['id'], 'test_id')
@@ -197,6 +198,7 @@ class TestSparkApi(object):
             return_value={'state': 'test_state'}
           )
         )
+        self.api._handle_session_health_check = Mock(return_value={'id': '1'})
 
         response = self.api.check_status(notebook, snippet)
         assert_equal(response['status'], 'test_state')