Ver código fonte

HUE-8747 [editor] Download query result as task

jdesjean 6 anos atrás
pai
commit
a66d30e129

+ 9 - 6
apps/beeswax/src/beeswax/data_export.py

@@ -33,7 +33,7 @@ FETCH_SIZE = 1000
 DOWNLOAD_COOKIE_AGE = 1800 # 30 minutes
 
 
-def download(handle, format, db, id=None, file_name='query_result', user_agent=None, callback=None):
+def download(handle, format, db, id=None, file_name='query_result', user_agent=None, callback=None, max_rows=None, store_data_type_in_header=False):
   """
   download(query_model, format) -> HttpResponse
 
@@ -43,10 +43,10 @@ def download(handle, format, db, id=None, file_name='query_result', user_agent=N
     LOG.error('Unknown download format "%s"' % (format,))
     return
 
-  max_rows = conf.DOWNLOAD_ROW_LIMIT.get()
-  max_bytes = conf.DOWNLOAD_BYTES_LIMIT.get()
+  max_rows = max_rows if max_rows else conf.DOWNLOAD_ROW_LIMIT.get()
+  max_bytes = -1 if max_rows else conf.DOWNLOAD_BYTES_LIMIT.get()
 
-  content_generator = HS2DataAdapter(handle, db, max_rows=max_rows, start_over=True, max_bytes=max_bytes, callback=callback)
+  content_generator = HS2DataAdapter(handle, db, max_rows=max_rows, start_over=True, max_bytes=max_bytes, callback=callback, store_data_type_in_header=store_data_type_in_header)
   generator = export_csvxls.create_generator(content_generator, format)
 
   resp = export_csvxls.make_response(generator, format, file_name, user_agent=user_agent)
@@ -83,7 +83,7 @@ def upload(path, handle, user, db, fs, max_rows=-1, max_bytes=-1):
 
 class HS2DataAdapter:
 
-  def __init__(self, handle, db, max_rows=-1, start_over=True, max_bytes=-1, callback=None):
+  def __init__(self, handle, db, max_rows=-1, start_over=True, max_bytes=-1, callback=None, store_data_type_in_header=False):
     self.handle = handle
     self.db = db
     self.max_rows = max_rows
@@ -97,10 +97,11 @@ class HS2DataAdapter:
     self.first_fetched = True
     self.headers = None
     self.num_cols = None
-    self.row_counter = 1
+    self.row_counter = 0
     self.bytes_counter = 0
     self.is_truncated = False
     self.has_more = True
+    self.store_data_type_in_header = store_data_type_in_header
 
   def __iter__(self):
     return self
@@ -141,6 +142,8 @@ class HS2DataAdapter:
       self.start_over = False
       self.headers = results.cols()
       self.num_cols = len(self.headers)
+      if self.store_data_type_in_header:
+        self.headers = [column['name'] + '|' + column['type'] for column in results.full_cols()]
       if self.limit_bytes:
         self.bytes_counter += max(self.num_cols - 1, 0)
         for header in self.headers:

+ 7 - 1
desktop/core/src/desktop/conf.py

@@ -1640,12 +1640,18 @@ TASK_SERVER = ConfigSection(
       default='--time-limit=300',
       help=_('Default options provided to the task server at startup.')
     ),
-    BEAT_ENABLED= Config(
+    BEAT_ENABLED = Config(
       key='beat_enabled',
       default=False,
       type=coerce_bool,
       help=_('Switch on the integration with the Task Scheduler.')
     ),
+    PREFETCH_RESULT_COUNT = Config(
+      key='prefetch_result_count',
+      default=2000,
+      type=coerce_positive_integer,
+      help=_('Number of rows to prefetch to Hue storage')
+    ),
 ))
 
 

+ 2 - 6
desktop/core/src/desktop/lib/export_csvxls.py

@@ -36,7 +36,7 @@ from desktop.lib import i18n
 LOG = logging.getLogger(__name__)
 
 ILLEGAL_CHARS = r'[\000-\010]|[\013-\014]|[\016-\037]'
-
+FORMAT_TO_CONTENT_TYPE = {'csv': 'application/csv', 'xls': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'json': 'application/json'}
 
 def nullify(cell):
   return cell if cell is not None else "NULL"
@@ -120,20 +120,16 @@ def make_response(generator, format, name, encoding=None, user_agent=None):
   @param name Base name for output file
   @param encoding Unicode encoding for data
   """
