Pārlūkot izejas kodu

HUE-8747 [editor] Fix download test.

jdesjean 6 gadi atpakaļ
vecāks
revīzija
d5265e4fc7

+ 18 - 12
apps/beeswax/src/beeswax/data_export.py

@@ -23,7 +23,7 @@ import types
 from django.utils.translation import ugettext as _
 
 from desktop.lib import export_csvxls
-from beeswax import common
+from beeswax import common, conf
 
 
 LOG = logging.getLogger(__name__)
@@ -32,7 +32,8 @@ LOG = logging.getLogger(__name__)
 FETCH_SIZE = 1000
 DOWNLOAD_COOKIE_AGE = 1800 # 30 minutes
 
-def download(format, db, id=None, file_name='query_result', user_agent=None, max_rows=-1, max_bytes=-1, store_data_type_in_header=False, start_over=True):
+
+def download(handle, format, db, id=None, file_name='query_result', user_agent=None):
   """
   download(query_model, format) -> HttpResponse
 
@@ -42,7 +43,10 @@ def download(format, db, id=None, file_name='query_result', user_agent=None, max
     LOG.error('Unknown download format "%s"' % (format,))
     return
 
-  content_generator = DataAdapter(db, max_rows=max_rows, start_over=start_over, max_bytes=max_bytes, store_data_type_in_header=store_data_type_in_header)
+  max_rows = conf.DOWNLOAD_ROW_LIMIT.get()
+  max_bytes = conf.DOWNLOAD_BYTES_LIMIT.get()
+
+  content_generator = DataAdapter(db, handle=handle, max_rows=max_rows, max_bytes=max_bytes)
   generator = export_csvxls.create_generator(content_generator, format)
 
   resp = export_csvxls.make_response(generator, format, file_name, user_agent=user_agent)
@@ -71,7 +75,7 @@ def upload(path, handle, user, db, fs, max_rows=-1, max_bytes=-1):
   else:
     fs.do_as_user(user.username, fs.create, path)
 
-  content_generator = DataAdapter(handle, db, max_rows=max_rows, start_over=True, max_bytes=max_bytes)
+  content_generator = DataAdapter(db, handle=handle, max_rows=max_rows, start_over=True, max_bytes=max_bytes)
   for header, data in content_generator:
     dataset = export_csvxls.dataset(None, data)
     fs.do_as_user(user.username, fs.append, path, dataset.csv)
@@ -79,7 +83,8 @@ def upload(path, handle, user, db, fs, max_rows=-1, max_bytes=-1):
 
 class DataAdapter:
 
-  def __init__(self, db, max_rows=-1, start_over=True, max_bytes=-1, store_data_type_in_header=False):
+  def __init__(self, db, handle=None, max_rows=-1, start_over=True, max_bytes=-1, store_data_type_in_header=False):
+    self.handle = handle
     self.db = db
     self.max_rows = max_rows
     self.max_bytes = max_bytes
@@ -129,15 +134,16 @@ class DataAdapter:
     return size
 
   def next(self):
-    results = self.db.fetch(start_over=self.start_over, rows=self.fetch_size)
+    results = self.db.fetch(self.handle, start_over=self.start_over, rows=self.fetch_size)
     if self.first_fetched:
       self.first_fetched = False
       self.start_over = False
-      self.num_cols = len(results['meta'])
+      results_headers = results.full_cols()
+      self.num_cols = len(results_headers)
       if self.store_data_type_in_header:
-        self.headers = [column['name'] + '|' + column['type'] for column in results['meta']]
+        self.headers = [column['name'] + '|' + column['type'] for column in results_headers]
       else:
-        self.headers = [column['name'] for column in results['meta']]
+        self.headers = [column['name'] for column in results_headers]
       if self.limit_bytes:
         self.bytes_counter += max(self.num_cols - 1, 0)
         for header in self.headers:
@@ -149,10 +155,10 @@ class DataAdapter:
         self.fetch_size = 100
 
     if self.has_more and not self.is_truncated:
-      self.has_more = results['has_more']
+      self.has_more = results.has_more
       data = []
 
-      for row in results['data']:
+      for row in results.rows():
         num_bytes = self._getsizeofascii(row)
         if self.limit_rows and self.row_counter + 1 > self.max_rows:
           LOG.warn('The query results exceeded the maximum row limit of %d and has been truncated to first %d rows.' % (self.max_rows, self.row_counter))
@@ -168,5 +174,5 @@ class DataAdapter:
 
       return self.headers, data
     else:
-      self.db.close()
+      self.db.close(self.handle)
       raise StopIteration

+ 18 - 6
desktop/libs/notebook/src/notebook/connectors/base.py

@@ -453,7 +453,7 @@ class Api(object):
     from beeswax import data_export #TODO: Move to notebook?
     from beeswax import conf
 
-    result_wrapper = ResultWrapper(self, notebook, snippet)
+    result_wrapper = ExecutionWrapper(self, notebook, snippet)
 
     max_rows = conf.DOWNLOAD_ROW_LIMIT.get()
     max_bytes = conf.DOWNLOAD_BYTES_LIMIT.get()
@@ -550,7 +550,7 @@ def _get_snippet_name(notebook, unique=False, table_format=False):
     name = re.sub('[-|\s:]', '_', name)
   return name
 
-class ResultWrapper():
+class ExecutionWrapper():
   def __init__(self, api, notebook, snippet, callback=None):
     self.api = api
     self.notebook = notebook
@@ -558,7 +558,7 @@ class ResultWrapper():
     self.callback = callback
     self.should_close = False
 
-  def fetch(self, start_over=None, rows=None):
+  def fetch(self, handle, start_over=None, rows=None):
     if start_over:
       if not self.snippet['result'].get('handle') or not self.snippet['result']['handle'].get('guid') or not self.api.can_start_over(self.notebook, self.snippet):
         start_over = False
@@ -569,10 +569,10 @@ class ResultWrapper():
         self.should_close = True
         self._until_available()
     if self.snippet['result']['handle'].get('sync', False):
-      return self.snippet['result']['handle']['result']
+      result = self.snippet['result']['handle']['result']
     else:
       result = self.api.fetch_result(self.notebook, self.snippet, rows, start_over)
-    return result
+    return ResultWrapper(result.get('meta'), result.get('data'), result.get('has_more'))
 
   def _until_available(self):
     if self.snippet['result']['handle'].get('sync', False):
@@ -602,7 +602,19 @@ class ResultWrapper():
         sleep_seconds = 10
       time.sleep(sleep_seconds)
 
-  def close(self):
+  def close(self, handle):
     if self.should_close:
       self.should_close = False
       self.api.close_statement(self.notebook, self.snippet)
+
+class ResultWrapper():
+  def __init__(self, cols, rows, has_more):
+    self._cols = cols
+    self._rows = rows
+    self.has_more = has_more
+
+  def full_cols(self):
+    return self._cols
+
+  def rows(self):
+    return self._rows

+ 3 - 3
desktop/libs/notebook/src/notebook/tasks.py

@@ -38,7 +38,7 @@ from desktop.lib import export_csvxls
 from desktop.lib import fsmanager
 from desktop.settings import CACHES_CELERY_KEY
 
-from notebook.connectors.base import get_api, QueryExpired, ResultWrapper
+from notebook.connectors.base import get_api, QueryExpired, ExecutionWrapper
 from notebook.sql_utils import get_current_statement
 
 LOG_TASK = get_task_logger(__name__)
@@ -60,7 +60,7 @@ STATE_MAP = {
 storage_info = json.loads(TASK_SERVER.RESULT_FILE_STORAGE.get())
 storage = get_storage_class(storage_info.get('backend'))(**storage_info.get('properties', {}))
 
-class ResultWrapperCallback(object):
+class ExecutionWrapperCallback(object):
   def __init__(self, uuid, meta, f_log):
     self.meta = meta
     self.uuid = uuid
@@ -96,7 +96,7 @@ def download_to_file(notebook, snippet, file_format='csv', max_rows=-1, **kwargs
   meta = {'row_counter': 0, 'handle': {}, 'status': '', 'truncated': False}
 
   with storage.open(_log_key(notebook), 'wb') as f_log:
-    result_wrapper = ResultWrapper(api, notebook, snippet, ResultWrapperCallback(notebook['uuid'], meta, f_log))
+    result_wrapper = ExecutionWrapper(api, notebook, snippet, ExecutionWrapperCallback(notebook['uuid'], meta, f_log))
     content_generator = data_export.DataAdapter(result_wrapper, max_rows=max_rows, store_data_type_in_header=True) #TODO: Move PREFETCH_RESULT_COUNT to front end
     response = export_csvxls.create_generator(content_generator, file_format)
 

+ 46 - 0
desktop/libs/notebook/src/notebook/tests.py

@@ -296,6 +296,25 @@ FROM déclenché c, c.addresses a"""
 
 
 class MockedApi(Api):
