Explorar o código

PR1056 [editor] Support session properties for the SqlAlchemy connector (#1056)

* Raise AuthenticationRequired error when no credentials are present

* Handle notebook session data

* Find the session from the snippet type

* Check if the session exists

* Update vars key to USER

Co-authored-by: Romain Rigaux <romain.rigaux@gmail.com>
Jamie Davenport %!s(int64=5) %!d(string=hai) anos
pai
achega
03c7e4c476

+ 26 - 6
desktop/libs/notebook/src/notebook/connectors/sql_alchemy.py

@@ -89,6 +89,8 @@ def query_error_handler(func):
         raise AuthenticationRequired(message=message)
       else:
         raise e
+    except AuthenticationRequired:
+      raise
     except Exception as e:
       message = force_unicode(e)
       if 'Invalid query handle' in message or 'Invalid OperationHandle' in message:
@@ -108,12 +110,20 @@ class SqlAlchemyApi(Api):
 
   def _create_engine(self):
     if '${' in self.options['url']: # URL parameters substitution
-      vars = {'user': self.user.username}
-      for _prop in self.options['session']['properties']:
-        if _prop['name'] == 'user':
-          vars['USER'] = _prop['value']
-        if _prop['name'] == 'password':
-          vars['PASSWORD'] = _prop['value']
+      auth_provided=False
+      vars = {'USER': self.user.username}
+      if 'session' in self.options:
+        for _prop in self.options['session']['properties']:
+          if _prop['name'] == 'user':
+            vars['USER'] = _prop['value']
+            auth_provided = True
+          if _prop['name'] == 'password':
+            vars['PASSWORD'] = _prop['value']
+            auth_provided = True
+
+      if not auth_provided:
+        raise AuthenticationRequired(message='Missing username and/or password')
+
       raw_url = Template(self.options['url'])
       url = raw_url.safe_substitute(**vars)
     else:
@@ -131,10 +141,20 @@ class SqlAlchemyApi(Api):
 
     return create_engine(url, **options)
 
+  def _get_session(self, notebook, snippet):
+    for session in notebook['sessions']:
+      if session['type'] == snippet['type']:
+        return session
+
+    return None
+
   @query_error_handler
   def execute(self, notebook, snippet):
     guid = uuid.uuid4().hex
 
+    session = self._get_session(notebook, snippet)
+    if not session is None:
+      self.options['session'] = session
     engine = self._create_engine()
     connection = engine.connect()
 

+ 36 - 2
desktop/libs/notebook/src/notebook/connectors/sql_alchemy_tests.py

@@ -20,11 +20,12 @@ from builtins import object
 import logging
 import sys
 
-from nose.tools import assert_equal, assert_not_equal, assert_true, assert_false
+from nose.tools import assert_equal, assert_not_equal, assert_true, assert_false, raises
 
 from desktop.auth.backend import rewrite_user
 from desktop.lib.django_test_util import make_logged_in_client
 from useradmin.models import User
+from notebook.connectors.base import AuthenticationRequired
 
 from notebook.connectors.sql_alchemy import SqlAlchemyApi
 
@@ -66,7 +67,6 @@ class TestApi(object):
     }
     assert_equal(SqlAlchemyApi(self.user, interpreter).backticks, '"')
 
-
   def test_create_athena_engine(self):
     interpreter = {
       'options': {
@@ -135,6 +135,40 @@ class TestApi(object):
       assert_equal(data['data'], [['row1'], ['row2']])
       assert_equal(data['meta'](), [{'type': 'BIGINT_TYPE'}])
 
+  @raises(AuthenticationRequired)
+  def test_create_engine_auth_error(self):
+    interpreter = {
+        'options': {
+            'url': 'mysql://${USER}:${PASSWORD}@hue:3306/hue'
+        }
+    }
+
+    with patch('notebook.connectors.sql_alchemy.create_engine') as create_engine:
+      SqlAlchemyApi(self.user, interpreter)._create_engine()
+
+
+  def test_create_engine_auth(self):
+    interpreter = {
+      'options': {
+        'url': 'mysql://${USER}:${PASSWORD}@hue:3306/hue',
+        'session': {
+          'properties': [
+            {
+              'name': 'user',
+              'value': 'test_user'
+            },
+            {
+              'name': 'password',
+              'value': 'test_pass'
+            }
+          ]
+        }
+      }
+    }
+
+    with patch('notebook.connectors.sql_alchemy.create_engine') as create_engine:
+      SqlAlchemyApi(self.user, interpreter)._create_engine()
+
 
   def test_check_status(self):
     notebook = Mock()