Procházet zdrojové kódy

[sparksql] Enhance SparkSQL support via Apache Livy for Hue (#2794)

* [sparksql] Enhance SparkSQL support via Apache Livy for Hue

- Reuse Livy session per user similarly to SQLAlchemy
- Add support for /autocomplete, /get_sample_data and other APIs
- Add few UTs and update existing ones

NOTE: In the start, 2 sessions are being created, one for /autocomplete and other for /create_session. This is because /autocomplete call is always first in the call stack from the UI, plus this will help if user just want to access other tabs like Table Browser where we don't do /create_session call. But since we are reusing the session per user, eventually we only use one of the sessions afterwards for every user activity regarding SparkSQL.
Harsh Gupta před 3 roky
rodič
revize
e751b863ca

+ 154 - 11
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -20,6 +20,7 @@ import logging
 import re
 import re
 import sys
 import sys
 import time
 import time
+import textwrap
 
 
 from desktop.conf import USE_DEFAULT_CONFIGURATION
 from desktop.conf import USE_DEFAULT_CONFIGURATION
 from desktop.lib.exceptions_renderable import PopupException
 from desktop.lib.exceptions_renderable import PopupException
@@ -45,19 +46,23 @@ try:
 except ImportError as e:
 except ImportError as e:
   LOG.exception('Spark is not enabled')
   LOG.exception('Spark is not enabled')
 
 
-
+SESSIONS = {}
+SESSION_KEY = '%(username)s-%(interpreter_name)s'
 class SparkApi(Api):
 class SparkApi(Api):
 
 
   SPARK_UI_RE = re.compile("Started SparkUI at (http[s]?://([0-9a-zA-Z-_\.]+):(\d+))")
   SPARK_UI_RE = re.compile("Started SparkUI at (http[s]?://([0-9a-zA-Z-_\.]+):(\d+))")
   YARN_JOB_RE = re.compile("tracking URL: (http[s]?://.+/)")
   YARN_JOB_RE = re.compile("tracking URL: (http[s]?://.+/)")
   STANDALONE_JOB_RE = re.compile("Got job (\d+)")
   STANDALONE_JOB_RE = re.compile("Got job (\d+)")
 
 
+
   def __init__(self, user, interpreter):
   def __init__(self, user, interpreter):
     super(SparkApi, self).__init__(user=user, interpreter=interpreter)
     super(SparkApi, self).__init__(user=user, interpreter=interpreter)
 
 
+
   def get_api(self):
   def get_api(self):
     return get_spark_api(self.user, self.interpreter)
     return get_spark_api(self.user, self.interpreter)
 
 
+
   @staticmethod
   @staticmethod
   def get_livy_props(lang, properties=None):
   def get_livy_props(lang, properties=None):
     props = dict([(p['name'], p['value']) for p in SparkConfiguration.PROPERTIES])
     props = dict([(p['name'], p['value']) for p in SparkConfiguration.PROPERTIES])
@@ -91,6 +96,7 @@ class SparkApi(Api):
 
 
     return props
     return props
 
 
+
   @staticmethod
   @staticmethod
   def to_properties(props=None):
   def to_properties(props=None):
     properties = list()
     properties = list()
@@ -104,7 +110,21 @@ class SparkApi(Api):
 
 
     return properties
     return properties
 
 
+
+  def _get_session_key(self):
+    return SESSION_KEY % {
+      'username': self.user.username if hasattr(self.user, 'username') else self.user,
+      'interpreter_name': self.interpreter['name']
+    }
+
+
   def create_session(self, lang='scala', properties=None):
   def create_session(self, lang='scala', properties=None):
+    api = self.get_api()
+    session_key = self._get_session_key()
+
+    if SESSIONS.get(session_key):
+      return SESSIONS[session_key]
+
     if not properties and USE_DEFAULT_CONFIGURATION.get():
     if not properties and USE_DEFAULT_CONFIGURATION.get():
       user_config = DefaultConfiguration.objects.get_configuration_for_user(app='spark', user=self.user)
       user_config = DefaultConfiguration.objects.get_configuration_for_user(app='spark', user=self.user)
       if user_config is not None:
       if user_config is not None:
@@ -112,7 +132,6 @@ class SparkApi(Api):
 
 
     props = self.get_livy_props(lang, properties)
     props = self.get_livy_props(lang, properties)
 
 
-    api = get_spark_api(self.user)
     response = api.create_session(**props)
     response = api.create_session(**props)
 
 
     status = api.get_session(response['id'])
     status = api.get_session(response['id'])
@@ -127,18 +146,30 @@ class SparkApi(Api):
       info = '\n'.join(status['log']) if status['log'] else 'timeout'
       info = '\n'.join(status['log']) if status['log'] else 'timeout'
       raise QueryError(_('The Spark session is %s and could not be created in the cluster: %s') % (status['state'], info))
       raise QueryError(_('The Spark session is %s and could not be created in the cluster: %s') % (status['state'], info))
 
 
-    return {
+    SESSIONS[session_key] = {
         'type': lang,
         'type': lang,
         'id': response['id'],
         'id': response['id'],
         'properties': self.to_properties(props)
         'properties': self.to_properties(props)
     }
     }
+    return SESSIONS[session_key]
+    
 
 
   def execute(self, notebook, snippet):
   def execute(self, notebook, snippet):
     api = self.get_api()
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
     session = _get_snippet_session(notebook, snippet)
 
 
+    response = self._execute(api, session, snippet['statement'])
+    return response
+
+
+  def _execute(self, api, session, statement):
+    session_key = self._get_session_key()
+
+    if session['id'] is None and SESSIONS.get(session_key) is not None:
+      session = SESSIONS[session_key]
+
     try:
     try:
-      response = api.submit_statement(session['id'], snippet['statement'])
+      response = api.submit_statement(session['id'], statement)
       return {
       return {
           'id': response['id'],
           'id': response['id'],
           'has_result_set': True,
           'has_result_set': True,
@@ -151,6 +182,7 @@ class SparkApi(Api):
       else:
       else:
         raise e
         raise e
 
 
+
   def check_status(self, notebook, snippet):
   def check_status(self, notebook, snippet):
     api = self.get_api()
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
     session = _get_snippet_session(notebook, snippet)
@@ -168,11 +200,17 @@ class SparkApi(Api):
       else:
       else:
         raise e
         raise e
 
 
+
   def fetch_result(self, notebook, snippet, rows, start_over):
   def fetch_result(self, notebook, snippet, rows, start_over):
     api = self.get_api()
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
     session = _get_snippet_session(notebook, snippet)
     cell = snippet['result']['handle']['id']
     cell = snippet['result']['handle']['id']
 
 
+    response = self._fetch_result(api, session, cell, start_over)
+    return response
+
+
+  def _fetch_result(self, api, session, cell, start_over):
     try:
     try:
       response = api.fetch_data(session['id'], cell)
       response = api.fetch_data(session['id'], cell)
     except Exception as e:
     except Exception as e:
@@ -234,6 +272,7 @@ class SparkApi(Api):
 
 
       raise QueryError(msg)
       raise QueryError(msg)
 
 
+
   def cancel(self, notebook, snippet):
   def cancel(self, notebook, snippet):
     api = self.get_api()
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
     session = _get_snippet_session(notebook, snippet)
@@ -241,17 +280,21 @@ class SparkApi(Api):
 
 
     return {'status': 0}
     return {'status': 0}
 
 
+
   def get_log(self, notebook, snippet, startFrom=0, size=None):
   def get_log(self, notebook, snippet, startFrom=0, size=None):
     api = self.get_api()
     api = self.get_api()
     session = _get_snippet_session(notebook, snippet)
     session = _get_snippet_session(notebook, snippet)
 
 
     return api.get_log(session['id'], startFrom=startFrom, size=size)
     return api.get_log(session['id'], startFrom=startFrom, size=size)
 
 
+
   def close_statement(self, notebook, snippet): # Individual statements cannot be closed
   def close_statement(self, notebook, snippet): # Individual statements cannot be closed
     pass
     pass
 
 
+
   def close_session(self, session):
   def close_session(self, session):
     api = self.get_api()
     api = self.get_api()
+    session_key = self._get_session_key()
 
 
     if session['id'] is not None:
     if session['id'] is not None:
       try:
       try:
@@ -263,9 +306,13 @@ class SparkApi(Api):
       except RestException as e:
       except RestException as e:
         if e.code == 404 or e.code == 500: # TODO remove the 500
         if e.code == 404 or e.code == 500: # TODO remove the 500
           raise SessionExpired(e)
           raise SessionExpired(e)
+      finally:
+        if SESSIONS.get(session_key):
+          del SESSIONS[session_key]
     else:
     else:
       return {'status': -1}
       return {'status': -1}
 
 
+
   def get_jobs(self, notebook, snippet, logs):
   def get_jobs(self, notebook, snippet, logs):
     if self._is_yarn_mode():
     if self._is_yarn_mode():
       # Tracking URL is found at the start of the logs
       # Tracking URL is found at the start of the logs
@@ -274,29 +321,123 @@ class SparkApi(Api):
     else:
     else:
       return self._get_standalone_jobs(logs)
       return self._get_standalone_jobs(logs)
 
 
+
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
     response = {}
     response = {}
 
 
     # As booting a new SQL session is slow and we don't send the id of the current one in /autocomplete
     # As booting a new SQL session is slow and we don't send the id of the current one in /autocomplete
     # we could implement this by introducing an API cache per user similarly to SqlAlchemy.
     # we could implement this by introducing an API cache per user similarly to SqlAlchemy.
+    api = self.get_api()
+    session_key = self._get_session_key()
+    
+    session = SESSIONS[session_key] if SESSIONS.get(session_key) else self.create_session(snippet.get('type'))
+
+    if database is None:
+      response['databases'] = self._show_databases(api, session)
+    elif table is None:
+      response['tables_meta'] = self._show_tables(api, session, database)
+    elif column is None:
+      columns = self._get_columns(api, session, database, table)
+      response['columns'] = [col['name'] for col in columns]
+      response['extended_columns'] = [{
+          'comment': col.get('comment'),
+          'name': col.get('name'),
+          'type': col['type']
+        }
+        for col in columns
+      ]
 
 
     return response
     return response
 
 
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
-    if operation != 'hello':
-      raise NotImplementedError()
 
 
-    response = {}
+  def _check_status_and_fetch_result(self, api, session, execute_resp):
+    check_status = api.fetch_data(session['id'], execute_resp['id'])
+
+    while check_status['state'] in ['running', 'waiting']:
+      check_status = api.fetch_data(session['id'], execute_resp['id'])
+      time.sleep(1)
+
+    if check_status['state'] == 'available':
+      return self._fetch_result(api, session, execute_resp['id'], start_over=True)
+
+
+  def _show_databases(self, api, session):
+    show_db_execute = self._execute(api, session, 'SHOW DATABASES')
+    db_list = self._check_status_and_fetch_result(api, session, show_db_execute)
+
+    if db_list:
+      return [db[0] for db in db_list['data']]
+
+
+  def _show_tables(self, api, session, database):
+    use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
+    use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
+
+    show_tables_execute = self._execute(api, session, 'SHOW TABLES')
+    tables_list = self._check_status_and_fetch_result(api, session, show_tables_execute)
+
+    if tables_list:
+      return [table[1] for table in tables_list['data']]
+
+
+  def _get_columns(self, api, session, database, table):
+    use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
+    use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
 
 
+    describe_tables_execute = self._execute(api, session, 'DESCRIBE %(table)s' % {'table': table})
+    columns_list = self._check_status_and_fetch_result(api, session, describe_tables_execute)
+
+    if columns_list:
+      return [{
+        'name': col[0],
+        'type': col[1],
+        'comment': '',
+      } for col in columns_list['data']]
+
+
+  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
     api = self.get_api()
     api = self.get_api()
+    session_key = self._get_session_key()
+    session = SESSIONS.get(session_key)
 
 
-    api.get_status()
+    statement = self._get_select_query(database, table, column, operation)
 
 
-    response['status'] = 0
-    response['rows'] = []
+    sample_execute = self._execute(api, session, statement)
+    sample_result = self._check_status_and_fetch_result(api, session, sample_execute)
+
+    response = {
+      'status': 0,
+      'result': {}
+    }
+    response['rows'] = sample_result['data']
+    response['full_headers'] = sample_result['meta']
 
 
     return response
     return response
 
 
+
+  def get_browse_query(self, snippet, database, table, partition_spec=None):
+    return self._get_select_query(database, table)
+
+  
+  def _get_select_query(self, database, table, column=None, operation=None, limit=100):
+    if operation == 'hello':
+      statement = "SELECT 'Hello World!'"
+    else:
+      column = '%(column)s' % {'column': column} if column else '*'
+      statement = textwrap.dedent('''\
+          SELECT %(column)s
+          FROM %(database)s.%(table)s
+          LIMIT %(limit)s
+          ''' % {
+            'database': database,
+            'table': table,
+            'column': column,
+            'limit': limit,
+        })
+
+    return statement
+
+
   def _get_standalone_jobs(self, logs):
   def _get_standalone_jobs(self, logs):
     job_ids = set([])
     job_ids = set([])
 
 
@@ -321,6 +462,7 @@ class SparkApi(Api):
 
 
     return jobs
     return jobs
 
 
+
   def _get_yarn_jobs(self, logs):
   def _get_yarn_jobs(self, logs):
     tracking_urls = set([])
     tracking_urls = set([])
 
 
@@ -336,6 +478,7 @@ class SparkApi(Api):
 
 
     return jobs
     return jobs
 
 
+
   def _is_yarn_mode(self):
   def _is_yarn_mode(self):
     return LIVY_SERVER_SESSION_KIND.get() == "yarn"
     return LIVY_SERVER_SESSION_KIND.get() == "yarn"
 
 

+ 108 - 2
desktop/libs/notebook/src/notebook/connectors/spark_shell_tests.py

@@ -18,9 +18,9 @@
 import sys
 import sys
 
 
 from builtins import object
 from builtins import object
-from nose.tools import assert_equal, assert_true, assert_false
+from nose.tools import assert_equal, assert_true, assert_false, assert_raises
 
 
-from notebook.connectors.spark_shell import SparkApi
+from notebook.connectors.spark_shell import SparkApi, SESSIONS
 
 
 if sys.version_info[0] > 2:
 if sys.version_info[0] > 2:
   from unittest.mock import patch, Mock
   from unittest.mock import patch, Mock
@@ -49,6 +49,7 @@ class TestSparkApi(object):
     spark_api = self.api.get_api()
     spark_api = self.api.get_api()
     assert_equal(spark_api.__class__.__name__, 'LivyClient')
     assert_equal(spark_api.__class__.__name__, 'LivyClient')
 
 
+
   def test_get_livy_props_method(self):
   def test_get_livy_props_method(self):
     test_properties = [{
     test_properties = [{
         "name": "files",
         "name": "files",
@@ -57,9 +58,11 @@ class TestSparkApi(object):
     props = self.api.get_livy_props('scala', test_properties)
     props = self.api.get_livy_props('scala', test_properties)
     assert_equal(props['files'], ['file_a', 'file_b', 'file_c'])
     assert_equal(props['files'], ['file_a', 'file_b', 'file_c'])
 
 
+
   def test_create_session_with_config(self):
   def test_create_session_with_config(self):
     lang = 'pyspark'
     lang = 'pyspark'
     properties = None
     properties = None
+    session_key = self.api._get_session_key()
 
 
     with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
     with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
       with patch('notebook.connectors.spark_shell.DefaultConfiguration') as DefaultConfiguration:
       with patch('notebook.connectors.spark_shell.DefaultConfiguration') as DefaultConfiguration:
@@ -82,37 +85,54 @@ class TestSparkApi(object):
           # Case with user configuration. Expected 2 driverCores
           # Case with user configuration. Expected 2 driverCores
           USE_DEFAULT_CONFIGURATION.get.return_value = True
           USE_DEFAULT_CONFIGURATION.get.return_value = True
           session = self.api.create_session(lang=lang, properties=properties)
           session = self.api.create_session(lang=lang, properties=properties)
+
           assert_equal(session['type'], 'pyspark')
           assert_equal(session['type'], 'pyspark')
           assert_equal(session['id'], '1')
           assert_equal(session['id'], '1')
+
           for p in session['properties']:
           for p in session['properties']:
             if p['name'] == 'driverCores':
             if p['name'] == 'driverCores':
               cores = p['value']
               cores = p['value']
           assert_equal(cores, 2)
           assert_equal(cores, 2)
 
 
+          if SESSIONS.get(session_key):
+            del SESSIONS[session_key]
+
           # Case without user configuration. Expected 1 driverCores
           # Case without user configuration. Expected 1 driverCores
           USE_DEFAULT_CONFIGURATION.get.return_value = True
           USE_DEFAULT_CONFIGURATION.get.return_value = True
           DefaultConfiguration.objects.get_configuration_for_user.return_value = None
           DefaultConfiguration.objects.get_configuration_for_user.return_value = None
           session2 = self.api.create_session(lang=lang, properties=properties)
           session2 = self.api.create_session(lang=lang, properties=properties)
+
           assert_equal(session2['type'], 'pyspark')
           assert_equal(session2['type'], 'pyspark')
           assert_equal(session2['id'], '1')
           assert_equal(session2['id'], '1')
+
           for p in session2['properties']:
           for p in session2['properties']:
             if p['name'] == 'driverCores':
             if p['name'] == 'driverCores':
               cores = p['value']
               cores = p['value']
           assert_equal(cores, 1)
           assert_equal(cores, 1)
 
 
+          if SESSIONS.get(session_key):
+            del SESSIONS[session_key]
+
           # Case with no user configuration. Expected 1 driverCores
           # Case with no user configuration. Expected 1 driverCores
           USE_DEFAULT_CONFIGURATION.get.return_value = False
           USE_DEFAULT_CONFIGURATION.get.return_value = False
           session3 = self.api.create_session(lang=lang, properties=properties)
           session3 = self.api.create_session(lang=lang, properties=properties)
+
           assert_equal(session3['type'], 'pyspark')
           assert_equal(session3['type'], 'pyspark')
           assert_equal(session3['id'], '1')
           assert_equal(session3['id'], '1')
+
           for p in session3['properties']:
           for p in session3['properties']:
             if p['name'] == 'driverCores':
             if p['name'] == 'driverCores':
               cores = p['value']
               cores = p['value']
           assert_equal(cores, 1)
           assert_equal(cores, 1)
 
 
+          if SESSIONS.get(session_key):
+            del SESSIONS[session_key]
+
+
   def test_create_session_plain(self):
   def test_create_session_plain(self):
     lang = 'pyspark'
     lang = 'pyspark'
     properties = None
     properties = None
+    session_key = self.api._get_session_key()
 
 
     with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
     with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
       get_spark_api.return_value = Mock(
       get_spark_api.return_value = Mock(
@@ -133,6 +153,92 @@ class TestSparkApi(object):
       assert_true(files_properties, session['properties'])
       assert_true(files_properties, session['properties'])
       assert_equal(files_properties[0]['value'], [], session['properties'])
       assert_equal(files_properties[0]['value'], [], session['properties'])
 
 
+      if SESSIONS.get(session_key):
+        del SESSIONS[session_key]
+
+
+  def test_execute(self):
+    with patch('notebook.connectors.spark_shell._get_snippet_session') as _get_snippet_session:
+      with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
+        notebook = Mock()
+        snippet = {'statement': 'select * from test_table'}
+        _get_snippet_session.return_value = {'id': '1'}
+
+        get_spark_api.return_value = Mock(
+          submit_statement=Mock(
+            return_value={'id': 'test_id'}
+          )
+        )
+
+        response = self.api.execute(notebook, snippet)
+        assert_equal(response['id'], 'test_id')
+
+        get_spark_api.return_value = Mock(
+          submit_statement=Mock()
+        )
+        assert_raises(Exception, self.api.execute, notebook, snippet)
+
+
+  def test_check_status(self):
+    with patch('notebook.connectors.spark_shell._get_snippet_session') as _get_snippet_session:
+      with patch('notebook.connectors.spark_shell.get_spark_api') as get_spark_api:
+        notebook = Mock()
+        snippet = {
+          'result': {
+            'handle': {
+              'id': {'test_id'}
+            }
+          }
+        }
+        _get_snippet_session.return_value = {'id': '1'}
+
+        get_spark_api.return_value = Mock(
+          fetch_data=Mock(
+            return_value={'state': 'test_state'}
+          )
+        )
+
+        response = self.api.check_status(notebook, snippet)
+        assert_equal(response['status'], 'test_state')
+
+        get_spark_api.return_value = Mock(
+          submit_statement=Mock()
+        )
+        assert_raises(Exception, self.api.check_status, notebook, snippet)
+  
+
+  def test_get_sample_data(self):
+    snippet = Mock()
+    self.api._execute = Mock(
+      return_value='test_value'
+    )
+    self.api._check_status_and_fetch_result = Mock(
+      return_value={
+        'data': 'test_data',
+        'meta': 'test_meta'
+      }
+    )
+
+    response = self.api.get_sample_data(snippet, 'test_db', 'test_table', 'test_column')
+
+    assert_equal(response['rows'], 'test_data')
+    assert_equal(response['full_headers'], 'test_meta')
+  
+
+  def test_get_select_query(self):
+    # With operation as 'hello'
+    response = self.api._get_select_query('test_db', 'test_table', 'test_column', 'hello')
+    assert_equal(response, "SELECT 'Hello World!'")
+
+    # Without column name
+    response = self.api._get_select_query('test_db', 'test_table')
+    assert_equal(response, 'SELECT *\nFROM test_db.test_table\nLIMIT 100\n')
+
+    # With some column name
+    response = self.api._get_select_query('test_db', 'test_table', 'test_column')
+    assert_equal(response, 'SELECT test_column\nFROM test_db.test_table\nLIMIT 100\n')
+
+
   def test_get_jobs(self):
   def test_get_jobs(self):
     local_jobs = [
     local_jobs = [
       {'url': u'http://172.21.1.246:4040/jobs/job/?id=0', 'name': u'0'}
       {'url': u'http://172.21.1.246:4040/jobs/job/?id=0', 'name': u'0'}