Browse Source

[Trino] Table/database names starting with digit should inclosed with double quotes (#3735)

Ayush Goyal 1 year ago
parent
commit
d7a54dd63b

+ 1 - 1
desktop/core/base_requirements.txt

@@ -67,7 +67,7 @@ SQLAlchemy==1.3.8
 sqlparse==0.5.0
 sqlparse==0.5.0
 tablib==0.13.0
 tablib==0.13.0
 tabulate==0.8.9
 tabulate==0.8.9
-trino==0.324.0  # Need to upgrade to the latest version but that requires requests>=2.31.0
+trino==0.329.0
 git+https://github.com/gethue/thrift.git
 git+https://github.com/gethue/thrift.git
 thrift-sasl==0.4.3
 thrift-sasl==0.4.3
 git+https://github.com/gethue/django-babel.git
 git+https://github.com/gethue/django-babel.git

+ 41 - 37
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -15,29 +15,25 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import logging
 import json
 import json
-import requests
-import sys
-import textwrap
 import time
 import time
+import textwrap
+from urllib.parse import urlparse
 
 
+import requests
 from django.utils.translation import gettext as _
 from django.utils.translation import gettext as _
-from urllib.parse import urlparse
+from trino.auth import BasicAuthentication
+from trino.client import ClientSession, TrinoQuery, TrinoRequest
+from trino.exceptions import TrinoConnectionError
 
 
-from beeswax import conf
-from beeswax import data_export
-from desktop.conf import AUTH_USERNAME as DEFAULT_AUTH_USERNAME, AUTH_PASSWORD as DEFAULT_AUTH_PASSWORD
+from beeswax import conf, data_export
+from desktop.conf import AUTH_PASSWORD as DEFAULT_AUTH_PASSWORD, AUTH_USERNAME as DEFAULT_AUTH_USERNAME
 from desktop.lib import export_csvxls
 from desktop.lib import export_csvxls
 from desktop.lib.conf import coerce_password_from_script
 from desktop.lib.conf import coerce_password_from_script
 from desktop.lib.i18n import force_unicode
 from desktop.lib.i18n import force_unicode
 from desktop.lib.rest.http_client import HttpClient, RestException
 from desktop.lib.rest.http_client import HttpClient, RestException
 from desktop.lib.rest.resource import Resource
 from desktop.lib.rest.resource import Resource
-from notebook.connectors.base import Api, QueryError, ExecutionWrapper, ResultWrapper
-
-from trino import exceptions
-from trino.auth import BasicAuthentication
-from trino.client import ClientSession, TrinoRequest, TrinoQuery
+from notebook.connectors.base import Api, ExecutionWrapper, QueryError, ResultWrapper
 
 
 
 
 def query_error_handler(func):
 def query_error_handler(func):
@@ -47,8 +43,8 @@ def query_error_handler(func):
     except RestException as e:
     except RestException as e:
       try:
       try:
         message = force_unicode(json.loads(e.message)['errors'])
         message = force_unicode(json.loads(e.message)['errors'])
-      except:
-        message = e.message
+      except Exception as ex:
+        message = ex.message
       message = force_unicode(message)
       message = force_unicode(message)
       raise QueryError(message)
       raise QueryError(message)
     except Exception as e:
     except Exception as e:
@@ -81,7 +77,6 @@ class TrinoApi(Api):
       auth=self.auth
       auth=self.auth
     )
     )
 
 
-
   def get_auth_password(self):
   def get_auth_password(self):
     auth_password_script = self.options.get('auth_password_script')
     auth_password_script = self.options.get('auth_password_script')
     return (
     return (
@@ -90,21 +85,35 @@ class TrinoApi(Api):
         else DEFAULT_AUTH_PASSWORD.get()
         else DEFAULT_AUTH_PASSWORD.get()
     )
     )
 
 
