浏览代码

[notebook] Split APIs into respective packages

Romain Rigaux 10 年之前
父节点
当前提交
1435cd5

+ 1 - 1
desktop/libs/notebook/src/notebook/api.py

@@ -26,7 +26,7 @@ from desktop.lib.django_util import JsonResponse
 from desktop.models import Document2, Document
 from oozie.decorators import check_document_access_permission # Bad dependency
 
-from notebook.models import get_api, Notebook, QueryExpired
+from notebook.connectors.base import get_api, Notebook, QueryExpired
 from notebook.decorators import api_error_handler, check_document_modify_permission
 
 

+ 15 - 0
desktop/libs/notebook/src/notebook/connectors/__init__.py

@@ -0,0 +1,15 @@
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 110 - 0
desktop/libs/notebook/src/notebook/connectors/base.py

@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+
+from desktop.lib.i18n import force_unicode
+
+
+LOG = logging.getLogger(__name__)
+
+
+class SessionExpired(Exception):
+  pass
+
+
+class QueryExpired(Exception):
+  pass
+
+
+class QueryError(Exception):
+  def __init__(self, message):
+    self.message = message
+
+  def __str__(self):
+    return force_unicode(str(self.message))
+
+
+class Notebook():
+
+  def __init__(self, document=None):
+    self.document = None
+
+    if document is not None:
+      self.data = document.data
+      self.document = document
+    else:
+      self.data = json.dumps({
+          'name': 'My Notebook',
+          'description': '',
+          'snippets': []
+      })
+
+  def get_json(self):
+    _data = self.get_data()
+
+    return json.dumps(_data)
+
+  def get_data(self):
+    _data = json.loads(self.data)
+
+    if self.document is not None:
+      _data['id'] = self.document.id
+
+    return _data
+
+  def get_str(self):
+    return '\n\n'.join([snippet['statement_raw'] for snippet in self.get_data()['snippets']])
+
+
+def get_api(user, snippet):
+  from notebook.connectors.hiveserver2 import HS2Api
+  from notebook.connectors.spark_batch import SparkBatchApi
+  from notebook.connectors.text import TextApi
+  from notebook.connectors.spark_shell import SparkApi
+
+  if snippet['type'] in ('hive', 'impala', 'spark-sql'):
+    return HS2Api(user)
+  elif snippet['type'] in ('jar', 'py'):
+    return SparkBatchApi(user)
+  elif snippet['type'] == 'text':
+    return TextApi(user)
+  else:
+    return SparkApi(user)
+
+
+def _get_snippet_session(notebook, snippet):
+  return [session for session in notebook['sessions'] if session['type'] == snippet['type']][0]
+
+
+# Base API
+
+class Api(object):
+
+  def __init__(self, user):
+    self.user = user
+
+  def create_session(self, lang, properties=None):
+    return {
+        'type': lang,
+        'id': None,
+        'properties': []
+    }
+
+  def close_session(self, session):
+    pass

+ 185 - 0
desktop/libs/notebook/src/notebook/connectors/hiveserver2.py

