Browse Source

[Trino] Ensures uninterrupted schema listing even if a catalog query fails (#3939)

Ayush Goyal 11 months ago
parent
commit
77faef4777

+ 14 - 3
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -17,6 +17,7 @@
 
 
 import json
 import json
 import time
 import time
+import logging
 import textwrap
 import textwrap
 from urllib.parse import urlparse
 from urllib.parse import urlparse
 
 
@@ -35,6 +36,8 @@ from desktop.lib.rest.http_client import HttpClient, RestException
 from desktop.lib.rest.resource import Resource
 from desktop.lib.rest.resource import Resource
 from notebook.connectors.base import Api, ExecutionWrapper, QueryError, ResultWrapper
 from notebook.connectors.base import Api, ExecutionWrapper, QueryError, ResultWrapper
 
 
+LOG = logging.getLogger()
+
 
 
 def query_error_handler(func):
 def query_error_handler(func):
   def decorator(*args, **kwargs):
   def decorator(*args, **kwargs):
@@ -297,12 +300,18 @@ class TrinoApi(Api):
     databases = []
     databases = []
 
 
     for catalog in catalogs:
     for catalog in catalogs:
-      query_client = TrinoQuery(self.trino_request, 'SHOW SCHEMAS FROM ' + catalog)
-      response = query_client.execute()
-      databases += [f'{catalog}.{item}' for sublist in response.rows for item in sublist]
+      try:
+        query_client = TrinoQuery(self.trino_request, 'SHOW SCHEMAS FROM ' + catalog)
+        response = query_client.execute()
+        databases += [f'{catalog}.{item}' for sublist in response.rows for item in sublist]
+      except Exception as e:
+        # Log the exception and continue with the next catalog
+        LOG.error(f"Failed to fetch schemas from catalog {catalog}: {str(e)}")
+        continue
 
 
     return databases
     return databases
 
 
+  @query_error_handler
   def _show_catalogs(self):
   def _show_catalogs(self):
     query_client = TrinoQuery(self.trino_request, 'SHOW CATALOGS')
     query_client = TrinoQuery(self.trino_request, 'SHOW CATALOGS')
     response = query_client.execute()
     response = query_client.execute()
@@ -311,6 +320,7 @@ class TrinoApi(Api):
 
 
     return catalogs
     return catalogs
 
 
+  @query_error_handler
   def _show_tables(self, database):
   def _show_tables(self, database):
     database = self._format_identifier(database, is_db=True)
     database = self._format_identifier(database, is_db=True)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
@@ -326,6 +336,7 @@ class TrinoApi(Api):
       for table in tables
       for table in tables
     ]
     ]
 
 
+  @query_error_handler
   def _get_columns(self, database, table):
   def _get_columns(self, database, table):
     database = self._format_identifier(database, is_db=True)
     database = self._format_identifier(database, is_db=True)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)

+ 23 - 0
desktop/libs/notebook/src/notebook/connectors/trino_tests.py

@@ -380,3 +380,26 @@ class TestTrinoApi(TestCase):
     result = self.trino_api.get_log(notebook, snippet)
     result = self.trino_api.get_log(notebook, snippet)
 
 
     assert result == expected_log
     assert result == expected_log
+
+  def test_show_databases(self):
+    with patch('notebook.connectors.trino.LOG.error') as Log_error:
+      with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
+        with patch('notebook.connectors.trino.TrinoApi._show_catalogs') as _show_catalogs:
+          _show_catalogs.return_value = [
+            'catalog1', 'catalog2'
+          ]
+          query_instance = TrinoQuery.return_value
+          query_instance.execute.side_effect = [
+            MagicMock(rows=[["schema1"], ["schema2"]]),  # First catalog response
+            Exception("Some error")  # Second catalog raises an exception
+          ]
+          result = self.trino_api._show_databases()
+
+          # Assert the expected output
+          expected_result = ['catalog1.schema1', 'catalog1.schema2']
+          self.assertEqual(result, expected_result)
+
+          # Assert error logging was called for the exception
+          Log_error.assert_called_once_with(
+            "Failed to fetch schemas from catalog catalog2: Some error"
+          )