+  def _format_identifier(self, identifier):
+    # Remove any backticks
+    identifier = identifier.replace('`', '')
+
+    # Check if already formatted
+    if not (identifier.startswith('"') and identifier.endswith('"')):
+      # Check if it's a multi-part identifier (e.g., catalog.schema)
+      if '.' in identifier:
+        # Split and format each part separately
+        identifier = '"{}"'.format('"."'.join(identifier.split('.')))
+      else:
+        # Format single-part identifier
+        identifier = f'"{identifier}"'
+
+    return identifier
 
 
   @query_error_handler
   @query_error_handler
   def parse_api_url(self, api_url):
   def parse_api_url(self, api_url):
     parsed_url = urlparse(api_url)
     parsed_url = urlparse(api_url)
     return parsed_url.hostname, parsed_url.port, parsed_url.scheme
     return parsed_url.hostname, parsed_url.port, parsed_url.scheme
 
 
-
   @query_error_handler
   @query_error_handler
   def create_session(self, lang=None, properties=None):
   def create_session(self, lang=None, properties=None):
     pass
     pass
 
 
-
   @query_error_handler
   @query_error_handler
   def execute(self, notebook, snippet):
   def execute(self, notebook, snippet):
     database = snippet['database']
     database = snippet['database']
+    database = self._format_identifier(database)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
     query_client.execute()
 
 
@@ -138,7 +147,6 @@ class TrinoApi(Api):
 
 
     return response
     return response
 
 
-
   @query_error_handler
   @query_error_handler
   def check_status(self, notebook, snippet):
   def check_status(self, notebook, snippet):
     response = {}
     response = {}
@@ -153,7 +161,7 @@ class TrinoApi(Api):
       if _status.stats['state'] == 'QUEUED':
       if _status.stats['state'] == 'QUEUED':
         status = 'waiting'
         status = 'waiting'
       elif _status.stats['state'] == 'RUNNING':
       elif _status.stats['state'] == 'RUNNING':
-        status = 'available' # need to verify
+        status = 'available'  # need to verify
       else:
       else:
         status = 'available'
         status = 'available'
 
 
@@ -176,7 +184,7 @@ class TrinoApi(Api):
       try:
       try:
         response = self.trino_request.get(next_uri)
         response = self.trino_request.get(next_uri)
       except requests.exceptions.RequestException as e:
       except requests.exceptions.RequestException as e:
-        raise trino.exceptions.TrinoConnectionError("failed to fetch: {}".format(e))
+        raise TrinoConnectionError("failed to fetch: {}".format(e))
 
 
       status = self.trino_request.process(response)
       status = self.trino_request.process(response)
       data += status.rows
       data += status.rows
@@ -209,7 +217,6 @@ class TrinoApi(Api):
       'type': 'table'
       'type': 'table'
     }
     }
 
 
-
   @query_error_handler
   @query_error_handler
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
     response = {}
     response = {}
@@ -231,9 +238,8 @@ class TrinoApi(Api):
 
 
     return response
     return response
 
 
-
   @query_error_handler
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=False, is_async=False, operation=None):
     statement = self._get_select_query(database, table, column, operation)
     statement = self._get_select_query(database, table, column, operation)
     query_client = TrinoQuery(self.trino_request, statement)
     query_client = TrinoQuery(self.trino_request, statement)
     query_client.execute()
     query_client.execute()
@@ -248,12 +254,13 @@ class TrinoApi(Api):
 
 
     return response
     return response
 
 
-
   def _get_select_query(self, database, table, column=None, operation=None, limit=100):
   def _get_select_query(self, database, table, column=None, operation=None, limit=100):
     if operation == 'hello':
     if operation == 'hello':
       statement = "SELECT 'Hello World!'"
       statement = "SELECT 'Hello World!'"
     else:
     else:
-      column = '%(column)s' % {'column': column} if column else '*'
+      database = self._format_identifier(database)
+      table = self._format_identifier(table)
+      column = '%(column)s' % {'column': self._format_identifier(column)} if column else '*'
       statement = textwrap.dedent('''\
       statement = textwrap.dedent('''\
           SELECT %(column)s
           SELECT %(column)s
           FROM %(database)s.%(table)s
           FROM %(database)s.%(table)s
@@ -267,13 +274,12 @@ class TrinoApi(Api):
 
 
     return statement
     return statement
 
 
-
   def close_statement(self, notebook, snippet):
   def close_statement(self, notebook, snippet):
     try:
     try:
       if snippet['result']['handle']['next_uri']:
       if snippet['result']['handle']['next_uri']:
         self.trino_request.delete(snippet['result']['handle']['next_uri'])
         self.trino_request.delete(snippet['result']['handle']['next_uri'])
       else:
       else:
-        return {'status': -1} # missing operation ids
+        return {'status': -1}  # missing operation ids
     except Exception as e:
     except Exception as e:
       if 'does not exist in current session:' in str(e):
       if 'does not exist in current session:' in str(e):
         return {'status': -1}  # skipped
         return {'status': -1}  # skipped
@@ -282,12 +288,10 @@ class TrinoApi(Api):
 
 
     return {'status': 0}
     return {'status': 0}
 
 
-
   def close_session(self, session):
   def close_session(self, session):
     # Avoid closing session on page refresh or editor close for now
     # Avoid closing session on page refresh or editor close for now
     pass
     pass
 
 
-
   def _show_databases(self):
   def _show_databases(self):
     catalogs = self._show_catalogs()
     catalogs = self._show_catalogs()
     databases = []
     databases = []
@@ -299,7 +303,6 @@ class TrinoApi(Api):
 
 
     return databases
     return databases
 
 
-
   def _show_catalogs(self):
   def _show_catalogs(self):
     query_client = TrinoQuery(self.trino_request, 'SHOW CATALOGS')
     query_client = TrinoQuery(self.trino_request, 'SHOW CATALOGS')
     response = query_client.execute()
     response = query_client.execute()
@@ -308,8 +311,8 @@ class TrinoApi(Api):
 
 
     return catalogs
     return catalogs
 
 
-
   def _show_tables(self, database):
   def _show_tables(self, database):
+    database = self._format_identifier(database)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
     query_client.execute()
     query_client = TrinoQuery(self.trino_request, 'SHOW TABLES')
     query_client = TrinoQuery(self.trino_request, 'SHOW TABLES')
@@ -323,10 +326,11 @@ class TrinoApi(Api):
       for table in tables
       for table in tables
     ]
     ]
 
 
-
   def _get_columns(self, database, table):
   def _get_columns(self, database, table):
+    database = self._format_identifier(database)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
     query_client.execute()
+    table = self._format_identifier(table)
     query_client = TrinoQuery(self.trino_request, 'DESCRIBE ' + table)
     query_client = TrinoQuery(self.trino_request, 'DESCRIBE ' + table)
     response = query_client.execute()
     response = query_client.execute()
     columns = response.rows
     columns = response.rows
@@ -339,7 +343,6 @@ class TrinoApi(Api):
       for col in columns
       for col in columns
     ]
     ]
 
 
-
   @query_error_handler
   @query_error_handler
   def explain(self, notebook, snippet):
   def explain(self, notebook, snippet):
     statement = snippet['statement'].rstrip(';')
     statement = snippet['statement'].rstrip(';')
@@ -347,7 +350,9 @@ class TrinoApi(Api):
 
 
     if statement:
     if statement:
       try:
       try:
-        TrinoQuery(self.trino_request, 'USE ' + snippet['database']).execute()
+        database = snippet['database']
+        database = self._format_identifier(database)
+        TrinoQuery(self.trino_request, 'USE ' + database).execute()
         result = TrinoQuery(self.trino_request, 'EXPLAIN ' + statement).execute()
         result = TrinoQuery(self.trino_request, 'EXPLAIN ' + statement).execute()
         explanation = result.rows
         explanation = result.rows
       except Exception as e:
       except Exception as e:
@@ -359,7 +364,6 @@ class TrinoApi(Api):
       'statement': statement
       'statement': statement
     }
     }
 
 
-
   def download(self, notebook, snippet, file_format='csv'):
   def download(self, notebook, snippet, file_format='csv'):
     result_wrapper = TrinoExecutionWrapper(self, notebook, snippet)
     result_wrapper = TrinoExecutionWrapper(self, notebook, snippet)
 
 
@@ -398,7 +402,7 @@ class TrinoExecutionWrapper(ExecutionWrapper):
 
 
   def _until_available(self):
   def _until_available(self):
     if self.snippet['result']['handle'].get('sync', False):
     if self.snippet['result']['handle'].get('sync', False):
-      return # Request is already completed
+      return  # Request is already completed
 
 
     count = 0
     count = 0
     sleep_seconds = 1
     sleep_seconds = 1

+ 36 - 39
desktop/libs/notebook/src/notebook/connectors/trino_tests.py