+  def execute(self, notebook, snippet):
+    return {
+      'sync': True,
+      'has_result_set': True,
+      'result': {
+        'has_more': False,
+        'data': [['test']],
+        'meta': [{
+          'name': 'test',
+          'type': '',
+          'comment': ''
+        }],
+        'type': 'table'
+      }
+    }
+
+  def close_statement(self, notebook, snippet):
+    pass
+
   def export_data_as_hdfs_file(self, snippet, target_file, overwrite):
     return {'destination': target_file}
 
@@ -420,6 +439,33 @@ class TestNotebookApiMocked(object):
         assert_equal(0, data['status'], data)
         assert_equal('adl:/user/hue/path.csv', data['watch_url']['destination'], data)
 
+  def test_download_result(self):
+    notebook_json = """
+      {
+        "selectedSnippet": "hive",
+        "showHistory": false,
+        "description": "Test Hive Query",
+        "name": "Test Hive Query",
+        "sessions": [
+            {
+                "type": "hive",
+                "properties": [],
+                "id": null
+            }
+        ],
+        "type": "query-hive",
+        "id": null,
+        "snippets": [{"id":"2b7d1f46-17a0-30af-efeb-33d4c29b1055","type":"hive","status":"running","statement":"select * from web_logs","properties":{"settings":[],"variables":[],"files":[],"functions":[]},"result":{"id":"b424befa-f4f5-8799-a0b4-79753f2552b1","type":"table","handle":{"log_context":null,"statements_count":1,"end":{"column":21,"row":0},"statement_id":0,"has_more_statements":false,"start":{"column":0,"row":0},"secret":"rVRWw7YPRGqPT7LZ/TeFaA==an","has_result_set":true,"statement":"select * from web_logs","operation_type":0,"modified_row_count":null,"guid":"7xm6+epkRx6dyvYvGNYePA==an"}},"lastExecuted": 1462554843817,"database":"default"}],
+        "uuid": "d9efdee1-ef25-4d43-b8f9-1a170f69a05a"
+    }
+    """
+    response = self.client.post(reverse('notebook:download'), {
+        'notebook': notebook_json,
+        'snippet': json.dumps(json.loads(notebook_json)['snippets'][0]),
+        'format': 'csv'
+    })
+    content = "".join(response)
+    assert_true(len(content) > 0)
 
 def test_get_interpreters_to_show():
   default_interpreters = OrderedDict((