+  content_type = FORMAT_TO_CONTENT_TYPE.get(format, 'application/octet-stream')
   if format == 'csv':
-    content_type = 'application/csv'
     resp = StreamingHttpResponse(generator, content_type=content_type)
     try:
       del resp['Content-Length']
     except KeyError:
       pass
   elif format == 'xls':
-    format = 'xlsx'
-    content_type = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
     resp = HttpResponse(next(generator), content_type=content_type)
-
   elif format == 'json':
-    content_type = 'application/json'
     resp = HttpResponse(generator, content_type=content_type)
   else:
     raise Exception("Unknown format: %s" % format)

+ 5 - 6
desktop/libs/notebook/src/notebook/api.py

@@ -27,25 +27,24 @@ from django.utils.translation import ugettext as _
 from django.views.decorators.http import require_GET, require_POST
 
 from desktop.api2 import __paginate
-from desktop.conf import IS_K8S_ONLY
+from desktop.conf import IS_K8S_ONLY, TASK_SERVER
 from desktop.lib.i18n import smart_str
 from desktop.lib.django_util import JsonResponse
 from desktop.models import Document2, Document
 from indexer.file_format import HiveFormat
 from indexer.fields import Field
 
-from notebook.connectors.base import get_api, Notebook, QueryExpired, SessionExpired, QueryError, _get_snippet_name
+from notebook.connectors.base import Notebook, QueryExpired, SessionExpired, QueryError, _get_snippet_name
 from notebook.connectors.dataeng import DataEngApi
 from notebook.connectors.hiveserver2 import HS2Api
 from notebook.connectors.oozie_batch import OozieApi
 from notebook.decorators import api_error_handler, check_document_access_permission, check_document_modify_permission
 from notebook.models import escape_rows, make_notebook
-from notebook.views import upgrade_session_properties
+from notebook.views import upgrade_session_properties, get_api
 
 
 LOG = logging.getLogger(__name__)
 
-
 DEFAULT_HISTORY_NAME = ''
 
 
@@ -536,7 +535,7 @@ def close_notebook(request):
   for snippet in [_s for _s in notebook['snippets'] if _s['type'] in ('hive', 'impala')]:
     try:
       if snippet['status'] != 'running':
-        response['result'].append(get_api(request, snippet).close_statement(snippet))
+        response['result'].append(get_api(request, snippet).close_statement(notebook, snippet))
       else:
         LOG.info('Not closing SQL snippet as still running.')
     except QueryExpired:
@@ -560,7 +559,7 @@ def close_statement(request):
   snippet = json.loads(request.POST.get('snippet', '{}'))
 
   try:
-    response['result'] = get_api(request, snippet).close_statement(snippet)
+    response['result'] = get_api(request, snippet).close_statement(notebook, snippet)
   except QueryExpired:
     pass
 

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/altus_adb.py

@@ -57,7 +57,7 @@ class AltusAdbApi(Api):
     return HueQuery(self.user, cluster_crn=self.cluster_name).do_fetch_result(handle)
 
 
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     return {'status': -1}
 
 

+ 1 - 0
desktop/libs/notebook/src/notebook/connectors/base.py

@@ -54,6 +54,7 @@ class OperationNotSupported(Exception):
 
 class QueryError(Exception):
   def __init__(self, message, handle=None):
+    super(QueryError, self).__init__(message)
     self.message = message or _('No error message, please check the logs.')
     self.handle = handle
     self.extra = {}

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/dataeng.py

@@ -133,7 +133,7 @@ class DataEngApi(Api):
     ]
 
 
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     pass
 
 