@@ -0,0 +1,185 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from desktop.lib.exceptions_renderable import PopupException
+from desktop.lib.i18n import force_unicode
+
+from beeswax import data_export
+from beeswax.design import hql_query
+from beeswax import conf as beeswax_conf
+from beeswax.models import QUERY_TYPES, HiveServerQueryHandle, QueryHistory, HiveServerQueryHistory
+from beeswax.server import dbms
+from beeswax.server.dbms import get_query_server_config, QueryServerException
+from beeswax.views import _parse_out_hadoop_jobs
+
+from notebook.connectors.base import Api, QueryError, QueryExpired
+
+
+LOG = logging.getLogger(__name__)
+
+
+def query_error_handler(func):
+  def decorator(*args, **kwargs):
+    try:
+      return func(*args, **kwargs)
+    except QueryServerException, e:
+      message = force_unicode(str(e))
+      if 'Invalid query handle' in message or 'Invalid OperationHandle' in message:
+        raise QueryExpired(e)
+      else:
+        raise QueryError(message)
+  return decorator
+
+
+class HS2Api(Api):
+
+  def _get_handle(self, snippet):
+    snippet['result']['handle']['secret'], snippet['result']['handle']['guid'] = HiveServerQueryHandle.get_decoded(snippet['result']['handle']['secret'], snippet['result']['handle']['guid'])
+    return HiveServerQueryHandle(**snippet['result']['handle'])
+
+  def _get_db(self, snippet):
+    if snippet['type'] == 'hive':
+      name = 'beeswax'
+    elif snippet['type'] == 'impala':
+      name = 'impala'
+    else:
+      name = 'spark-sql'
+
+    return dbms.get(self.user, query_server=get_query_server_config(name=name))
+
+  def execute(self, notebook, snippet):
+    db = self._get_db(snippet)
+    query = hql_query(snippet['statement'], QUERY_TYPES[0])
+
+    try:
+      handle = db.client.query(query)
+    except QueryServerException, ex:
+      raise QueryError(ex.message)
+
+    # All good
+    server_id, server_guid  = handle.get()
+    return {
+        'secret': server_id,
+        'guid': server_guid,
+        'operation_type': handle.operation_type,
+        'has_result_set': handle.has_result_set,
+        'modified_row_count': handle.modified_row_count,
+        'log_context': handle.log_context
+    }
+
+  @query_error_handler
+  def check_status(self, notebook, snippet):
+    response = {}
+    db = self._get_db(snippet)
+
+    handle = self._get_handle(snippet)
+    operation = db.get_operation_status(handle)
+    status = HiveServerQueryHistory.STATE_MAP[operation.operationState]
+
+    if status.index in (QueryHistory.STATE.failed.index, QueryHistory.STATE.expired.index):
+      raise QueryError(operation.errorMessage)
+
+    response['status'] = 'running' if status.index in (QueryHistory.STATE.running.index, QueryHistory.STATE.submitted.index) else 'available'
+
+    return response
+
+  @query_error_handler
+  def fetch_result(self, notebook, snippet, rows, start_over):
+    db = self._get_db(snippet)
+
+    handle = self._get_handle(snippet)
+    results = db.fetch(handle, start_over=start_over, rows=rows)
+
+    # No escaping...
+    return {
+        'has_more': results.has_more,
+        'data': list(results.rows()),
+        'meta': [{
+          'name': column.name,
+          'type': column.type,
+          'comment': column.comment
+        } for column in results.data_table.cols()],
+        'type': 'table'
+    }
+
+  @query_error_handler
+  def fetch_result_metadata(self):
+    pass
+
+  @query_error_handler
+  def cancel(self, notebook, snippet):
+    db = self._get_db(snippet)
+
+    handle = self._get_handle(snippet)
+    db.cancel_operation(handle)
+    return {'status': 0}
+
+  @query_error_handler
+  def get_log(self, notebook, snippet, startFrom=None, size=None):
+    db = self._get_db(snippet)
+
+    handle = self._get_handle(snippet)
+    return db.get_log(handle, start_over=startFrom == 0)
+
+  def download(self, notebook, snippet, format):
+    try:
+      db = self._get_db(snippet)
+      handle = self._get_handle(snippet)
+      return data_export.download(handle, format, db)
+    except Exception, e:
+      LOG.exception('error downloading notebook')
+
+      if not hasattr(e, 'message') or not e.message:
+        message = e
+      else:
+        message = e.message
+      raise PopupException(message, detail='')
+
+  def _progress(self, snippet, logs):
+    if snippet['type'] == 'hive':
+      match = re.search('Total jobs = (\d+)', logs, re.MULTILINE)
+      total = (int(match.group(1)) if match else 1) * 2
+
+      started = logs.count('Starting Job')
+      ended = logs.count('Ended Job')
+
+      return int((started + ended) * 100 / total)
+    elif snippet['type'] == 'impala':
+      match = re.search('(\d+)% Complete', logs, re.MULTILINE)
+      return int(match.group(1)) if match else 0
+    else:
+      return 50
+
+  @query_error_handler
+  def close_statement(self, snippet):
+    if snippet['type'] == 'impala':
+      from impala import conf as impala_conf
+
+    if (snippet['type'] == 'hive' and beeswax_conf.CLOSE_QUERIES.get()) or (snippet['type'] == 'impala' and impala_conf.CLOSE_QUERIES.get()):
+      db = self._get_db(snippet)
+
+      handle = self._get_handle(snippet)
+      db.close_operation(handle)
+      return {'status': 0}
+    else:
+      return {'status': -1}  # skipped
+
+  def _get_jobs(self, log):
+    return _parse_out_hadoop_jobs(log)

