Răsfoiți Sursa

HUE-8740 [sql] Add create_session to sqlalchemy & cache engine.

jdesjean 6 ani în urmă
părinte
comite
d523248427

+ 0 - 3
desktop/libs/notebook/src/notebook/api.py

@@ -123,14 +123,11 @@ def _execute_notebook(request, notebook, snippet):
 
 
   try:
   try:
     try:
     try:
-      session = notebook.get('sessions') and notebook['sessions'][0] # Session reference for snippet execution without persisting it
       if historify:
       if historify:
         history = _historify(notebook, request.user)
         history = _historify(notebook, request.user)
         notebook = Notebook(document=history).get_data()
         notebook = Notebook(document=history).get_data()
 
 
       interpreter = get_api(request, snippet)
       interpreter = get_api(request, snippet)
-      if snippet.get('interface') == 'sqlalchemy':
-        interpreter.options['session'] = session
 
 
       response['handle'] = interpreter.execute(notebook, snippet)
       response['handle'] = interpreter.execute(notebook, snippet)
 
 

+ 37 - 6
desktop/libs/notebook/src/notebook/connectors/sqlalchemyapi.py

@@ -59,10 +59,10 @@ from desktop.lib.i18n import force_unicode
 from beeswax import data_export
 from beeswax import data_export
 from librdbms.server import dbms
 from librdbms.server import dbms
 
 
-from notebook.connectors.base import Api, QueryError, QueryExpired, _get_snippet_name, AuthenticationRequired
+from notebook.connectors.base import Api, QueryError, QueryExpired, _get_snippet_name, AuthenticationRequired, SessionExpired
 from notebook.models import escape_rows
 from notebook.models import escape_rows
 
 
-
+ENGINE_CACHE = None
 CONNECTION_CACHE = {}
 CONNECTION_CACHE = {}
 LOG = logging.getLogger(__name__)
 LOG = logging.getLogger(__name__)
 
 
@@ -77,6 +77,10 @@ def query_error_handler(func):
         raise AuthenticationRequired(message=message)
         raise AuthenticationRequired(message=message)
       else:
       else:
         raise e
         raise e
+    except SessionExpired, e:
+      raise e
+    except QueryExpired, e:
+      raise e
     except Exception, e:
     except Exception, e:
       message = force_unicode(e)
       message = force_unicode(e)
       if 'Invalid query handle' in message or 'Invalid OperationHandle' in message:
       if 'Invalid query handle' in message or 'Invalid OperationHandle' in message:
@@ -90,14 +94,15 @@ def query_error_handler(func):
 class SqlAlchemyApi(Api):
 class SqlAlchemyApi(Api):
 
 
   def __init__(self, user, interpreter=None):
   def __init__(self, user, interpreter=None):
+    global ENGINE_CACHE
     self.user = user
     self.user = user
     self.options = interpreter['options']
     self.options = interpreter['options']
-    self.engine = None # Currently instantiated by an execute()
+    self.engine = ENGINE_CACHE
 
 
-  def _create_engine(self):
+  def _create_engine(self, properties):
     if '${' in self.options['url']: # URL parameters substitution
     if '${' in self.options['url']: # URL parameters substitution
       vars = {'user': self.user.username}
       vars = {'user': self.user.username}
-      for _prop in self.options['session']['properties']:
+      for _prop in properties:
         if _prop['name'] == 'user':
         if _prop['name'] == 'user':
           vars['USER'] = _prop['value']
           vars['USER'] = _prop['value']
         if _prop['name'] == 'password':
         if _prop['name'] == 'password':
@@ -108,12 +113,26 @@ class SqlAlchemyApi(Api):
       url = self.options['url']
       url = self.options['url']
     return create_engine(url)
     return create_engine(url)
 
 