+ 3 - 3
desktop/libs/notebook/src/notebook/connectors/hiveserver2.py

@@ -361,7 +361,7 @@ class HS2Api(Api):
 
 
   @query_error_handler
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     if snippet['type'] == 'impala':
       from impala import conf as impala_conf
 
@@ -382,7 +382,7 @@ class HS2Api(Api):
 
 
   @query_error_handler
-  def download(self, notebook, snippet, format, user_agent=None):
+  def download(self, notebook, snippet, format, user_agent=None, max_rows=None, store_data_type_in_header=False):
     try:
       db = self._get_db(snippet, cluster=self.cluster)
       handle = self._get_handle(snippet)
@@ -391,7 +391,7 @@ class HS2Api(Api):
 
       file_name = _get_snippet_name(notebook)
 
-      return data_export.download(handle, format, db, id=snippet['id'], file_name=file_name, user_agent=user_agent)
+      return data_export.download(handle, format, db, id=snippet['id'], file_name=file_name, user_agent=user_agent, max_rows=max_rows, store_data_type_in_header=store_data_type_in_header)
     except Exception, e:
       title = 'The query result cannot be downloaded.'
       LOG.exception(title)

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/jdbc.py

@@ -131,7 +131,7 @@ class JdbcApi(Api):
     return 50
 
   @query_error_handler
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     return {'status': -1}
 
   @query_error_handler

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/oozie_batch.py

@@ -164,7 +164,7 @@ class OozieApi(Api):
     return jobs
 
 
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     pass
 
 

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/rdbms.py

@@ -117,7 +117,7 @@ class RdbmsApi(Api):
 
 
   @query_error_handler
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     return {'status': -1}
 
 

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/solr.py

@@ -134,7 +134,7 @@ class SolrApi(Api):
 
 
   @query_error_handler
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     return {'status': -1}
 
 

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/spark_batch.py

@@ -78,7 +78,7 @@ class SparkBatchApi(Api):
 
     return api.get_batch_log(snippet['result']['handle']['id'], startFrom=startFrom, size=size)
 
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     api = get_spark_api(self.user)
 
     session_id = snippet['result']['handle']['id']

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -352,7 +352,7 @@ class SparkApi(Api):
   def progress(self, snippet, logs):
     return 50
 
-  def close_statement(self, snippet): # Individual statements cannot be closed
+  def close_statement(self, notebook, snippet): # Individual statements cannot be closed
     pass
 
   def close_session(self, session):

+ 1 - 1
desktop/libs/notebook/src/notebook/connectors/sqlalchemyapi.py

@@ -230,7 +230,7 @@ class SqlAlchemyApi(Api):
 
 
   @query_error_handler
-  def close_statement(self, snippet):
+  def close_statement(self, notebook, snippet):
     result = {'status': -1}
 
     try:

+ 2 - 2
desktop/libs/notebook/src/notebook/dashboard_api.py

@@ -351,7 +351,7 @@ class SQLDashboardApi(DashboardApi):
           status = api.check_status(mock_notebook, snippet)
           if status['status'] == 'available':
             result = api.fetch_result(mock_notebook, snippet, rows=10, start_over=True)
-            api.close_statement(snippet)
+            api.close_statement(mock_notebook, snippet)
             break
           time.sleep(sleep_interval)
           curr = time.time()
@@ -361,7 +361,7 @@ class SQLDashboardApi(DashboardApi):
             api.cancel_operation(snippet)
           except Exception, e:
             LOG.warning("Failed to cancel query: %s" % e)
-            api.close_statement(snippet)
+            api.close_statement(mock_notebook, snippet)
           raise OperationTimeout(e)
 
     return result

+ 363 - 0
desktop/libs/notebook/src/notebook/tasks.py

