Browse Source

HUE-8507 [editor] Fix download sqlalchemy

jdesjean 7 years ago
parent
commit
86aec9b

+ 6 - 3
apps/beeswax/src/beeswax/data_export.py

@@ -33,7 +33,7 @@ FETCH_SIZE = 1000
 DOWNLOAD_COOKIE_AGE = 1800 # 30 minutes
 
 
-def download(handle, format, db, id=None, file_name='query_result', user_agent=None):
+def download(handle, format, db, id=None, file_name='query_result', user_agent=None, callback=None):
   """
   download(query_model, format) -> HttpResponse
 
@@ -46,7 +46,7 @@ def download(handle, format, db, id=None, file_name='query_result', user_agent=N
   max_rows = conf.DOWNLOAD_ROW_LIMIT.get()
   max_bytes = conf.DOWNLOAD_BYTES_LIMIT.get()
 
-  content_generator = HS2DataAdapter(handle, db, max_rows=max_rows, start_over=True, max_bytes=max_bytes)
+  content_generator = HS2DataAdapter(handle, db, max_rows=max_rows, start_over=True, max_bytes=max_bytes, callback=callback)
   generator = export_csvxls.create_generator(content_generator, format)
 
   resp = export_csvxls.make_response(generator, format, file_name, user_agent=user_agent)
@@ -83,7 +83,7 @@ def upload(path, handle, user, db, fs, max_rows=-1, max_bytes=-1):
 
 class HS2DataAdapter:
 
-  def __init__(self, handle, db, max_rows=-1, start_over=True, max_bytes=-1):
+  def __init__(self, handle, db, max_rows=-1, start_over=True, max_bytes=-1, callback=None):
     self.handle = handle
     self.db = db
     self.max_rows = max_rows
@@ -92,6 +92,7 @@ class HS2DataAdapter:
     self.fetch_size = FETCH_SIZE
     self.limit_rows = max_rows > -1
     self.limit_bytes = max_bytes > -1
+    self.callback = callback
 
     self.first_fetched = True
     self.headers = None
@@ -171,4 +172,6 @@ class HS2DataAdapter:
 
       return self.headers, data
     else:
+      if self.callback:
+        self.callback()
       raise StopIteration

+ 85 - 34
desktop/libs/notebook/src/notebook/connectors/sqlalchemyapi.py

@@ -17,6 +17,7 @@
 
 import json
 import logging
+import uuid
 
 from desktop.lib import export_csvxls
 from desktop.lib.i18n import force_unicode
@@ -32,7 +33,7 @@ from sqlalchemy import create_engine, inspect
 
 
 LOG = logging.getLogger(__name__)
-
+CONNECTION_CACHE = {}
 
 def query_error_handler(func):
   def decorator(*args, **kwargs):
@@ -52,46 +53,54 @@ class SqlAlchemyApi(Api):
     self.options = interpreter['options']
     self.engine = create_engine(self.options['url'])
 
-  def _execute(self, notebook, snippet):
-    connection = self.engine.connect()
-    try:
-      result = connection.execute(snippet['statement'])
-      return result.cursor.description, result.fetchmany(100) # TODO: execute statement stub in Rdbms
-    finally:
-      connection.close()
-
   @query_error_handler
   def execute(self, notebook, snippet):
-    metadata, data = self._execute(notebook, snippet)
-    has_result_set = data is not None
+    guid = uuid.uuid4().hex
+    connection = self.engine.connect()
+    result = connection.execute(snippet['statement'])
+    CONNECTION_CACHE[guid] = {
+      'connection': connection,
+      'result': result
+    }
 
     return {
-      'sync': True,
-      'has_result_set': has_result_set,
+      'sync': False,
+      'has_result_set': True,
       'modified_row_count': 0,
+      'guid': guid,
       'result': {
         'has_more': True,
-        'data': data if has_result_set else [],
+        'data': [],
         'meta': [{
           'name': col[0] if type(col) is dict or type(col) is tuple else col,
           'type': 'String', #TODO: resolve
           'comment': ''
-        } for col in metadata] if has_result_set else [],
+        } for col in result.cursor.description],
         'type': 'table'
       }
     }
 
-
   @query_error_handler
   def check_status(self, notebook, snippet):
-    return {'status': 'expired'}
-
+    guid = snippet['result']['handle']['guid']
+    connection = CONNECTION_CACHE.get(guid)
+    if connection:
+      return {'status': 'available'}
+    else:
+      return {'status': 'canceled'}
 
   @query_error_handler
   def fetch_result(self, notebook, snippet, rows, start_over):
+    guid = snippet['result']['handle']['guid']
+    connection = CONNECTION_CACHE.get(guid)
+    if connection:
+      data = connection['result'].fetchmany(rows)
+    else:
+      data = []
+    has_result_set = data is not None
     return {
-      'has_more': False,
-      'data': [],
+      'has_more': has_result_set and len(data) >= rows,
+      'data': data if has_result_set else [],
       'meta': [],
       'type': 'table'
     }
@@ -104,27 +113,54 @@ class SqlAlchemyApi(Api):
 
   @query_error_handler
   def cancel(self, notebook, snippet):
-    return {'status': 0}
+    result = {'status': -1}
+    try:
+      guid = snippet['result']['handle']['guid']
+      connection = CONNECTION_CACHE.get(guid)
+      if connection:
+        connection['connection'].close()
+        del CONNECTION_CACHE[guid]
+      result['status'] = 0
+    finally:
+      return result
 
 
   @query_error_handler
   def get_log(self, notebook, snippet, startFrom=None, size=None):
-    return 'No logs'
+    return ''
 
 
   @query_error_handler
-  def download(self, notebook, snippet, format):
-
+  def download(self, notebook, snippet, format, user_agent=None):
     file_name = _get_snippet_name(notebook)
-    results = self._execute(notebook, snippet)
-    db = FixedResult(results)
-
-    return data_export.download(None, format, db, id=snippet['id'], file_name=file_name)
+    guid = uuid.uuid4().hex
+    connection = self.engine.connect()
+    result = connection.execute(snippet['statement'])
+    CONNECTION_CACHE[guid] = {
+      'connection': connection,
+      'result': result
+    }
+    db = FixedResult([col[0] if type(col) is dict or type(col) is tuple else col for col in result.cursor.description])
+    def callback():
+      connection = CONNECTION_CACHE.get(guid)
+      if connection:
+        connection['connection'].close()
+        del CONNECTION_CACHE[guid]
+    return data_export.download({'guid': guid}, format, db, id=snippet['id'], file_name=file_name, callback=callback)
 
 
   @query_error_handler
   def close_statement(self, snippet):
-    return {'status': -1}
+    result = {'status': -1}
+    try:
+      guid = snippet['result']['handle']['guid']
+      connection = CONNECTION_CACHE.get('guid')
+      if connection:
+        connection['connection'].close()
+        del CONNECTION_CACHE[guid]
+      result['status'] = 0
+    finally:
+      return result
 
 
   @query_error_handler
@@ -214,11 +250,26 @@ class Assist():
     finally:
       connection.close()
 
-class FixedResult():
+class FixedResultSet():
+  def __init__(self, metadata, data, has_more):
+    self.metadata = metadata
+    self.data = data
+    self.has_more = has_more
+
+  def cols(self):
+    return self.metadata
 
-  def __init__(self, result):
-    self.result = result
-    self.has_more = False
+  def rows(self):
+    return self.data if self.data is not None else []
+
+class FixedResult():
+  def __init__(self, metadata):
+    self.metadata = metadata
 
   def fetch(self, handle=None, start_over=None, rows=None):
-    return self.result
+    connection = CONNECTION_CACHE.get(handle['guid'])
+    if connection:
+      data = connection['result'].fetchmany(rows)
+      return FixedResultSet(self.metadata, data, data is not None and len(data) >= rows)
+    else:
+      return FixedResultSet([], [])