+  @query_error_handler
+  def create_session(self, lang=None, properties=None):
+    if self.engine:
+      return {}
+
+    global ENGINE_CACHE
+    engine = self._create_engine(properties)
+    connection = engine.connect() # Try to connect so we can check if we can authenticate
+    connection.close()
+    self.engine = self.engine
+    ENGINE_CACHE = engine
+
+    return {}
+
   @query_error_handler
   @query_error_handler
   def execute(self, notebook, snippet):
   def execute(self, notebook, snippet):
     guid = uuid.uuid4().hex
     guid = uuid.uuid4().hex
 
 
     if not self.engine:
     if not self.engine:
-      self.engine = self._create_engine()
+      raise SessionExpired()
     connection = self.engine.connect()
     connection = self.engine.connect()
     result = connection.execution_options(stream_results=True).execute(snippet['statement'])
     result = connection.execution_options(stream_results=True).execute(snippet['statement'])
     cache = {
     cache = {
@@ -154,6 +173,8 @@ class SqlAlchemyApi(Api):
   def fetch_result(self, notebook, snippet, rows, start_over):
   def fetch_result(self, notebook, snippet, rows, start_over):
     guid = snippet['result']['handle']['guid']
     guid = snippet['result']['handle']['guid']
     cache = CONNECTION_CACHE.get(guid)
     cache = CONNECTION_CACHE.get(guid)
+    if not cache:
+      raise QueryExpired()
 
 
     if cache:
     if cache:
       data = cache['result'].fetchmany(rows)
       data = cache['result'].fetchmany(rows)
@@ -198,6 +219,8 @@ class SqlAlchemyApi(Api):
     try:
     try:
       guid = snippet['result']['handle']['guid']
       guid = snippet['result']['handle']['guid']
       connection = CONNECTION_CACHE.get(guid)
       connection = CONNECTION_CACHE.get(guid)
+      if not connection:
+        raise QueryExpired()
       if connection:
       if connection:
         connection['connection'].close()
         connection['connection'].close()
         del CONNECTION_CACHE[guid]
         del CONNECTION_CACHE[guid]
@@ -216,6 +239,8 @@ class SqlAlchemyApi(Api):
     file_name = _get_snippet_name(notebook)
     file_name = _get_snippet_name(notebook)
     guid = uuid.uuid4().hex
     guid = uuid.uuid4().hex
 
 
+    if not self.engine:
+      raise SessionExpired()
     connection = self.engine.connect()
     connection = self.engine.connect()
     result = connection.execution_options(stream_results=True).execute(snippet['statement'])
     result = connection.execution_options(stream_results=True).execute(snippet['statement'])
 
 
@@ -241,6 +266,8 @@ class SqlAlchemyApi(Api):
     try:
     try:
       guid = snippet['result']['handle']['guid']
       guid = snippet['result']['handle']['guid']
       connection = CONNECTION_CACHE.get('guid')
       connection = CONNECTION_CACHE.get('guid')
+      if not connection:
+        raise QueryExpired()
       if connection:
       if connection:
         connection['connection'].close()
         connection['connection'].close()
         del CONNECTION_CACHE[guid]
         del CONNECTION_CACHE[guid]
@@ -251,6 +278,8 @@ class SqlAlchemyApi(Api):
 
 
   @query_error_handler
   @query_error_handler
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None):
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None):
+    if not self.engine:
+      raise SessionExpired()
     inspector = inspect(self.engine)
     inspector = inspect(self.engine)
 
 
     assist = Assist(inspector, self.engine)
     assist = Assist(inspector, self.engine)
@@ -287,6 +316,8 @@ class SqlAlchemyApi(Api):
 
 
   @query_error_handler
   @query_error_handler
   def get_sample_data(self, snippet, database=None, table=None, column=None, async=False, operation=None):
   def get_sample_data(self, snippet, database=None, table=None, column=None, async=False, operation=None):
+    if not self.engine:
+      raise SessionExpired()
     inspector = inspect(self.engine)
     inspector = inspect(self.engine)
 
 
     assist = Assist(inspector, self.engine)
     assist = Assist(inspector, self.engine)