Browse Source

HUE-8768 [task] First basic unit test for async SQL query task

Romain 5 years ago
parent
commit
105af052f2

+ 37 - 5
desktop/libs/notebook/src/notebook/task_tests.py

@@ -26,18 +26,50 @@ from desktop.lib.django_test_util import make_logged_in_client
 from useradmin.models import User
 
 from notebook.connectors.sql_alchemy import SqlAlchemyApi
-from notebook.tasks import run_sync_query
-
+from notebook.tasks import run_sync_query, download_to_file
 
 if sys.version_info[0] > 2:
-  from unittest.mock import patch, Mock
+  from unittest.mock import patch, Mock, MagicMock
 else:
-  from mock import patch, Mock
+  from mock import patch, Mock, MagicMock
 
 
 LOG = logging.getLogger(__name__)
 
 
+
+class TestRunAsyncQueryTask():
+
+  def setUp(self):
+    self.client = make_logged_in_client(username="test", groupname="default", recreate=True, is_superuser=False)
+    self.user = User.objects.get(username="test")
+
+
+  def test_run_query_only(self):
+    with patch('notebook.tasks._get_request') as _get_request:
+      with patch('notebook.tasks.get_api') as get_api:
+        with patch('notebook.tasks.DataAdapter') as DataAdapter:
+          with patch('notebook.tasks.export_csvxls.create_generator') as create_generator:
+
+            DataAdapter.return_value = MagicMock(row_counter=2)
+
+            get_api.return_value = Mock(
+              check_status=Mock(return_value={'status': 0})
+            )
+
+            def notebook_dict(key):
+              return {
+                'uuid': '1ca47e0d-4708-4709-82c1-a9280e15452b',
+              }.get(key, Mock())
+            notebook = MagicMock()
+            notebook.__getitem__.side_effect = notebook_dict
+
+            snippet = MagicMock()
+            meta = download_to_file(notebook, snippet)
+
+            assert_equal(meta['row_counter'], 2, meta)
+
+
 class TestRunSyncQueryTask():
 
   def setUp(self):
@@ -45,7 +77,7 @@ class TestRunSyncQueryTask():
     self.user = User.objects.get(username="test")
 
 
-  def test_run_sync_query(self):
+  def test_run_query(self):
     snippet = {'type': 'mysql', 'statement_raw': 'SHOW TABLES', 'variables': []}
 
     with patch('notebook.tasks.Document2.objects.get_by_uuid') as get_by_uuid:

+ 2 - 2
desktop/libs/notebook/src/notebook/tasks.py

@@ -33,7 +33,7 @@ from django.core.files.storage import get_storage_class
 from django.db import transaction
 from django.http import FileResponse, HttpRequest
 
-from beeswax import data_export
+from beeswax.data_export import DataAdapter
 from desktop.auth.backend import rewrite_user
 from desktop.celery import app
 from desktop.conf import TASK_SERVER
@@ -121,7 +121,7 @@ def download_to_file(notebook, snippet, file_format='csv', max_rows=-1, **kwargs
         snippet,
         ExecutionWrapperCallback(notebook['uuid'], meta, f_log)
     )
-    content_generator = data_export.DataAdapter(
+    content_generator = DataAdapter(
         result_wrapper,
         max_rows=max_rows,
         store_data_type_in_header=True