Browse Source

[Trino] Add a flag to distinguish between the database and table/column name (#3867)

Ayush Goyal 1 year ago
parent
commit
2e10b7129c

+ 7 - 7
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -85,14 +85,14 @@ class TrinoApi(Api):
         else DEFAULT_AUTH_PASSWORD.get()
     )
 
-  def _format_identifier(self, identifier):
+  def _format_identifier(self, identifier, is_db=False):
     # Remove any backticks
     identifier = identifier.replace('`', '')
 
     # Check if already formatted
     if not (identifier.startswith('"') and identifier.endswith('"')):
       # Check if it's a multi-part identifier (e.g., catalog.schema)
-      if '.' in identifier:
+      if '.' in identifier and is_db:
         # Split and format each part separately
         identifier = '"{}"'.format('"."'.join(identifier.split('.')))
       else:
@@ -113,7 +113,7 @@ class TrinoApi(Api):
   @query_error_handler
   def execute(self, notebook, snippet):
     database = snippet['database']
-    database = self._format_identifier(database)
+    database = self._format_identifier(database, is_db=True)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
 
@@ -258,7 +258,7 @@ class TrinoApi(Api):
     if operation == 'hello':
       statement = "SELECT 'Hello World!'"
     else:
-      database = self._format_identifier(database)
+      database = self._format_identifier(database, is_db=True)
       table = self._format_identifier(table)
       column = '%(column)s' % {'column': self._format_identifier(column)} if column else '*'
       statement = textwrap.dedent('''\
@@ -312,7 +312,7 @@ class TrinoApi(Api):
     return catalogs
 
   def _show_tables(self, database):
-    database = self._format_identifier(database)
+    database = self._format_identifier(database, is_db=True)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
     query_client = TrinoQuery(self.trino_request, 'SHOW TABLES')
@@ -327,7 +327,7 @@ class TrinoApi(Api):
     ]
 
   def _get_columns(self, database, table):
-    database = self._format_identifier(database)
+    database = self._format_identifier(database, is_db=True)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
     table = self._format_identifier(table)
@@ -356,7 +356,7 @@ class TrinoApi(Api):
     if statement:
       try:
         database = snippet['database']
-        database = self._format_identifier(database)
+        database = self._format_identifier(database, is_db=True)
         TrinoQuery(self.trino_request, 'USE ' + database).execute()
         result = TrinoQuery(self.trino_request, 'EXPLAIN ' + statement).execute()
         explanation = result.rows

+ 12 - 2
desktop/libs/notebook/src/notebook/connectors/trino_tests.py

@@ -42,13 +42,23 @@ class TestTrinoApi(TestCase):
     cls.trino_api = TrinoApi(cls.user, interpreter=cls.interpreter)
 
   def test_format_identifier(self):
+    # db name test
     test_cases = [
       ("my_db", '"my_db"'),
-      ("my_db.table", '"my_db"."table"'),
+      ("my_catalog.my_db", '"my_catalog"."my_db"'),
     ]
 
     for database, expected_output in test_cases:
-      assert self.trino_api._format_identifier(database) == expected_output
+      assert self.trino_api._format_identifier(database, is_db=True) == expected_output
+
+    # table name test
+    test_cases = [
+      ("io.airlift.discovery.store:name=dynamic,type=distributedstore", '"io.airlift.discovery.store:name=dynamic,type=distributedstore"'),
+      ("table", '"table"'),
+    ]
+
+    for table, expected_output in test_cases:
+      assert self.trino_api._format_identifier(table) == expected_output
 
   def test_parse_api_url(self):
     # Test parse_api_url method