+ 93 - 0
desktop/libs/notebook/src/notebook/connectors/spark_batch.py

@@ -0,0 +1,93 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from spark.job_server_api import get_api as get_spark_api
+
+from notebook.connectors.base import Api
+
+
+LOG = logging.getLogger(__name__)
+
+
+class SparkBatchApi(Api):
+
+  def create_session(self, lang, properties=None):
+    return {
+        'type': lang,
+        'id': None
+    }
+
+  def execute(self, notebook, snippet):
+    api = get_spark_api(self.user)
+
+    properties = {
+        'file': snippet['properties'].get('app_jar'),
+        'className': snippet['properties'].get('class'),
+        'args': snippet['properties'].get('arguments'),
+        'pyFiles': snippet['properties'].get('py_file'),
+        # files
+        # driverMemory
+        # driverCores
+        # executorMemory
+        # executorCores
+        # archives
+    }
+
+    response = api.submit_batch(properties)
+    return {
+        'id': response['id'],
+        'has_result_set': True,
+        'properties': []
+    }
+
+  def check_status(self, notebook, snippet):
+    api = get_spark_api(self.user)
+
+    state = api.get_batch_status(snippet['result']['handle']['id'])
+    return {
+        'status': state,
+    }
+
+  def get_log(self, notebook, snippet, startFrom=0, size=None):
+    api = get_spark_api(self.user)
+
+    return api.get_batch_log(snippet['result']['handle']['id'], startFrom=startFrom, size=size)
+
+  def close_statement(self, snippet):
+    api = get_spark_api(self.user)
+
+    session_id = snippet['result']['handle']['id']
+    if session_id is not None:
+      api.close_batch(session_id)
+      return {
+        'session': session_id,
+        'status': 0
+      }
+    else:
+      return {'status': -1}  # skipped
+
+  def cancel(self, notebook, snippet):
+    # Batch jobs do not support interruption, so close statement instead.
+    return self.close_statement(snippet)
+
+  def _progress(self, snippet, logs):
+    return 50
+
+  def _get_jobs(self, log):
+    return []

