瀏覽代碼

HUE-8758 [sqlalchemy] Pick proper backtick depending on dialect

Romain 5 年之前
父節點
當前提交
8edec160bc

+ 8 - 4
desktop/libs/notebook/src/notebook/connectors/sql_alchemy.py

@@ -107,7 +107,11 @@ class SqlAlchemyApi(Api):
   def __init__(self, user, interpreter):
     self.user = user
     self.options = interpreter['options']
-    self.backticks = '"' if re.match('^(postgresql://|awsathena|elasticsearch)', self.options.get('url', '')) else '`'
+
+    if interpreter.get('dialect_properties'):
+      self.backticks = interpreter['dialect_properties']['sql_identifier_quote']
+    else:
+      self.backticks = '"' if re.match('^(postgresql://|awsathena|elasticsearch)', self.options.get('url', '')) else '`'
 
   def _create_engine(self):
     if '${' in self.options['url']: # URL parameters substitution
@@ -304,8 +308,8 @@ class SqlAlchemyApi(Api):
           'default': col.get('default'),
           'name': col.get('name'),
           'nullable': col.get('nullable'),
-          'type': str(col.get('type'))
-        } for col in columns if not isinstance(col.get('type'), NullType)
+          'type': str(col.get('type')) if not isinstance(col.get('type'), NullType) else 'Null'
+        } for col in columns
       ]
     else:
       columns = assist.get_columns(database, table)
@@ -333,7 +337,7 @@ class SqlAlchemyApi(Api):
       columns = assist.get_columns(database, table)
       response['full_headers'] = [{
           'name': col.get('name'),
-          'type': str(col.get('type')),
+          'type': str(col.get('type')) if not isinstance(col.get('type'), NullType) else 'Null',
           'comment': ''
         } for col in columns
       ]

+ 35 - 2
desktop/libs/notebook/src/notebook/connectors/sql_alchemy_tests.py

@@ -201,11 +201,43 @@ class TestApi(object):
         )
 
 
-class TestAutocomplete(object):
+class TestDialects(object):
 
   def setUp(self):
     self.client = make_logged_in_client(username="test", groupname="default", recreate=True, is_superuser=False)
+    self.user = rewrite_user(User.objects.get(username="test"))
+
+
+  def test_backticks_with_connectors(self):
+    interpreter = {'options': {'url': 'dialect://'}, 'dialect_properties': {'sql_identifier_quote': '`'}}
+    data = SqlAlchemyApi(self.user, interpreter).get_browse_query(snippet=Mock(), database='db1', table='table1')
+
+    assert_equal(data, 'SELECT *\nFROM `db1`.`table1`\nLIMIT 1000\n')
+
+
+    interpreter = {'options': {'url': 'dialect://'}, 'dialect_properties': {'sql_identifier_quote': '"'}}
+    data = SqlAlchemyApi(self.user, interpreter).get_browse_query(snippet=Mock(), database='db1', table='table1')
+
+    assert_equal(data, 'SELECT *\nFROM "db1"."table1"\nLIMIT 1000\n')
+
+
+  def test_backticks_without_connectors(self):
+    interpreter = {'options': {'url': 'phoenix://'}}
+    data = SqlAlchemyApi(self.user, interpreter).get_browse_query(snippet=Mock(), database='db1', table='table1')
 
+    assert_equal(data, 'SELECT *\nFROM `db1`.`table1`\nLIMIT 1000\n')
+
+
+    interpreter = {'options': {'url': 'postgresql://'}}
+    data = SqlAlchemyApi(self.user, interpreter).get_browse_query(snippet=Mock(), database='db1', table='table1')
+
+    assert_equal(data, 'SELECT *\nFROM "db1"."table1"\nLIMIT 1000\n')
+
+
+class TestAutocomplete(object):
+
+  def setUp(self):
+    self.client = make_logged_in_client(username="test", groupname="default", recreate=True, is_superuser=False)
     self.user = rewrite_user(User.objects.get(username="test"))
 
 
@@ -237,6 +269,7 @@ class TestAutocomplete(object):
           def col1_dict(key):
             return {
               'name': 'col1',
+              'type': 'string'
             }.get(key, Mock())
           col1 = MagicMock()
           col1.__getitem__.side_effect = col1_dict
@@ -255,4 +288,4 @@ class TestAutocomplete(object):
           data = SqlAlchemyApi(self.user, interpreter).autocomplete(snippet, database='database', table='table')
 
           assert_equal(data['columns'], ['col1', 'col2'])
-          assert_equal([col['name'] for col in data['extended_columns']], ['col1'])  # Skip col2
+          assert_equal([col['type'] for col in data['extended_columns']], ['string', 'Null'])