浏览代码

[Trino] Adding explain query method (#3645)

Ayush Goyal 1 年之前
父节点
当前提交
1f64deec8a

+ 20 - 0
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -321,6 +321,26 @@ class TrinoApi(Api):
     ]
     ]
 
 
 
 
+  @query_error_handler
+  def explain(self, notebook, snippet):
+    statement = snippet['statement'].rstrip(';')
+    explanation = ''
+
+    if statement:
+      try:
+        TrinoQuery(self.trino_request, 'USE ' + snippet['database']).execute()
+        result = TrinoQuery(self.trino_request, 'EXPLAIN ' + statement).execute()
+        explanation = result.rows
+      except Exception as e:
+        explanation = str(e)
+
+    return {
+      'status': 0,
+      'explanation': explanation,
+      'statement': statement
+    }
+
+
   def download(self, notebook, snippet, file_format='csv'):
   def download(self, notebook, snippet, file_format='csv'):
     result_wrapper = TrinoExecutionWrapper(self, notebook, snippet)
     result_wrapper = TrinoExecutionWrapper(self, notebook, snippet)
 
 

+ 42 - 0
desktop/libs/notebook/src/notebook/connectors/trino_tests.py

@@ -253,3 +253,45 @@ class TestTrinoApi(unittest.TestCase):
       self.trino_api._get_select_query(database, table),
       self.trino_api._get_select_query(database, table),
       expected_statement
       expected_statement
     )
     )
+
+
+  def test_explain(self):
+    with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
+      snippet = {'statement': 'SELECT * FROM tpch.sf1.partsupp LIMIT 100;', 'database': 'tpch.sf1'}
+      output = [['Trino version: 432\nFragment 0 [SINGLE]\n    Output layout: [partkey, suppkey, availqty, supplycost, comment]\n    '\
+      'Output partitioning: SINGLE []\n    Output[columnNames = [partkey, suppkey, availqty, supplycost, comment]]\n    │   '\
+      'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n    │   '\
+      'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n    └─ Limit[count = 100]\n       │   '\
+      'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n       │   '\
+      'Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n       └─ LocalExchange[partitioning = SINGLE]\n          '\
+      '│   Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n          │   '\
+      'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n          └─ RemoteSource[sourceFragmentIds = [1]]\n'\
+      '                 Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n\n'\
+      'Fragment 1 [SOURCE]\n    Output layout: [partkey, suppkey, availqty, supplycost, comment]\n    Output partitioning: SINGLE []\n'\
+      '    LimitPartial[count = 100]\n    │   Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '\
+      'comment:varchar(199)]\n    │   Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n    └─ '\
+      'TableScan[table = tpch:sf1:partsupp]\n           Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '\
+      'comment:varchar(199)]\n           Estimates: {rows: 800000 (122.44MB), cpu: 122.44M, memory: 0B, network: 0B}\n           '\
+      'partkey := tpch:partkey\n           availqty := tpch:availqty\n           supplycost := tpch:supplycost\n           '\
+      'comment := tpch:comment\n           suppkey := tpch:suppkey\n\n']]
+      # Mock TrinoQuery object and its execute method
+      query_instance = TrinoQuery.return_value
+      query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=output, columns=[])
+      
+      # Call the explain method
+      result = self.trino_api.explain(notebook=None, snippet=snippet)
+
+      # Assert the result
+      assert_equal(result['status'], 0)
+      assert_equal(result['explanation'], output)
+      assert_equal(result['statement'], 'SELECT * FROM tpch.sf1.partsupp LIMIT 100')
+
+      query_instance = TrinoQuery.return_value
+      query_instance.execute.side_effect = Exception('Mocked exception')
+
+      # Call the explain method
+      result = self.trino_api.explain(notebook=None, snippet=snippet)
+
+      # Assert the exception message
+      assert_equal(result['explanation'], 'Mocked exception')
+