+ 219 - 0
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -0,0 +1,219 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+
+from django.utils.translation import ugettext as _
+
+from desktop.lib.exceptions_renderable import PopupException
+from desktop.lib.i18n import force_unicode
+from desktop.lib.rest.http_client import RestException
+
+from spark.job_server_api import get_api as get_spark_api
+
+from notebook.data_export import download as spark_download
+from notebook.connectors.base import SessionExpired, _get_snippet_session, Api,\
+  QueryError
+
+
+LOG = logging.getLogger(__name__)
+
+
+class SparkApi(Api):
+  PROPERTIES = [
+    {'name': 'jars', 'nice_name': _('Jars'), 'default': '', 'type': 'csv-hdfs-files', 'is_yarn': False},
+    {'name': 'files', 'nice_name': _('Files'), 'default': '', 'type': 'csv-hdfs-files', 'is_yarn': False},
+    {'name': 'pyFiles', 'nice_name': _('pyFiles'), 'default': '', 'type': 'csv-hdfs-files', 'is_yarn': False},
+
+    {'name': 'driverMemory', 'nice_name': _('Driver Memory'), 'default': '1', 'type': 'jvm', 'is_yarn': False},
+
+    {'name': 'driverCores', 'nice_name': _('Driver Cores'), 'default': '1', 'type': 'number', 'is_yarn': True},
+    {'name': 'executorCores', 'nice_name': _('Executor Cores'), 'default': '1', 'type': 'number', 'is_yarn': True},
+    {'name': 'queue', 'nice_name': _('Queue'), 'default': '1', 'type': 'string', 'is_yarn': True},
+    {'name': 'archives', 'nice_name': _('Archives'), 'default': '', 'type': 'csv-hdfs-files', 'is_yarn': True},
+    {'name': 'numExecutors', 'nice_name': _('Executors Numbers'), 'default': '1', 'type': 'number', 'is_yarn': True},
+  ]
+
+  def create_session(self, lang='scala', properties=None):
+    properties = dict([(p['name'], p['value']) for p in properties]) if properties is not None else {}
+
+    properties['kind'] = lang
+
+    api = get_spark_api(self.user)
+
+    response = api.create_session(**properties)
+
+    status = api.get_session(response['id'])
+    count = 0
+
+    while status['state'] == 'starting' and count < 120:
+      status = api.get_session(response['id'])
+      count += 1
+      time.sleep(1)
+
+    if status['state'] != 'idle':
+      info = '\n'.join(status['log']) if status['log'] else 'timeout'
+      raise QueryError(_('The Spark session could not be created in the cluster: %s') % info)
+
+    return {
+        'type': lang,
+        'id': response['id'],
+        'properties': []
+    }
+
+  def execute(self, notebook, snippet):
+    api = get_spark_api(self.user)
+    session = _get_snippet_session(notebook, snippet)
+
+    try:
+      response = api.submit_statement(session['id'], snippet['statement'])
+      return {
+          'id': response['id'],
+          'has_result_set': True,
+      }
+    except Exception, e:
+      message = force_unicode(str(e)).lower()
+      if 'session not found' in message or 'connection refused' in message or 'session is in state busy' in message:
+        raise SessionExpired(e)
+      else:
+        raise e
+
+  def check_status(self, notebook, snippet):
+    api = get_spark_api(self.user)
+    session = _get_snippet_session(notebook, snippet)
+    cell = snippet['result']['handle']['id']
+
+    try:
+      response = api.fetch_data(session['id'], cell)
+      return {
+          'status': response['state'],
+      }
+    except Exception, e:
+      message = force_unicode(str(e)).lower()
+      if 'session not found' in message:
+        raise SessionExpired(e)
+      else:
+        raise e
+
+  def fetch_result(self, notebook, snippet, rows, start_over):
+    api = get_spark_api(self.user)
+    session = _get_snippet_session(notebook, snippet)
+    cell = snippet['result']['handle']['id']
+
+    try:
+      response = api.fetch_data(session['id'], cell)
+    except Exception, e:
+      message = force_unicode(str(e)).lower()
+      if 'session not found' in message:
+        raise SessionExpired(e)
+      else:
+        raise e
+
+    content = response['output']
+
+    if content['status'] == 'ok':
+      data = content['data']
+      images = []
+
+      try:
+        table = data['application/vnd.livy.table.v1+json']
+      except KeyError:
+        try:
+          images = [data['image/png']]
+        except KeyError:
+          images = []
+        data = [[data['text/plain']]]
+        meta = [{'name': 'Header', 'type': 'STRING_TYPE', 'comment': ''}]
+        type = 'text'
+      else:
+        data = table['data']
+        headers = table['headers']
+        meta = [{'name': h['name'], 'type': h['type'], 'comment': ''} for h in headers]
+        type = 'table'
+
+      # Non start_over not supported
+      if not start_over:
+        data = []
+
+      return {
+          'data': data,
+          'images': images,
+          'meta': meta,
+          'type': type
+      }
+    elif content['status'] == 'error':
+      tb = content.get('traceback', None)
+
+      if tb is None:
+        msg = content.get('ename', 'unknown error')
+
+        evalue = content.get('evalue')
+        if evalue is not None:
+          msg = '%s: %s' % (msg, evalue)
+      else:
+        msg = ''.join(tb)
+
+      raise QueryError(msg)
+
+  def download(self, notebook, snippet, format):
+    try:
+      api = get_spark_api(self.user)
+      session = _get_snippet_session(notebook, snippet)
+      cell = snippet['result']['handle']['id']
+
+      return spark_download(api, session['id'], cell, format)
+    except Exception, e:
+      raise PopupException(e)
+
+  def cancel(self, notebook, snippet):
+    api = get_spark_api(self.user)
+    session = _get_snippet_session(notebook, snippet)
+    response = api.cancel(session['id'])
+
+    return {'status': 0}
+
+  def get_log(self, notebook, snippet, startFrom=0, size=None):
+    api = get_spark_api(self.user)
+    session = _get_snippet_session(notebook, snippet)
+
+    return api.get_log(session['id'], startFrom=startFrom, size=size)
+
+  def _progress(self, snippet, logs):
+    return 50
+
+  def close_statement(self, snippet): # Individual statements cannot be closed
+    pass
+
+  def close_session(self, session):
+    api = get_spark_api(self.user)
+
+    if session['id'] is not None:
+      try:
+        api.close(session['id'])
+        return {
+          'session': session['id'],
+          'status': 0
+        }
+      except RestException, e:
+        if e.code == 404 or e.code == 500: # TODO remove the 500
+          raise SessionExpired(e)
+    else:
+      return {'status': -1}
+
+  def _get_jobs(self, log):
+    return []