@@ -0,0 +1,363 @@
+#!/usr/bin/env python
+# Licensed to Cloudera, Inc. under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  Cloudera, Inc. licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import, unicode_literals
+
+import csv
+import os
+import django
+import json
+import logging
+import tempfile
+import time
+
+from celery.utils.log import get_task_logger
+from celery import states
+
+from django.contrib.auth.models import User
+from django.db import transaction
+from django.http import FileResponse, HttpRequest
+
+from desktop.auth.backend import rewrite_user
+from desktop.celery import app
+from desktop.conf import TASK_SERVER
+from desktop.lib.export_csvxls import FORMAT_TO_CONTENT_TYPE
+
+from notebook.connectors.base import get_api, QueryExpired
+
+LOG_TASK = get_task_logger(__name__)
+LOG = logging.getLogger(__name__)
+DOWNLOAD_COOKIE_AGE = 3600
+STATE_MAP = {
+  'SUBMITTED': 'waiting',
+  states.RECEIVED: 'waiting',
+  states.PENDING: 'waiting',
+  states.STARTED: 'running',
+  states.RETRY: 'running',
+  states.SUCCESS: 'available',
+  'PROGRESS': 'running',
+  states.FAILURE: 'failure',
+  states.REVOKED: 'canceled',
+  states.REJECTED: 'rejected',
+  states.IGNORED: 'ignored'
+}
+
+#TODO: Add periodic cleanup task
+#TODO: move file paths to a file like API so we can change implementation
+@app.task()
+def download_to_file(notebook, snippet, file_format='csv', user_agent=None, postdict=None, user_id=None, create=False, store_data_type_in_header=False):
+  download_to_file.update_state(task_id=notebook['uuid'], state='STARTED', meta={})
+  request = _get_request(postdict, user_id)
+  api = get_api(request, snippet)
+  if create:
+    handle = api.execute(notebook, snippet)
+  else:
+    handle = snippet['result']['handle']
+
+  f, path = tempfile.mkstemp()
+  f_log, path_log = tempfile.mkstemp()
+  f_progress, path_progress = tempfile.mkstemp()
+  try:
+    os.write(f_progress, '0')
+    meta = {'row_counter': 0, 'file_path': path, 'handle': handle, 'log_path': path_log, 'progress_path': path_progress, 'status': 'running', 'truncated': False}
+    download_to_file.update_state(task_id=notebook['uuid'], state='PROGRESS', meta=meta)
+    _until_available(notebook, snippet, api, f_log, handle, meta)
+
+    snippet['result']['handle'] = handle.copy()
+    #TODO: Move PREFETCH_RESULT_COUNT to front end
+    response = api.download(notebook, snippet, file_format, user_agent=user_agent, max_rows=TASK_SERVER.PREFETCH_RESULT_COUNT.get(), store_data_type_in_header=store_data_type_in_header)
+
+    row_count = 0
+    for chunk in response:
+      os.write(f, chunk)
+      row_count += chunk.count('\n')
+      meta['row_counter'] = row_count - 1
+      download_to_file.update_state(task_id=notebook['uuid'], state='PROGRESS', meta=meta)
+
+    api.close_statement(notebook, snippet)
+  finally:
+    os.close(f)
+    os.close(f_log)
+    os.close(f_progress)
+  return meta
+
+@app.task(ignore_result=True)
+def cancel_async(notebook, snippet, postdict=None, user_id=None):
+  request = _get_request(postdict, user_id)
+  get_api(request, snippet).cancel(notebook, snippet)
+
+@app.task(ignore_result=True)
+def close_statement_async(notebook, snippet, postdict=None, user_id=None):
+  request = _get_request(postdict, user_id)
+  get_api(request, snippet).close_statement(notebook, snippet)
+
+def _until_available(notebook, snippet, api, f, handle, meta):
+  count = 0
+  sleep_seconds = 1
+  check_status_count = 0
+  while True:
+    snippet['result']['handle'] = handle.copy()
+    response = api.check_status(notebook, snippet)
+    meta['status'] = response['status']
+    download_to_file.update_state(task_id=notebook['uuid'], state='PROGRESS', meta=meta)
+    snippet['result']['handle'] = handle.copy()
+    log = api.get_log(notebook, snippet, startFrom=count)
+    os.write(f, log)
+    count += log.count('\n')
+    if response['status'] == 'available':
+      break
+    check_status_count += 1
+    if check_status_count > 5:
+      sleep_seconds = 5
+    elif check_status_count > 10:
+      sleep_seconds = 10
+    time.sleep(sleep_seconds)
+
+#TODO: Convert csv to excel if needed
+def download(*args, **kwargs):
+  result = download_to_file.AsyncResult(args[0]['uuid'])
+  state = result.state
+  if state == states.PENDING:
+    raise QueryExpired()
+  elif state in states.EXCEPTION_STATES:
+    result.maybe_reraise()
+
+  info = result.wait()
+  response = FileResponse(open(info['file_path'], 'rb'), content_type=FORMAT_TO_CONTENT_TYPE.get('csv', 'application/octet-stream'))
+  response['Content-Disposition'] = 'attachment; filename="%s.%s"' % (args[0]['uuid'], 'csv') #TODO: Add support for 3rd party (e.g. nginx file serving)
+  response.set_cookie(
+      'download-%s' % args[1]['id'],
+      json.dumps({
+        'truncated': info.get('truncated', False),
+        'row_counter': info.get('row_counter', 0)
+      }),
+      max_age=DOWNLOAD_COOKIE_AGE
+    )
+  return response
+
+# Why we need this:
+# 1) There is no way in celery to differentiate between a task that was submitted, but not yet started and a task that has been GCed.
+# 2) The client will keep checking for data until the query is expired. The new definition for expired in this case is a task that has been GCed.
+def _patch_status(notebook):
+  result = download_to_file.AsyncResult(notebook['uuid'])
+  result.backend.store_result(notebook['uuid'], None, "SUBMITTED")
+
+def execute(*args, **kwargs):
+  notebook = args[0]
+  kwargs['create'] = True
+  kwargs['store_data_type_in_header'] = True
+  _patch_status(notebook)
+  download_to_file.apply_async(args=args, kwargs=kwargs, task_id=notebook['uuid'])
+  return {'sync': False,
+      'has_result_set': True,
+      'modified_row_count': 0,
+      'guid': '',
+      'result': {
+        'has_more': True,
+        'data': [],
+        'meta': [],
+        'type': 'table'
+      }}
+
+def check_status(*args, **kwargs):
+  notebook = args[0]
+  result = download_to_file.AsyncResult(notebook['uuid'])
+  state = result.state
+  if state == states.PENDING:
+    raise QueryExpired()
+  elif state in states.EXCEPTION_STATES:
+    result.maybe_reraise()
+
+  info = result.info
+  if not info or not info.get('status'):
+    status = STATE_MAP[state]
+  else:
+    status = info.get('status')
+  return {'status': status}
+
+def get_log(notebook, snippet, startFrom=None, size=None, postdict=None, user_id=None):
+  result = download_to_file.AsyncResult(notebook['uuid'])
+  state = result.state
+  if state == states.PENDING:
+    raise QueryExpired()
+  elif state == 'SUBMITTED' or states.state(result.state) < states.state('PROGRESS'):
+    return ''
+  elif state in states.EXCEPTION_STATES:
+    result.maybe_reraise()
+    return ''
+
+  info = result.info
+  if not startFrom:
+    with open(info.get('log_path'), 'r') as f:
+      return f.read()
+  else:
+    count = 0
+    data = ''
+    with open(info.get('log_path'), 'r') as f:
+      for line in f:
+        count += 1
+        if count <= startFrom:
+          continue
+        data += line
+    return data
+
+def get_jobs(notebook, snippet, logs, **kwargs): #Re implement to fetch updated guid in download_to_file from DB
+  result = download_to_file.AsyncResult(notebook['uuid'])
+  state = result.state
+  if state == states.PENDING:
+    raise QueryExpired()
+  elif state == 'SUBMITTED' or states.state(result.state) < states.state('PROGRESS'):
+    return []
+  elif state in states.EXCEPTION_STATES:
+    result.maybe_reraise()
+    return []
+
+  info = result.info
+  snippet['result']['handle'] = info.get('handle', {})
+
+  request = _get_request(**kwargs)
+  api = get_api(request, snippet)
+  #insiduous problem where each call in hive api transform the guid/secret to binary form. get_log does the transform, but not get_jobs. get_jobs called after get_log so usually not an issue. Our get_log implementation doesn't
+  if hasattr(api, '_get_handle'): # This is specific to impala, should be handled in hiveserver2
+    api._get_handle(snippet)
+  return api.get_jobs(notebook, snippet, logs)
+
+def fetch_result(notebook, snippet, rows, start_over, **kwargs):
+  result = download_to_file.AsyncResult(notebook['uuid'])
+  state = result.state
+  if state == states.PENDING:
+    raise QueryExpired()
+  elif state == 'SUBMITTED' or states.state(result.state) < states.state('PROGRESS'):
+    return {
+      'has_more': False,
+      'data': [],
+      'meta': [],
+      'type': 'table'
+    }
+  elif state in states.EXCEPTION_STATES:
+    result.maybe_reraise()
+    return {
+      'has_more': False,
+      'data': [],
+      'meta': [],
+      'type': 'table'
+    }
+
+  info = result.info
+  data = []
+  skip = 0
+  if not start_over:
+    with open(info.get('progress_path'), 'r') as f:
+      skip = int(f.read())
+  target = skip + rows
+
+  with open(info.get('file_path'), 'r') as f:
+    csv_reader = csv.reader(f, delimiter=','.encode('utf-8'))
+    first = next(csv_reader)
+    cols = map(lambda x: {'name': x.split('|')[0], 'type': x.split('|')[1], 'comment': None}, first)
+    count = 0
+    for row in csv_reader:
+      count += 1
+      if count <= skip:
+        continue
+      data.append(row)
+      if count >= target:
+        break
+
+  with open(info.get('progress_path'), 'w') as f:
+    f.write(str(count))
+
+  has_more = count < info.get('row_counter') or state == states.state('PROGRESS')
+
+  return {
+      'has_more': has_more,
+      'data': data,
+      'meta': cols,
+      'type': 'table'
+  }
+
+def fetch_result_size(*args, **kwargs):
+  notebook = args[0]
+  result = download_to_file.AsyncResult(notebook['uuid'])
+  state = result.state
+  if state == states.PENDING:
+    raise QueryExpired()
+  elif state == 'SUBMITTED' or states.state(result.state) < states.state('PROGRESS'):
+    return {'rows': 0}
+  elif state in states.EXCEPTION_STATES:
+    result.maybe_reraise()
+    return {'rows': 0}
+
+  info = result.info
+  return {'rows': info.get('row_counter', 0)}
+
+def cancel(*args, **kwargs):
+  notebook = args[0]
+  snippet = args[1]
+  result = download_to_file.AsyncResult(notebook['uuid'])
+  state = result.state
+  if state == states.PENDING:
+    raise QueryExpired()
+  elif state == 'SUBMITTED' or states.state(result.state) < states.state('PROGRESS'):
+    return {'status': -1}
+  elif state in states.EXCEPTION_STATES:
+    result.maybe_reraise()
+    return {'status': -1}
+
+  info = result.info
+  snippet['result']['handle'] = info.get('handle', {})
+  cancel_async.apply_async(args=args, kwargs=kwargs, task_id=_cancel_statement_async_id(notebook))
+  result.forget()
+  os.remove(info.get('file_path'))
+  os.remove(info.get('log_path'))
+  os.remove(info.get('progress_path'))
+  return {'status': 0}
+
+def close_statement(*args, **kwargs):
+  notebook = args[0]
+  snippet = args[1]
+  result = download_to_file.AsyncResult(notebook['uuid'])
+  state = result.state
+  if state == states.PENDING:
+    raise QueryExpired()
+  elif state == 'SUBMITTED' or states.state(result.state) < states.state('PROGRESS'):
+    return {'status': -1}
+  elif state in states.EXCEPTION_STATES:
+    result.maybe_reraise()
+    return {'status': -1}
+
+  info = result.info
+  snippet['result']['handle'] = info.get('handle', {})
+  close_statement_async.apply_async(args=args, kwargs=kwargs, task_id=_close_statement_async_id(notebook))
+  result.forget()
+  os.remove(info.get('file_path'))
+  os.remove(info.get('log_path'))
+  os.remove(info.get('progress_path'))
+  return {'status': 0}
+
+def _cancel_statement_async_id(notebook):
+  return notebook['uuid'] + '_cancel'
+
+def _close_statement_async_id(notebook):
+  return notebook['uuid'] + '_close'
+
+def _get_request(postdict=None, user_id=None):
+  request = HttpRequest()
+  request.POST = postdict
+  user = User.objects.get(id=user_id)
+  user = rewrite_user(user)
+  request.user = user
+  return request