@@ -15,9 +15,10 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
+from unittest.mock import MagicMock, Mock, patch
+
 import pytest
 import pytest
 from django.test import TestCase
 from django.test import TestCase
-from unittest.mock import MagicMock, patch, Mock
 
 
 from desktop.auth.backend import rewrite_user
 from desktop.auth.backend import rewrite_user
 from desktop.lib.django_test_util import make_logged_in_client
 from desktop.lib.django_test_util import make_logged_in_client
@@ -40,6 +41,14 @@ class TestTrinoApi(TestCase):
     # Initialize TrinoApi with mock user and interpreter
     # Initialize TrinoApi with mock user and interpreter
     cls.trino_api = TrinoApi(cls.user, interpreter=cls.interpreter)
     cls.trino_api = TrinoApi(cls.user, interpreter=cls.interpreter)
 
 
+  def test_format_identifier(self):
+    test_cases = [
+      ("my_db", '"my_db"'),
+      ("my_db.table", '"my_db"."table"'),
+    ]
+
+    for database, expected_output in test_cases:
+      assert self.trino_api._format_identifier(database) == expected_output
 
 
   def test_parse_api_url(self):
   def test_parse_api_url(self):
     # Test parse_api_url method
     # Test parse_api_url method
@@ -49,7 +58,6 @@ class TestTrinoApi(TestCase):
 
 
     assert result == expected_result
     assert result == expected_result
 
 
-
   def test_autocomplete_with_database(self):
   def test_autocomplete_with_database(self):
     with patch('notebook.connectors.trino.TrinoApi._show_databases') as _show_databases:
     with patch('notebook.connectors.trino.TrinoApi._show_databases') as _show_databases:
       _show_databases.return_value = [
       _show_databases.return_value = [
@@ -62,7 +70,6 @@ class TestTrinoApi(TestCase):
       assert (response['databases'] ==
       assert (response['databases'] ==
       [{'name': 'test_catalog1.test_db1'}, {'name': 'test_catalog2.test_db1'}, {'name': 'test_catalog2.test_db2'}])
       [{'name': 'test_catalog1.test_db1'}, {'name': 'test_catalog2.test_db1'}, {'name': 'test_catalog2.test_db2'}])
 
 
-
   def test_autocomplete_with_database_and_table(self):
   def test_autocomplete_with_database_and_table(self):
     with patch('notebook.connectors.trino.TrinoApi._show_tables') as _show_tables:
     with patch('notebook.connectors.trino.TrinoApi._show_tables') as _show_tables:
       _show_tables.return_value = [
       _show_tables.return_value = [
@@ -73,7 +80,7 @@ class TestTrinoApi(TestCase):
       snippet = {}
       snippet = {}
       database = 'test_db1'
       database = 'test_db1'
       response = self.trino_api.autocomplete(snippet, database)
       response = self.trino_api.autocomplete(snippet, database)
-      
+
       assert 'tables_meta' in response   # Check if 'table_meta' key exists in the response
       assert 'tables_meta' in response   # Check if 'table_meta' key exists in the response
       assert (response['tables_meta'] ==
       assert (response['tables_meta'] ==
       [
       [
@@ -82,7 +89,6 @@ class TestTrinoApi(TestCase):
       {'name': 'test_table3', 'type': 'table', 'comment': ''}
       {'name': 'test_table3', 'type': 'table', 'comment': ''}
       ])
       ])
 
 
-
   def test_autocomplete_with_database_table_and_column(self):
   def test_autocomplete_with_database_table_and_column(self):
     with patch('notebook.connectors.trino.TrinoApi._get_columns') as _get_columns:
     with patch('notebook.connectors.trino.TrinoApi._get_columns') as _get_columns:
       _get_columns.return_value = [
       _get_columns.return_value = [
@@ -106,7 +112,6 @@ class TestTrinoApi(TestCase):
       assert 'columns' in response   # Check if 'columns' key exists in the response
       assert 'columns' in response   # Check if 'columns' key exists in the response
       assert response['columns'] == ['test_column1', 'test_column2', 'test_column3']
       assert response['columns'] == ['test_column1', 'test_column2', 'test_column3']
 
 
-
   def test_get_sample_data_success(self):
   def test_get_sample_data_success(self):
     with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
     with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
       # Mock TrinoQuery object and its execute method
       # Mock TrinoQuery object and its execute method
@@ -124,7 +129,6 @@ class TestTrinoApi(TestCase):
       assert (result['full_headers'] ==
       assert (result['full_headers'] ==
       [{'name': 'test_column1', 'type': 'string', 'comment': ''}, {'name': 'test_column2', 'type': 'string', 'comment': ''}])
       [{'name': 'test_column1', 'type': 'string', 'comment': ''}, {'name': 'test_column2', 'type': 'string', 'comment': ''}])
 
 
-
   def test_check_status_available(self):
   def test_check_status_available(self):
     mock_trino_request = MagicMock()
     mock_trino_request = MagicMock()
     self.trino_api.trino_request = mock_trino_request
     self.trino_api.trino_request = mock_trino_request
@@ -139,7 +143,6 @@ class TestTrinoApi(TestCase):
     assert result['status'] == 'available'
     assert result['status'] == 'available'
     assert result['next_uri'] == 'http://url'
     assert result['next_uri'] == 'http://url'
 
 
-
   def test_execute(self):
   def test_execute(self):
     with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
     with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
       # Mock TrinoQuery object and its methods
       # Mock TrinoQuery object and its methods
@@ -147,7 +150,6 @@ class TestTrinoApi(TestCase):
       mock_query_instance.query = "SELECT * FROM test_table"
       mock_query_instance.query = "SELECT * FROM test_table"
       mock_query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=[], columns=[])
       mock_query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=[], columns=[])
 
 
-
       mock_trino_request = MagicMock()
       mock_trino_request = MagicMock()
       self.trino_api.trino_request = mock_trino_request
       self.trino_api.trino_request = mock_trino_request
 
 
@@ -207,7 +209,6 @@ class TestTrinoApi(TestCase):
       }
       }
       assert result == expected_result
       assert result == expected_result
 
 
-
   def test_fetch_result(self):
   def test_fetch_result(self):
     # Mock TrinoRequest object and its methods
     # Mock TrinoRequest object and its methods
     mock_trino_request = MagicMock()
     mock_trino_request = MagicMock()
@@ -252,21 +253,20 @@ class TestTrinoApi(TestCase):
         } for column in _columns],
         } for column in _columns],
       'type': 'table'
       'type': 'table'
     }
     }