+ 28 - 0
desktop/libs/notebook/src/notebook/connectors/text.py

@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from notebook.connectors.base import Api
+
+
+LOG = logging.getLogger(__name__)
+
+
+class TextApi(Api):
+
+  pass

+ 1 - 2
desktop/libs/notebook/src/notebook/decorators.py

@@ -27,8 +27,7 @@ from desktop.lib.exceptions_renderable import PopupException
 from desktop.lib.i18n import force_unicode
 from desktop.models import Document2, Document
 
-
-from notebook.models import QueryExpired, QueryError, SessionExpired
+from notebook.connectors.base import QueryExpired, QueryError, SessionExpired
 
 
 LOG = logging.getLogger(__name__)

+ 0 - 523
desktop/libs/notebook/src/notebook/models.py

@@ -14,526 +14,3 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
-import json
-import logging
-import re
-import time
-
-from django.utils.translation import ugettext as _
-
-from desktop.lib.exceptions_renderable import PopupException
-from desktop.lib.i18n import force_unicode
-from desktop.lib.rest.http_client import RestException
-
-from beeswax import data_export
-from beeswax.design import hql_query
-from beeswax import conf as beeswax_conf
-from beeswax.models import QUERY_TYPES, HiveServerQueryHandle, QueryHistory, HiveServerQueryHistory
-from beeswax.server import dbms
-from beeswax.server.dbms import get_query_server_config, QueryServerException
-from beeswax.views import _parse_out_hadoop_jobs
-
-from spark.job_server_api import get_api as get_spark_api
-
-from notebook.data_export import download as spark_download
-
-
-LOG = logging.getLogger(__name__)
-
-
-# To move to Editor API
-class SessionExpired(Exception):
-  pass
-
-
-class QueryExpired(Exception):
-  pass
-
-
-class QueryError(Exception):
-  def __init__(self, message):
-    self.message = message
-
-  def __str__(self):
-    return force_unicode(str(self.message))
-
-
-class Notebook():
-
-  def __init__(self, document=None):
-    self.document = None
-
-    if document is not None:
-      self.data = document.data
-      self.document = document
-    else:
-      self.data = json.dumps({
-          'name': 'My Notebook',
-          'description': '',
-          'snippets': []
-      })
-
-  def get_json(self):
-    _data = self.get_data()
-
-    return json.dumps(_data)
-
-  def get_data(self):
-    _data = json.loads(self.data)
-
-    if self.document is not None:
-      _data['id'] = self.document.id
-
-    return _data
-
-  def get_str(self):
-    return '\n\n'.join([snippet['statement_raw'] for snippet in self.get_data()['snippets']])
-
-
-def get_api(user, snippet):
-  if snippet['type'] in ('hive', 'impala', 'spark-sql'):
-    return HS2Api(user)
-  elif snippet['type'] in ('jar', 'py'):
-    return SparkBatchApi(user)
-  elif snippet['type'] == 'text':
-    return TextApi(user)
-  else:
-    return SparkApi(user)
-
-
-def _get_snippet_session(notebook, snippet):
-  return [session for session in notebook['sessions'] if session['type'] == snippet['type']][0]
-
-
-# Base API
-
-class Api(object):
-
-  def __init__(self, user):
-    self.user = user
-
-  def create_session(self, lang, properties=None):
-    return {
-        'type': lang,
-        'id': None,
-        'properties': []
-    }
-
-  def close_session(self, session):
-    pass
-
-
-# Text
-
-class TextApi(Api):
-
-  pass
-
-
-# HS2
-
-def query_error_handler(func):
-  def decorator(*args, **kwargs):
-    try:
-      return func(*args, **kwargs)
-    except QueryServerException, e:
-      message = force_unicode(str(e))
-      if 'Invalid query handle' in message or 'Invalid OperationHandle' in message:
-        raise QueryExpired(e)
-      else:
-        raise QueryError(message)
-  return decorator
-
-
-class HS2Api(Api):
-
-  def _get_handle(self, snippet):
-    snippet['result']['handle']['secret'], snippet['result']['handle']['guid'] = HiveServerQueryHandle.get_decoded(snippet['result']['handle']['secret'], snippet['result']['handle']['guid'])
-    return HiveServerQueryHandle(**snippet['result']['handle'])
-
-  def _get_db(self, snippet):
-    if snippet['type'] == 'hive':
-      name = 'beeswax'
-    elif snippet['type'] == 'impala':
-      name = 'impala'
-    else:
-      name = 'spark-sql'
-
-    return dbms.get(self.user, query_server=get_query_server_config(name=name))
-
-  def execute(self, notebook, snippet):
-    db = self._get_db(snippet)
-    query = hql_query(snippet['statement'], QUERY_TYPES[0])
-
-    try:
-      handle = db.client.query(query)
-    except QueryServerException, ex:
-      raise QueryError(ex.message)
-
-    # All good
-    server_id, server_guid  = handle.get()
-    return {
-        'secret': server_id,
-        'guid': server_guid,
-        'operation_type': handle.operation_type,
-        'has_result_set': handle.has_result_set,
-        'modified_row_count': handle.modified_row_count,
-        'log_context': handle.log_context
-    }
-
-  @query_error_handler
-  def check_status(self, notebook, snippet):
-    response = {}
-    db = self._get_db(snippet)
-
-    handle = self._get_handle(snippet)
-    operation = db.get_operation_status(handle)
-    status = HiveServerQueryHistory.STATE_MAP[operation.operationState]
-
-    if status.index in (QueryHistory.STATE.failed.index, QueryHistory.STATE.expired.index):
-      raise QueryError(operation.errorMessage)
-
-    response['status'] = 'running' if status.index in (QueryHistory.STATE.running.index, QueryHistory.STATE.submitted.index) else 'available'
-
-    return response
-
-  @query_error_handler
-  def fetch_result(self, notebook, snippet, rows, start_over):
-    db = self._get_db(snippet)
-
-    handle = self._get_handle(snippet)
-    results = db.fetch(handle, start_over=start_over, rows=rows)
-
-    # No escaping...
-    return {
-        'has_more': results.has_more,
-        'data': list(results.rows()),
-        'meta': [{
-          'name': column.name,
-          'type': column.type,
-          'comment': column.comment
-        } for column in results.data_table.cols()],
-        'type': 'table'
-    }
-
-  @query_error_handler
-  def fetch_result_metadata(self):
-    pass
-
-  @query_error_handler
-  def cancel(self, notebook, snippet):
-    db = self._get_db(snippet)
-
-    handle = self._get_handle(snippet)
-    db.cancel_operation(handle)
-    return {'status': 0}
-
-  @query_error_handler
-  def get_log(self, notebook, snippet, startFrom=None, size=None):
-    db = self._get_db(snippet)
-
-    handle = self._get_handle(snippet)
-    return db.get_log(handle, start_over=startFrom == 0)
-
-  def download(self, notebook, snippet, format):
-    try:
-      db = self._get_db(snippet)
-      handle = self._get_handle(snippet)
-      return data_export.download(handle, format, db)
-    except Exception, e:
-      LOG.exception('error downloading notebook')
-
-      if not hasattr(e, 'message') or not e.message:
-        message = e
-      else:
-        message = e.message
-      raise PopupException(message, detail='')
-
-  def _progress(self, snippet, logs):
-    if snippet['type'] == 'hive':
-      match = re.search('Total jobs = (\d+)', logs, re.MULTILINE)
-      total = (int(match.group(1)) if match else 1) * 2
-
-      started = logs.count('Starting Job')
-      ended = logs.count('Ended Job')
-
-      return int((started + ended) * 100 / total)
-    elif snippet['type'] == 'impala':
-      match = re.search('(\d+)% Complete', logs, re.MULTILINE)
-      return int(match.group(1)) if match else 0
-    else:
-      return 50
-
-  @query_error_handler
-  def close_statement(self, snippet):
-    if snippet['type'] == 'impala':
-      from impala import conf as impala_conf
-
-    if (snippet['type'] == 'hive' and beeswax_conf.CLOSE_QUERIES.get()) or (snippet['type'] == 'impala' and impala_conf.CLOSE_QUERIES.get()):
-      db = self._get_db(snippet)
-
-      handle = self._get_handle(snippet)
-      db.close_operation(handle)
-      return {'status': 0}
-    else:
-      return {'status': -1}  # skipped
-
-  def _get_jobs(self, log):
-    return _parse_out_hadoop_jobs(log)
-
-
-# Spark
-
-class SparkApi(Api):
-  PROPERTIES = [
-    {'name': 'jars', 'nice_name': _('Jars'), 'default': '', 'type': 'csv-hdfs-files', 'is_yarn': False},
-    {'name': 'files', 'nice_name': _('Files'), 'default': '', 'type': 'csv-hdfs-files', 'is_yarn': False},
-    {'name': 'pyFiles', 'nice_name': _('pyFiles'), 'default': '', 'type': 'csv-hdfs-files', 'is_yarn': False},
-
-    {'name': 'driverMemory', 'nice_name': _('Driver Memory'), 'default': '1', 'type': 'jvm', 'is_yarn': False},
-
-    {'name': 'driverCores', 'nice_name': _('Driver Cores'), 'default': '1', 'type': 'number', 'is_yarn': True},
-    {'name': 'executorCores', 'nice_name': _('Executor Cores'), 'default': '1', 'type': 'number', 'is_yarn': True},
-    {'name': 'queue', 'nice_name': _('Queue'), 'default': '1', 'type': 'string', 'is_yarn': True},
-    {'name': 'archives', 'nice_name': _('Archives'), 'default': '', 'type': 'csv-hdfs-files', 'is_yarn': True},
-    {'name': 'numExecutors', 'nice_name': _('Executors Numbers'), 'default': '1', 'type': 'number', 'is_yarn': True},
-  ]
-
-  def create_session(self, lang='scala', properties=None):
-    properties = dict([(p['name'], p['value']) for p in properties]) if properties is not None else {}
-
-    properties['kind'] = lang
-
-    api = get_spark_api(self.user)
-
-    response = api.create_session(**properties)
-
-    status = api.get_session(response['id'])
-    count = 0
-
-    while status['state'] == 'starting' and count < 120:
-      status = api.get_session(response['id'])
-      count += 1
-      time.sleep(1)
-
-    if status['state'] != 'idle':
-      info = '\n'.join(status['log']) if status['log'] else 'timeout'
-      raise QueryError(_('The Spark session could not be created in the cluster: %s') % info)
-
-    return {
-        'type': lang,
-        'id': response['id'],
-        'properties': []
-    }
-
-  def execute(self, notebook, snippet):
-    api = get_spark_api(self.user)
-    session = _get_snippet_session(notebook, snippet)
-
-    try:
-      response = api.submit_statement(session['id'], snippet['statement'])
-      return {
-          'id': response['id'],
-          'has_result_set': True,
-      }
-    except Exception, e:
-      message = force_unicode(str(e)).lower()
-      if 'session not found' in message or 'connection refused' in message or 'session is in state busy' in message:
-        raise SessionExpired(e)
-      else:
-        raise e
-
-  def check_status(self, notebook, snippet):
-    api = get_spark_api(self.user)
-    session = _get_snippet_session(notebook, snippet)
-    cell = snippet['result']['handle']['id']
-
-    try:
-      response = api.fetch_data(session['id'], cell)
-      return {
-          'status': response['state'],
-      }
-    except Exception, e:
-      message = force_unicode(str(e)).lower()
-      if 'session not found' in message:
-        raise SessionExpired(e)
-      else:
-        raise e
-
-  def fetch_result(self, notebook, snippet, rows, start_over):
-    api = get_spark_api(self.user)
-    session = _get_snippet_session(notebook, snippet)
-    cell = snippet['result']['handle']['id']
-
-    try:
-      response = api.fetch_data(session['id'], cell)
-    except Exception, e:
-      message = force_unicode(str(e)).lower()
-      if 'session not found' in message:
-        raise SessionExpired(e)
-      else:
-        raise e
-
-    content = response['output']
-
-    if content['status'] == 'ok':
-      data = content['data']
-      images = []
-
-      try:
-        table = data['application/vnd.livy.table.v1+json']
-      except KeyError:
-        try:
-          images = [data['image/png']]
-        except KeyError:
-          images = []
-        data = [[data['text/plain']]]
-        meta = [{'name': 'Header', 'type': 'STRING_TYPE', 'comment': ''}]
-        type = 'text'
-      else:
-        data = table['data']
-        headers = table['headers']
-        meta = [{'name': h['name'], 'type': h['type'], 'comment': ''} for h in headers]
-        type = 'table'
-
-      # Non start_over not supported
-      if not start_over:
-        data = []
-
-      return {
-          'data': data,
-          'images': images,
-          'meta': meta,
-          'type': type
-      }
-    elif content['status'] == 'error':
-      tb = content.get('traceback', None)
-
-      if tb is None:
-        msg = content.get('ename', 'unknown error')
-
-        evalue = content.get('evalue')
-        if evalue is not None:
-          msg = '%s: %s' % (msg, evalue)
-      else:
-        msg = ''.join(tb)
-
-      raise QueryError(msg)
-
-  def download(self, notebook, snippet, format):
-    try:
-      api = get_spark_api(self.user)
-      session = _get_snippet_session(notebook, snippet)
-      cell = snippet['result']['handle']['id']
-
-      return spark_download(api, session['id'], cell, format)
-    except Exception, e:
-      raise PopupException(e)
-
-  def cancel(self, notebook, snippet):
-    api = get_spark_api(self.user)
-    session = _get_snippet_session(notebook, snippet)
-    response = api.cancel(session['id'])
-
-    return {'status': 0}
-
-  def get_log(self, notebook, snippet, startFrom=0, size=None):
-    api = get_spark_api(self.user)
-    session = _get_snippet_session(notebook, snippet)
-
-    return api.get_log(session['id'], startFrom=startFrom, size=size)
-
-  def _progress(self, snippet, logs):
-    return 50
-
-  def close_statement(self, snippet): # Individual statements cannot be closed
-    pass
-
-  def close_session(self, session):
-    api = get_spark_api(self.user)
-
-    if session['id'] is not None:
-      try:
-        api.close(session['id'])
-        return {
-          'session': session['id'],
-          'status': 0
-        }
-      except RestException, e:
-        if e.code == 404 or e.code == 500: # TODO remove the 500
-          raise SessionExpired(e)
-    else:
-      return {'status': -1}
-
-  def _get_jobs(self, log):
-    return []
-
-
-class SparkBatchApi(Api):
-
-  def create_session(self, lang, properties=None):
-    return {
-        'type': lang,
-        'id': None
-    }
-
-  def execute(self, notebook, snippet):
-    api = get_spark_api(self.user)
-
-    properties = {
-        'file': snippet['properties'].get('app_jar'),
-        'className': snippet['properties'].get('class'),
-        'args': snippet['properties'].get('arguments'),
-        'pyFiles': snippet['properties'].get('py_file'),
-        # files
-        # driverMemory
-        # driverCores
-        # executorMemory
-        # executorCores
-        # archives
-    }
-
-    response = api.submit_batch(properties)
-    return {
-        'id': response['id'],
-        'has_result_set': True,
-        'properties': []
-    }
-
-  def check_status(self, notebook, snippet):
-    api = get_spark_api(self.user)
-
-    state = api.get_batch_status(snippet['result']['handle']['id'])
-    return {
-        'status': state,
-    }
-
-  def get_log(self, notebook, snippet, startFrom=0, size=None):
-    api = get_spark_api(self.user)
-
-    return api.get_batch_log(snippet['result']['handle']['id'], startFrom=startFrom, size=size)
-
-  def close_statement(self, snippet):
-    api = get_spark_api(self.user)
-
-    session_id = snippet['result']['handle']['id']
-    if session_id is not None:
-      api.close_batch(session_id)
-      return {
-        'session': session_id,
-        'status': 0
-      }
-    else:
-      return {'status': -1}  # skipped
-
-  def cancel(self, notebook, snippet):
-    # Batch jobs do not support interruption, so close statement instead.
-    return self.close_statement(snippet)
-
-  def _progress(self, snippet, logs):
-    return 50
-
-  def _get_jobs(self, log):
-    return []

+ 2 - 1
desktop/libs/notebook/src/notebook/views.py

@@ -28,8 +28,9 @@ from spark.conf import LIVY_SERVER_SESSION_KIND
 
 from notebook.conf import LANGUAGES
 from notebook.decorators import check_document_access_permission, check_document_modify_permission
-from notebook.models import Notebook, get_api, SparkApi
+from notebook.connectors.base import Notebook, get_api
 from notebook.management.commands.notebook_setup import Command
+from notebook.connectors.spark_shell import SparkApi
 
 
 LOG = logging.getLogger(__name__)