Bläddra i källkod

[sparksql] Enhance SparkSQL support via Apache Livy for Hue (#2794)

* [sparksql] Enhance SparkSQL support via Apache Livy for Hue

- Reuse Livy session per user similarly to SQLAlchemy
- Add support for /autocomplete, /get_sample_data and other APIs
- Add few UTs and update existing ones

NOTE: In the start, 2 sessions are being created, one for /autocomplete and other for /create_session. This is because /autocomplete call is always first in the call stack from the UI, plus this will help if user just want to access other tabs like Table Browser where we don't do /create_session call. But since we are reusing the session per user, eventually we only use one of the sessions afterwards for every user activity regarding SparkSQL.
Harsh Gupta 3 år sedan
förälder
incheckning
e751b863ca

+ 154 - 11
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -20,6 +20,7 @@ import logging
 import re
 import sys
 import time
+import textwrap
 
 from desktop.conf import USE_DEFAULT_CONFIGURATION
 from desktop.lib.exceptions_renderable import PopupException
@@ -45,19 +46,23 @@ try:
 except ImportError as e:
   LOG.exception('Spark is not enabled')
 
-
+SESSIONS = {}
+SESSION_KEY = '%(username)s-%(interpreter_name)s'
 class SparkApi(Api):
 
   SPARK_UI_RE = re.compile("Started SparkUI at (http[s]?://([0-9a-zA-Z-_\.]+):(\d+))")
   YARN_JOB_RE = re.compile("tracking URL: (http[s]?://.+/)")
   STANDALONE_JOB_RE = re.compile("Got job (\d+)")
 
+
   def __init__(self, user, interpreter):
     super(SparkApi, self).__init__(user=user, interpreter=interpreter)
 
+
   def get_api(self):
     return get_spark_api(self.user, self.interpreter)
 
+
   @staticmethod
   def get_livy_props(lang, properties=None):
     props = dict([(p['name'], p['value']) for p in SparkConfiguration.PROPERTIES])
@@ -91,6 +96,7 @@ class SparkApi(Api):
 
     return props
 
+
   @staticmethod
   def to_properties(props=None):
     properties = list()
@@ -104,7 +110,21 @@ class SparkApi(Api):
 
     return properties
 
+
+  def _get_session_key(self):
+    return SESSION_KEY % {
+      'username': self.user.username if hasattr(self.user, 'username') else self.user,
+      'interpreter_name': self.interpreter['name']
+    }
+
+
   def create_session(self, lang='scala', properties=None):
+    api = self.get_api()
+    session_key = self._get_session_key()
+
+    if SESSIONS.get(session_key):
+      return SESSIONS[session_key]
+
     if not properties and USE_DEFAULT_CONFIGURATION.get():
       user_config = DefaultConfiguration.objects.get_configuration_for_user(app='spark', user=self.user)
       if user_config is not None:
@@ -112,7 +132,6 @@ class SparkApi(Api):
 
     props = self.get_livy_props(lang, properties)
 
-    api = get_spark_api(self.user)
     response = api.create_session(**props)
 
     status = api.get_session(response['id'])
@@ -127,18 +146,30 @@ class SparkApi(Api):
       info = '\n'.join(status['log']) if status['log'] else 'timeout'
       raise QueryError(_('The Spark session is %s and could not be created in the cluster: %s') % (status['state'], info))
 
-    return {
+    SESSIONS[session_key] = {
         'type': lang,
         'id': response['id'],
         'properties': self.to_properties(props)
     }
+    return SESSIONS[session_key]
+    
 
   def execute(self, notebook, snippet):
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
 
+    response = self._execute(api, session, snippet['statement'])
+    return response
+
+
+  def _execute(self, api, session, statement):
+    session_key = self._get_session_key()
+
+    if session['id'] is None and SESSIONS.get(session_key) is not None:
+      session = SESSIONS[session_key]
+
     try:
-      response = api.submit_statement(session['id'], snippet['statement'])
+      response = api.submit_statement(session['id'], statement)
       return {
           'id': response['id'],
           'has_result_set': True,
@@ -151,6 +182,7 @@ class SparkApi(Api):
       else:
         raise e
 
+
   def check_status(self, notebook, snippet):
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
@@ -168,11 +200,17 @@ class SparkApi(Api):
       else:
         raise e
 
+
   def fetch_result(self, notebook, snippet, rows, start_over):
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
     cell = snippet['result']['handle']['id']
 
+    response = self._fetch_result(api, session, cell, start_over)
+    return response
+
+
+  def _fetch_result(self, api, session, cell, start_over):
     try:
       response = api.fetch_data(session['id'], cell)
     except Exception as e:
@@ -234,6 +272,7 @@ class SparkApi(Api):
 
       raise QueryError(msg)
 
+
   def cancel(self, notebook, snippet):
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
@@ -241,17 +280,21 @@ class SparkApi(Api):
 
     return {'status': 0}
 
+
   def get_log(self, notebook, snippet, startFrom=0, size=None):
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
 
     return api.get_log(session['id'], startFrom=startFrom, size=size)
 
+
   def close_statement(self, notebook, snippet): # Individual statements cannot be closed
     pass
 
+
   def close_session(self, session):
     api = self.get_api()
+    session_key = self._get_session_key()
 
     if session['id'] is not None:
       try:
@@ -263,9 +306,13 @@ class SparkApi(Api):
       except RestException as e:
         if e.code == 404 or e.code == 500: # TODO remove the 500
           raise SessionExpired(e)
+      finally:
+        if SESSIONS.get(session_key):
+          del SESSIONS[session_key]
     else:
       return {'status': -1}
 
+
   def get_jobs(self, notebook, snippet, logs):
     if self._is_yarn_mode():
       # Tracking URL is found at the start of the logs
@@ -274,29 +321,123 @@ class SparkApi(Api):
     else:
       return self._get_standalone_jobs(logs)
 
+
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
     response = {}
 
     # As booting a new SQL session is slow and we don't send the id of the current one in /autocomplete
     # we could implement this by introducing an API cache per user similarly to SqlAlchemy.
+    api = self.get_api()
+    session_key = self._get_session_key()
+    
+    session = SESSIONS[session_key] if SESSIONS.get(session_key) else self.create_session(snippet.get('type'))
+
+    if database is None:
+      response['databases'] = self._show_databases(api, session)
+    elif table is None:
+      response['tables_meta'] = self._show_tables(api, session, database)
+    elif column is None:
+      columns = self._get_columns(api, session, database, table)
+      response['columns'] = [col['name'] for col in columns]
+      response['extended_columns'] = [{
+          'comment': col.get('comment'),
+          'name': col.get('name'),
+          'type': col['type']
+        }
+        for col in columns
+      ]
 
     return response
 
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
-    if operation != 'hello':
-      raise NotImplementedError()
 
-    response = {}
+  def _check_status_and_fetch_result(self, api, session, execute_resp):
+    check_status = api.fetch_data(session['id'], execute_resp['id'])
+
+    while check_status['state'] in ['running', 'waiting']:
+      check_status = api.fetch_data(session['id'], execute_resp['id'])
+      time.sleep(1)
+
+    if check_status['state'] == 'available':
+      return self._fetch_result(api, session, execute_resp['id'], start_over=True)
+
+
+  def _show_databases(self, api, session):
+    show_db_execute = self._execute(api, session, 'SHOW DATABASES')
+    db_list = self._check_status_and_fetch_result(api, session, show_db_execute)
+
+    if db_list:
+      return [db[0] for db in db_list['data']]
+
+
+  def _show_tables(self, api, session, database):
+    use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
+    use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
+
+    show_tables_execute = self._execute(api, session, 'SHOW TABLES')
+    tables_list = self._check_status_and_fetch_result(api, session, show_tables_execute)
+
+    if tables_list:
+      return [table[1] for table in tables_list['data']]
+
+
+  def _get_columns(self, api, session, database, table):
+    use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
+    use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
 
+    describe_tables_execute = self._execute(api, session, 'DESCRIBE %(table)s' % {'table': table})
+    columns_list = self._check_status_and_fetch_result(api, session, describe_tables_execute)
+
+    if columns_list:
+      return [{
+        'name': col[0],
+        'type': col[1],
+        'comment': '',
+      } for col in columns_list['data']]
+
+
+  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
     api = self.get_api()
+    session_key = self._get_session_key()
+    session = SESSIONS.get(session_key)
 
-    api.get_status()
+    statement = self._get_select_query(database, table, column, operation)
 
-    response['status'] = 0
-    response['rows'] = []
+    sample_execute = self._execute(api, session, statement)
+    sample_result = self._check_status_and_fetch_result(api, session, sample_execute)
+
+    response = {
+      'status': 0,
+      'result': {}
+    }
+    response['rows'] = sample_result['data']
+    response['full_headers'] = sample_result['meta']
 
     return response
 
+
+  def get_browse_query(self, snippet, database, table, partition_spec=None):
+    return self._get_select_query(database, table)
+
+  
+  def _get_select_query(self, database, table, column=None, operation=None, limit=100):
+    if operation == 'hello':
+      statement = "SELECT 'Hello World!'"
+    else:
+      column = '%(column)s' % {'column': column} if column else '*'
+      statement = textwrap.dedent('''\
+          SELECT %(column)s
+          FROM %(database)s.%(table)s
+          LIMIT %(limit)s
+          ''' % {
+            'database': database,
+            'table': table,
+            'column': column,
+            'limit': limit,
+        })
+
+    return statement
+
+
   def _get_standalone_jobs(self, logs):
     job_ids = set([])
 
@@ -321,6 +462,7 @@ class SparkApi(Api):
 
     return jobs
 
+
   def _get_yarn_jobs(self, logs):
     tracking_urls = set([])
 
@@ -336,6 +478,7 @@ class SparkApi(Api):
 
     return jobs
 
+
   def _is_yarn_mode(self):
     return LIVY_SERVER_SESSION_KIND.get() == "yarn"
 

+ 108 - 2
desktop/libs/notebook/src/notebook/connectors/spark_shell_tests.py

@@ -18,9 +18,9 @@
 import sys
 
 from builtins import object
-from nose.tools import assert_equal, assert_true, assert_false
+from nose.tools import assert_equal, assert_true, assert_false, assert_raises
 
-from notebook.connectors.spark_shell import SparkApi
+from notebook.connectors.spark_shell import SparkApi, SESSIONS
 
 if sys.version_info[0] > 2:
   from unittest.mock import patch, Mock
@@ -49,6 +49,7 @@ class TestSparkApi(object):
     spark_api = self.api.get_api()
     assert_equal(spark_api.__class__.__name__, 'LivyClient')
 
+
   def test_get_livy_props_method(self):
     test_properties = [{
         "name": "files",
@@ -57,9 +58,11 @@ class TestSparkApi(object):
     props = self.api.get_livy_props('scala', test_properties)
     assert_equal(props['files'], ['file_a', 'file_b', 'file_c'])
 
+
   def test_create_session_with_config(self):
     lang = 'pyspark'
     properties = None
+    session_key = self.api._get_session_key()
 
     with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
       with patch('notebook.connectors.spark_shell.DefaultConfiguration') as DefaultConfiguration:
@@ -82,37 +85,54 @@ class TestSparkApi(object):
           # Case with user configuration. Expected 2 driverCores
           USE_DEFAULT_CONFIGURATION.get.return_value = True
           session = self.api.create_session(lang=lang, properties=properties)
+
           assert_equal(session['type'], 'pyspark')
           assert_equal(session['id'], '1')
+
           for p in session['properties']:
             if p['name'] == 'driverCores':
               cores = p['value']
           assert_equal(cores, 2)
 
+          if SESSIONS.get(session_key):
+            del SESSIONS[session_key]
+
           # Case without user configuration. Expected 1 driverCores
           USE_DEFAULT_CONFIGURATION.get.return_value = True
           DefaultConfiguration.objects.get_configuration_for_user.return_value = None
           session2 = self.api.create_session(lang=lang, properties=properties)
+
           assert_equal(session2['type'], 'pyspark')
           assert_equal(session2['id'], '1')
+
           for p in session2['properties']:
             if p['name'] == 'driverCores':
               cores = p['value']
           assert_equal(cores, 1)
 
+          if SESSIONS.get(session_key):
+            del SESSIONS[session_key]
+
           # Case with no user configuration. Expected 1 driverCores
           USE_DEFAULT_CONFIGURATION.get.return_value = False
           session3 = self.api.create_session(lang=lang, properties=properties)
+
           assert_equal(session3['type'], 'pyspark')
           assert_equal(session3['id'], '1')
+
           for p in session3['properties']:
             if p['name'] == 'driverCores':
               cores = p['value']
           assert_equal(cores, 1)
 
+          if SESSIONS.get(session_key):
+            del SESSIONS[session_key]
+
+
   def test_create_session_plain(self):
     lang = 'pyspark'
     properties = None
+    session_key = self.api._get_session_key()
 
     with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
       get_spark_api.return_value = Mock(
@@ -133,6 +153,92 @@ class TestSparkApi(object):
       assert_true(files_properties, session['properties'])
       assert_equal(files_properties[0]['value'], [], session['properties'])
 
+      if SESSIONS.get(session_key):
+        del SESSIONS[session_key]
+
+
+  def test_execute(self):
+    with patch('notebook.connectors.spark_shell._get_snippet_session') as _get_snippet_session:
+      with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
+        notebook = Mock()
+        snippet = {'statement': 'select * from test_table'}
+        _get_snippet_session.return_value = {'id': '1'}
+
+        get_spark_api.return_value = Mock(
+          submit_statement=Mock(
+            return_value={'id': 'test_id'}
+          )
+        )
+
+        response = self.api.execute(notebook, snippet)
+        assert_equal(response['id'], 'test_id')
+
+        get_spark_api.return_value = Mock(
+          submit_statement=Mock()
+        )
+        assert_raises(Exception, self.api.execute, notebook, snippet)
+
+
+  def test_check_status(self):
+    with patch('notebook.connectors.spark_shell._get_snippet_session') as _get_snippet_session:
+      with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
+        notebook = Mock()
+        snippet = {
+          'result': {
+            'handle': {
+              'id': {'test_id'}
+            }
+          }
+        }
+        _get_snippet_session.return_value = {'id': '1'}
+
+        get_spark_api.return_value = Mock(
+          fetch_data=Mock(
+            return_value={'state': 'test_state'}
+          )
+        )
+
+        response = self.api.check_status(notebook, snippet)
+        assert_equal(response['status'], 'test_state')
+
+        get_spark_api.return_value = Mock(
+          submit_statement=Mock()
+        )
+        assert_raises(Exception, self.api.check_status, notebook, snippet)
+  
+
+  def test_get_sample_data(self):
+    snippet = Mock()
+    self.api._execute = Mock(
+      return_value='test_value'
+    )
+    self.api._check_status_and_fetch_result = Mock(
+      return_value={
+        'data': 'test_data',
+        'meta': 'test_meta'
+      }
+    )
+
+    response = self.api.get_sample_data(snippet, 'test_db', 'test_table', 'test_column')
+
+    assert_equal(response['rows'], 'test_data')
+    assert_equal(response['full_headers'], 'test_meta')
+  
+
+  def test_get_select_query(self):
+    # With operation as 'hello'
+    response = self.api._get_select_query('test_db', 'test_table', 'test_column', 'hello')
+    assert_equal(response, "SELECT 'Hello World!'")
+
+    # Without column name
+    response = self.api._get_select_query('test_db', 'test_table')
+    assert_equal(response, 'SELECT *\nFROM test_db.test_table\nLIMIT 100\n')
+
+    # With some column name
+    response = self.api._get_select_query('test_db', 'test_table', 'test_column')
+    assert_equal(response, 'SELECT test_column\nFROM test_db.test_table\nLIMIT 100\n')
+
+
   def test_get_jobs(self):
     local_jobs = [
       {'url': u'http://172.21.1.246:4040/jobs/job/?id=0', 'name': u'0'}