Эх сурвалжийг харах

Raise AuthenticationRequired error whenever invalid credentials to datasource throws exception

emmanuel 4 жил өмнө
parent
commit
0a85782542

+ 22 - 5
desktop/libs/notebook/src/notebook/connectors/sql_alchemy.py

@@ -128,11 +128,14 @@ class SqlAlchemyApi(Api):
     else:
       self.backticks = '"' if re.match('^(postgresql://|awsathena|elasticsearch|phoenix)', self.options.get('url', '')) else '`'
 
-  def _get_engine(self):
-    engine_key = ENGINE_KEY % {
+  def _get_engine_key(self):
+    return ENGINE_KEY % {
       'username': self.user.username,
       'connector_name': self.interpreter['name']
     }
+  
+  def _get_engine(self):
+    engine_key = self._get_engine_key()
 
     if engine_key not in ENGINES:
       ENGINES[engine_key] = self._create_engine()
@@ -211,6 +214,20 @@ class SqlAlchemyApi(Api):
 
     return None
 
+
+  def _create_connection(self, engine):
+    connection = None
+    try:
+      connection = engine.connect()
+    except Exception as e:
+      engine_key = self._get_engine_key()
+      del ENGINES[engine_key]
+      
+      raise AuthenticationRequired(message='Could not establish connection to datasource')
+
+    return connection
+
+
   @query_error_handler
   def execute(self, notebook, snippet):
     guid = uuid.uuid4().hex
@@ -220,7 +237,7 @@ class SqlAlchemyApi(Api):
       self.options['session'] = session
 
     engine = self._get_engine()
-    connection = engine.connect()
+    connection = self._create_connection(engine)
     statement = snippet['statement']
 
     if self.interpreter['dialect_properties'].get('trim_statement_semicolon', True):
@@ -271,7 +288,7 @@ class SqlAlchemyApi(Api):
       self.options['session'] = session
 
     engine = self._get_engine()
-    connection = engine.connect()
+    connection = self._create_connection(engine)
     statement = snippet['statement']
 
     explanation = ''
@@ -536,7 +553,7 @@ class Assist(object):
           'backticks': self.backticks
       })
 
-    connection = self.engine.connect()
+    connection = self._create_connection(self.engine)
     try:
       result = connection.execute(statement)
       return result.cursor.description, result.fetchall()

+ 55 - 20
desktop/libs/notebook/src/notebook/connectors/sql_alchemy_tests.py

@@ -181,6 +181,44 @@ class TestApi(object):
       SqlAlchemyApi(self.user, interpreter)._create_engine()
 
 
+  @raises(AuthenticationRequired)
+  def test_create_connection_error(self):
+    interpreter = {
+      'name': 'hive',
+      'options': {
+        'url': 'mysql://${USER}:${PASSWORD}@hue:3306/hue'
+      }
+    }
+
+    with patch('notebook.connectors.sql_alchemy.create_engine') as create_engine:
+      engine = SqlAlchemyApi(self.user, interpreter)._create_engine()
+      SqlAlchemyApi(self.user, interpreter)._create_connection(engine)
+
+  def test_create_connection(self):
+    interpreter = {
+      'name': 'hive',
+      'options': {
+        'url': 'mysql://${USER}:${PASSWORD}@hue:3306/hue',
+        'session': {
+          'properties': [
+            {
+              'name': 'user',
+              'value': 'test_user'
+            },
+            {
+              'name': 'password',
+              'value': 'test_pass'
+            }
+          ]
+        }
+      }
+    }
+
+    with patch('notebook.connectors.sql_alchemy.create_engine') as create_engine:
+      engine = SqlAlchemyApi(self.user, interpreter)._create_engine()
+      SqlAlchemyApi(self.user, interpreter)._create_connection(engine)
+
+
   def test_create_engine_with_impersonation(self):
     interpreter = {
       'name': 'hive',
@@ -258,32 +296,29 @@ class TestApi(object):
       'dialect_properties': {},
     }
 
-    with patch('notebook.connectors.sql_alchemy.SqlAlchemyApi._create_engine') as _create_engine:
-      with patch('notebook.connectors.sql_alchemy.SqlAlchemyApi._get_session') as _get_session:
-        execute = Mock(return_value=Mock(cursor=None))
-        _create_engine.return_value = Mock(
-          connect=Mock(
-            return_value=Mock(
-              execute=execute
-            )
+    with patch('notebook.connectors.sql_alchemy.SqlAlchemyApi._create_connection') as _create_connection:
+      with patch('notebook.connectors.sql_alchemy.SqlAlchemyApi._create_engine') as _create_engine:
+        with patch('notebook.connectors.sql_alchemy.SqlAlchemyApi._get_session') as _get_session:
+          execute = Mock(return_value=Mock(cursor=None))
+          _create_connection.return_value = Mock(
+            execute=execute
           )
-        )
-        notebook = {}
-        snippet = {'statement': 'SELECT 1;'}
+          notebook = {}
+          snippet = {'statement': 'SELECT 1;'}
 
-        # Trim
-        engine = SqlAlchemyApi(self.user, interpreter).execute(notebook, snippet)
+          # Trim
+          engine = SqlAlchemyApi(self.user, interpreter).execute(notebook, snippet)
 
-        execute.assert_called_with('SELECT 1')
+          execute.assert_called_with('SELECT 1')
 
-        # No Trim
-        interpreter['options']['url'] = 'mysql://hue:3306/hue'
-        interpreter['dialect_properties']['trim_statement_semicolon'] = False
-        interpreter['dialect_properties']['sql_identifier_quote'] = '`'
+          # No Trim
+          interpreter['options']['url'] = 'mysql://hue:3306/hue'
+          interpreter['dialect_properties']['trim_statement_semicolon'] = False
+          interpreter['dialect_properties']['sql_identifier_quote'] = '`'
 
-        engine = SqlAlchemyApi(self.user, interpreter).execute(notebook, snippet)
+          engine = SqlAlchemyApi(self.user, interpreter).execute(notebook, snippet)
 
-        execute.assert_called_with('SELECT 1;')
+          execute.assert_called_with('SELECT 1;')
 
 
 class TestDialects(object):