浏览代码

HUE-9270 [sqlalchemy] Proper engine cache

Romain 5 年之前
父节点
当前提交
a0610f34dc
共有 1 个文件被更改,包括 22 次插入6 次删除
  1. 22 6
      desktop/libs/notebook/src/notebook/connectors/sql_alchemy.py

+ 22 - 6
desktop/libs/notebook/src/notebook/connectors/sql_alchemy.py

@@ -55,6 +55,7 @@ import textwrap
 
 from string import Template
 
+from django.core.cache import caches
 from django.utils.translation import ugettext as _
 from sqlalchemy import create_engine, inspect, Table, MetaData
 from sqlalchemy.exc import OperationalError
@@ -75,7 +76,9 @@ else:
   from urllib import quote_plus as urllib_quote_plus
 
 
+ENGINES = {}
 CONNECTION_CACHE = {}
+
 LOG = logging.getLogger(__name__)
 
 
@@ -104,7 +107,7 @@ def query_error_handler(func):
 class SqlAlchemyApi(Api):
 
   def __init__(self, user, interpreter):
-    self.user = user
+    super().__init__(user=user, interpreter=interpreter)
     self.options = interpreter['options']
 
     if interpreter.get('dialect_properties'):
@@ -112,9 +115,20 @@ class SqlAlchemyApi(Api):
     else:
       self.backticks = '"' if re.match('^(postgresql://|awsathena|elasticsearch)', self.options.get('url', '')) else '`'
 
+  def _get_engine(self):
+    engine_key = '%(username)s-%(connector_name)s' % {
+      'username': self.user.username,
+      'connector_name': self.interpreter['name']
+    }
+
+    if engine_key not in ENGINES:
+      ENGINES[engine_key] = self._create_engine()
+
+    return ENGINES[engine_key]
+
   def _create_engine(self):
     if '${' in self.options['url']: # URL parameters substitution
-      auth_provided=False
+      auth_provided = False
       vars = {'USER': self.user.username}
       if 'session' in self.options:
         for _prop in self.options['session']['properties']:
@@ -147,6 +161,7 @@ class SqlAlchemyApi(Api):
 
     return create_engine(url, **options)
 
+
   def _get_session(self, notebook, snippet):
     for session in notebook['sessions']:
       if session['type'] == snippet['type']:
@@ -159,9 +174,10 @@ class SqlAlchemyApi(Api):
     guid = uuid.uuid4().hex
 
     session = self._get_session(notebook, snippet)
-    if not session is None:
+    if session is not None:
       self.options['session'] = session
-    engine = self._create_engine()
+
+    engine = self._get_engine()
     connection = engine.connect()
 
     result = connection.execute(snippet['statement'])
@@ -287,7 +303,7 @@ class SqlAlchemyApi(Api):
 
   @query_error_handler
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None):
-    engine = self._create_engine()
+    engine = self._get_engine()
     inspector = inspect(engine)
 
     assist = Assist(inspector, engine, backticks=self.backticks)
@@ -329,7 +345,7 @@ class SqlAlchemyApi(Api):
 
   @query_error_handler
   def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
-    engine = self._create_engine()
+    engine = self._get_engine()
     inspector = inspect(engine)
 
     assist = Assist(inspector, engine, backticks=self.backticks)