+ 22 - 5
desktop/libs/notebook/src/notebook/views.py

@@ -24,7 +24,7 @@ from django.shortcuts import redirect
 from django.utils.translation import ugettext as _
 from django.views.decorators.clickjacking import xframe_options_exempt
 
-from desktop.conf import ENABLE_DOWNLOAD, USE_NEW_EDITOR
+from desktop.conf import ENABLE_DOWNLOAD, USE_NEW_EDITOR, TASK_SERVER
 from desktop.lib.django_util import render, JsonResponse
 from desktop.lib.exceptions_renderable import PopupException
 from desktop.lib.json_utils import JSONEncoderForHTML
@@ -34,15 +34,33 @@ from desktop.views import serve_403_error
 from metadata.conf import has_optimizer, has_catalog, has_workload_analytics
 
 from notebook.conf import get_ordered_interpreters, SHOW_NOTEBOOKS
-from notebook.connectors.base import Notebook, get_api, _get_snippet_name
+from notebook.connectors.base import Notebook, get_api as _get_api, _get_snippet_name
 from notebook.connectors.spark_shell import SparkApi
 from notebook.decorators import check_editor_access_permission, check_document_access_permission, check_document_modify_permission
 from notebook.management.commands.notebook_setup import Command
 from notebook.models import make_notebook
 