-    
+
     assert result == expected_result
     assert result == expected_result
     assert len(result['data']) == 6
     assert len(result['data']) == 6
     assert len(result['meta']) == 2
     assert len(result['meta']) == 2
 
 
-
   def test_get_select_query(self):
   def test_get_select_query(self):
     # Test with specified database, table, and column
     # Test with specified database, table, and column
-    database = 'test_db'
-    table = 'test_table'
+    database = '`test_schema.test_db`'
+    table = '`test_table`'
     column = 'test_column'
     column = 'test_column'
     expected_statement = (
     expected_statement = (
-        "SELECT test_column\n"
-        "FROM test_db.test_table\n"
-        "LIMIT 100\n"
+        'SELECT "test_column"\n'
+        'FROM "test_schema"."test_db"."test_table"\n'
+        'LIMIT 100\n'
     )
     )
     assert (
     assert (
       self.trino_api._get_select_query(database, table, column) ==
       self.trino_api._get_select_query(database, table, column) ==
@@ -276,38 +276,37 @@ class TestTrinoApi(TestCase):
     database = 'test_db'
     database = 'test_db'
     table = 'test_table'
     table = 'test_table'
     expected_statement = (
     expected_statement = (
-        "SELECT *\n"
-        "FROM test_db.test_table\n"
-        "LIMIT 100\n"
+        'SELECT *\n'
+        'FROM "test_db"."test_table"\n'
+        'LIMIT 100\n'
     )
     )
     assert (
     assert (
       self.trino_api._get_select_query(database, table) ==
       self.trino_api._get_select_query(database, table) ==
       expected_statement)
       expected_statement)
 
 
-
   def test_explain(self):
   def test_explain(self):
     with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
     with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
       snippet = {'statement': 'SELECT * FROM tpch.sf1.partsupp LIMIT 100;', 'database': 'tpch.sf1'}
       snippet = {'statement': 'SELECT * FROM tpch.sf1.partsupp LIMIT 100;', 'database': 'tpch.sf1'}
-      output = [['Trino version: 432\nFragment 0 [SINGLE]\n    Output layout: [partkey, suppkey, availqty, supplycost, comment]\n    '\
-      'Output partitioning: SINGLE []\n    Output[columnNames = [partkey, suppkey, availqty, supplycost, comment]]\n    │   '\
-      'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n    │   '\
-      'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n    └─ Limit[count = 100]\n       │   '\
-      'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n       │   '\
-      'Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n       └─ LocalExchange[partitioning = SINGLE]\n          '\
-      '│   Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n          │   '\
-      'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n          └─ RemoteSource[sourceFragmentIds = [1]]\n'\
-      '                 Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n\n'\
-      'Fragment 1 [SOURCE]\n    Output layout: [partkey, suppkey, availqty, supplycost, comment]\n    Output partitioning: SINGLE []\n'\
-      '    LimitPartial[count = 100]\n    │   Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '\
-      'comment:varchar(199)]\n    │   Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n    └─ '\
-      'TableScan[table = tpch:sf1:partsupp]\n           Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '\
-      'comment:varchar(199)]\n           Estimates: {rows: 800000 (122.44MB), cpu: 122.44M, memory: 0B, network: 0B}\n           '\
-      'partkey := tpch:partkey\n           availqty := tpch:availqty\n           supplycost := tpch:supplycost\n           '\
+      output = [['Trino version: 432\nFragment 0 [SINGLE]\n    Output layout: [partkey, suppkey, availqty, supplycost, comment]\n    '
+      'Output partitioning: SINGLE []\n    Output[columnNames = [partkey, suppkey, availqty, supplycost, comment]]\n    │   '
+      'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n    │   '
+      'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n    └─ Limit[count = 100]\n       │   '
+      'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n       │   '
+      'Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n       └─ LocalExchange[partitioning = SINGLE]\n          '
+      '│   Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n          │   '
+      'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n          └─ RemoteSource[sourceFragmentIds = [1]]\n'
+      '                 Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n\n'
+      'Fragment 1 [SOURCE]\n    Output layout: [partkey, suppkey, availqty, supplycost, comment]\n    Output partitioning: SINGLE []\n'
+      '    LimitPartial[count = 100]\n    │   Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '
+      'comment:varchar(199)]\n    │   Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n    └─ '
+      'TableScan[table = tpch:sf1:partsupp]\n           Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '
+      'comment:varchar(199)]\n           Estimates: {rows: 800000 (122.44MB), cpu: 122.44M, memory: 0B, network: 0B}\n           '
+      'partkey := tpch:partkey\n           availqty := tpch:availqty\n           supplycost := tpch:supplycost\n           '
       'comment := tpch:comment\n           suppkey := tpch:suppkey\n\n']]
       'comment := tpch:comment\n           suppkey := tpch:suppkey\n\n']]
       # Mock TrinoQuery object and its execute method
       # Mock TrinoQuery object and its execute method
       query_instance = TrinoQuery.return_value
       query_instance = TrinoQuery.return_value
       query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=output, columns=[])
       query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=output, columns=[])
-      
+
       # Call the explain method
       # Call the explain method
       result = self.trino_api.explain(notebook=None, snippet=snippet)
       result = self.trino_api.explain(notebook=None, snippet=snippet)
 
 
@@ -325,7 +324,6 @@ class TestTrinoApi(TestCase):
       # Assert the exception message
       # Assert the exception message
       assert result['explanation'] == 'Mocked exception'
       assert result['explanation'] == 'Mocked exception'
 
 
-
   @patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
   @patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
   @patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
   @patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
   def test_auth_username_and_auth_password_default(self, mock_default_username, mock_default_password):
   def test_auth_username_and_auth_password_default(self, mock_default_username, mock_default_password):
@@ -334,7 +332,6 @@ class TestTrinoApi(TestCase):
     assert trino_api.auth_username == 'mocked_username'
     assert trino_api.auth_username == 'mocked_username'
     assert trino_api.auth_password == 'mocked_password'
     assert trino_api.auth_password == 'mocked_password'
 
 
-
   @patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
   @patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
   @patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
   @patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
   def test_auth_username_custom(self, mock_default_username, mock_default_password):
   def test_auth_username_custom(self, mock_default_username, mock_default_password):