-
 LOG = logging.getLogger(__name__)
 
+if TASK_SERVER.ENABLED.get():
+  import notebook.tasks as ntasks
+
+class ApiWrapper(object):
+  def __init__(self, request, snippet):
+    self.request = request
+    self.snippet = snippet
+  def __getattr__(self, name):
+    if TASK_SERVER.ENABLED.get() and hasattr(ntasks, name):
+      attr = object.__getattribute__(ntasks, name)
+      def _method(*args, **kwargs):
+        return attr(*args, **dict(kwargs, postdict=self.request.POST, user_id=self.request.user.id))
+      return _method
+    else:
+      api = _get_api(self.request, self.snippet)
+      return object.__getattribute__(api, name)
+
+def get_api(request, snippet):
+  return ApiWrapper(request, snippet)
 
 def notebooks(request):
   editor_type = request.GET.get('type', 'notebook')
@@ -311,7 +329,6 @@ def copy(request):
 
   return JsonResponse(response)
 
-
 @check_document_access_permission()
 def download(request):
   if not ENABLE_DOWNLOAD.get():
@@ -321,7 +338,7 @@ def download(request):
   snippet = json.loads(request.POST.get('snippet', '{}'))
   file_format = request.POST.get('format', 'csv')
 
-  response = get_api(request, snippet).download(notebook, snippet, file_format, user_agent=request.META.get('HTTP_USER_AGENT'))
+  response = get_api(request, snippet).download(notebook, snippet, file_format=file_format, user_agent=request.META.get('HTTP_USER_AGENT'))
 
   if response:
     request.audit = {