Browse Source

[Trino] Enhance the logic to efficiently retrieve the rows (#3972)

Ayush Goyal 10 months ago
parent
commit
887bba5d4e

+ 3 - 0
desktop/core/src/desktop/js/apps/notebook/snippet.js

@@ -1864,6 +1864,7 @@ class Snippet {
             if (self.type() === 'trino') {
             if (self.type() === 'trino') {
               const existing_handle = self.result.handle();
               const existing_handle = self.result.handle();
               existing_handle.row_count = data.handle.row_count;
               existing_handle.row_count = data.handle.row_count;
+              existing_handle.rows_remaining = data.handle.rows_remaining;
               existing_handle.next_uri = data.handle.next_uri;
               existing_handle.next_uri = data.handle.next_uri;
             }
             }
             self.showLogs(true);
             self.showLogs(true);
@@ -2195,6 +2196,7 @@ class Snippet {
                 if (self.type() === 'trino') {
                 if (self.type() === 'trino') {
                   const existing_handle = self.result.handle();
                   const existing_handle = self.result.handle();
                   existing_handle.row_count = data.result.row_count;
                   existing_handle.row_count = data.result.row_count;
+                  existing_handle.rows_remaining = data.result.rows_remaining;
                   existing_handle.next_uri = data.result.next_uri;
                   existing_handle.next_uri = data.result.next_uri;
                 }
                 }
               } else {
               } else {
@@ -2369,6 +2371,7 @@ class Snippet {
                   if (self.type() === 'trino') {
                   if (self.type() === 'trino') {
                     const existing_handle = self.result.handle();
                     const existing_handle = self.result.handle();
                     existing_handle.row_count = 0;
                     existing_handle.row_count = 0;
+                    existing_handle.rows_remaining = 0;
                     existing_handle.next_uri = data.query_status.next_uri;
                     existing_handle.next_uri = data.query_status.next_uri;
                   }
                   }
                   const delay = self.result.executionTime() > 45000 ? 5000 : 1000; // 5s if more than 45s
                   const delay = self.result.executionTime() > 45000 ? 5000 : 1000; // 5s if more than 45s

+ 16 - 13
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -170,6 +170,7 @@ class TrinoApi(Api):
 
 
     response = {
     response = {
       'row_count': 0,
       'row_count': 0,
+      'rows_remaining': 0,
       'next_uri': status.next_uri,
       'next_uri': status.next_uri,
       'sync': None,
       'sync': None,
       'has_result_set': status.next_uri is not None,
       'has_result_set': status.next_uri is not None,
@@ -219,10 +220,11 @@ class TrinoApi(Api):
     data = []
     data = []
     columns = []
     columns = []
     next_uri = snippet['result']['handle']['next_uri']
     next_uri = snippet['result']['handle']['next_uri']
-    processed_rows = snippet['result']['handle'].get('row_count', 0)
+    row_count = snippet['result']['handle'].get('row_count', 0)
+    rows_remaining = snippet['result']['handle'].get('rows_remaining', 0)
     status = False
     status = False
 
 
-    if processed_rows == 0:
+    if row_count == 0:
       data = snippet['result']['handle']['result']['data']
       data = snippet['result']['handle']['result']['data']
 
 
     while next_uri:
     while next_uri:
@@ -235,25 +237,25 @@ class TrinoApi(Api):
       data += status.rows
       data += status.rows
       columns = status.columns
       columns = status.columns
 
 
-      if len(data) >= processed_rows + 100:
-        if processed_rows < 0:
-          data = data[:100]
-        else:
-          data = data[processed_rows:processed_rows + 100]
+      if rows_remaining:
+        data = data[-rows_remaining:]  # Trim the data to only include the remaining rows
+        rows_remaining = 0  # Reset rows_remaining since we've handled the trimming
+
+      if len(data) > 100:
+        rows_remaining = len(data) - 100  # no of rows remaining to fetch in the present uri
         break
         break
+      rows_remaining = 0
 
 
       next_uri = status.next_uri
       next_uri = status.next_uri
-      current_length = len(data)
-      if processed_rows < 0:
-        processed_rows = 0
-      data = data[processed_rows:processed_rows + 100]
-      processed_rows -= current_length
+
+    data = data[:100]
 
 
     properties = self.trino_session.properties
     properties = self.trino_session.properties
     self._set_session_info_to_user(properties)
     self._set_session_info_to_user(properties)
 
 
     return {
     return {
-      'row_count': 100 + processed_rows,
+      'row_count': len(data) + row_count,
+      'rows_remaining': rows_remaining,
       'next_uri': next_uri,
       'next_uri': next_uri,
       'has_more': bool(status.next_uri) if status else False,
       'has_more': bool(status.next_uri) if status else False,
       'data': data or [],
       'data': data or [],
@@ -456,6 +458,7 @@ class TrinoExecutionWrapper(ExecutionWrapper):
     else:
     else:
       result = self.api.fetch_result(self.notebook, self.snippet, rows, start_over)
       result = self.api.fetch_result(self.notebook, self.snippet, rows, start_over)
       self.snippet['result']['handle']['row_count'] = result['row_count']
       self.snippet['result']['handle']['row_count'] = result['row_count']
+      self.snippet['result']['handle']['rows_remaining'] = result['rows_remaining']
       self.snippet['result']['handle']['next_uri'] = result['next_uri']
       self.snippet['result']['handle']['next_uri'] = result['next_uri']
 
 
     return ResultWrapper(result.get('meta'), result.get('data'), result.get('has_more'))
     return ResultWrapper(result.get('meta'), result.get('data'), result.get('has_more'))

+ 61 - 12
desktop/libs/notebook/src/notebook/connectors/trino_tests.py

@@ -178,6 +178,7 @@ class TestTrinoApi(TestCase):
 
 
       expected_result = {
       expected_result = {
         'row_count': 0,
         'row_count': 0,
+        'rows_remaining': 0,
         'next_uri': 'http://url',
         'next_uri': 'http://url',
         'sync': None,
         'sync': None,
         'has_result_set': True,
         'has_result_set': True,
@@ -204,6 +205,7 @@ class TestTrinoApi(TestCase):
 
 
       expected_result = {
       expected_result = {
         'row_count': 0,
         'row_count': 0,
+        'rows_remaining': 0,
         'next_uri': 'http://url',
         'next_uri': 'http://url',
         'sync': None,
         'sync': None,
         'has_result_set': True,
         'has_result_set': True,
@@ -220,7 +222,7 @@ class TestTrinoApi(TestCase):
       }
       }
       assert result == expected_result
       assert result == expected_result
 
 
-  def test_fetch_result(self):
+  def test_fetch_result_more_than_100(self):
     # Mock TrinoRequest object and its methods
     # Mock TrinoRequest object and its methods
     mock_trino_request = MagicMock()
     mock_trino_request = MagicMock()
     self.trino_api.trino_request = mock_trino_request
     self.trino_api.trino_request = mock_trino_request
@@ -229,18 +231,67 @@ class TestTrinoApi(TestCase):
     mock_trino_request.get.return_value = MagicMock()
     mock_trino_request.get.return_value = MagicMock()
     _columns = [{'comment': '', 'name': 'test_column1', 'type': 'str'}, {'comment': '', 'name': 'test_column2', 'type': 'str'}]
     _columns = [{'comment': '', 'name': 'test_column1', 'type': 'str'}, {'comment': '', 'name': 'test_column2', 'type': 'str'}]
 
 
+    # Generate more than 100 rows of mock data
+    mock_data = [[f'value{i}', f'value{i + 1}'] for i in range(1, 201, 1)]
+
     mock_trino_request.process.side_effect = [
     mock_trino_request.process.side_effect = [
       MagicMock(
       MagicMock(
-        stats={'state': 'FINISHED'}, next_uri='http://url', id=123,
-        rows=[['value1', 'value2'], ['value3', 'value4']], columns=_columns
+        stats={'state': 'FINISHED'}, next_uri='http://url1', id=123,
+        rows=mock_data[:57], columns=_columns
       ),
       ),
       MagicMock(
       MagicMock(
-        stats={'state': 'FINISHED'}, next_uri='http://url1', id=124,
-        rows=[['value5', 'value6'], ['value7', 'value8']], columns=_columns
+        stats={'state': 'FINISHED'}, next_uri='http://url2', id=124,
+        rows=mock_data[57:105], columns=_columns
       ),
       ),
       MagicMock(
       MagicMock(
         stats={'state': 'FINISHED'}, next_uri=None, id=125,
         stats={'state': 'FINISHED'}, next_uri=None, id=125,
-        rows=[['value9', 'value10'], ['value11', 'value12']], columns=_columns
+        rows=mock_data[105:], columns=_columns
+      )
+    ]
+
+    # Call the fetch_result method
+    result = self.trino_api.fetch_result(
+      notebook={}, snippet={'result': {'handle': {'next_uri': 'http://url', 'result': {'data': []}}}}, rows=0, start_over=False
+    )
+
+    expected_result = {
+      'row_count': 100,
+      'rows_remaining': 5,
+      'next_uri': 'http://url1',
+      'has_more': True,
+      'data': mock_data[:100],
+      'meta': [{
+        'name': column['name'],
+        'type': column['type'],
+        'comment': ''
+        } for column in _columns],
+      'type': 'table'
+    }
+
+    assert result == expected_result
+    assert len(result['data']) == 100
+    assert len(result['meta']) == 2
+
+  def test_fetch_result_less_than_100(self):
+    # Mock TrinoRequest object and its methods
+    mock_trino_request = MagicMock()
+    self.trino_api.trino_request = mock_trino_request
+
+    # Configure the MagicMock object to return expected responses
+    mock_trino_request.get.return_value = MagicMock()
+    _columns = [{'comment': '', 'name': 'test_column1', 'type': 'str'}, {'comment': '', 'name': 'test_column2', 'type': 'str'}]
+
+    # Generate 100 rows of mock data
+    mock_data = [[f'value{i}', f'value{i + 1}'] for i in range(1, 90, 1)]
+
+    mock_trino_request.process.side_effect = [
+      MagicMock(
+        stats={'state': 'FINISHED'}, next_uri='http://url1', id=123,
+        rows=mock_data[:57], columns=_columns
+      ),
+      MagicMock(
+        stats={'state': 'FINISHED'}, next_uri=None, id=124,
+        rows=mock_data[57:], columns=_columns
       )
       )
     ]
     ]
 
 
@@ -250,13 +301,11 @@ class TestTrinoApi(TestCase):
     )
     )
 
 
     expected_result = {
     expected_result = {
-      'row_count': 94,
+      'row_count': 89,
+      'rows_remaining': 0,
       'next_uri': None,
       'next_uri': None,
       'has_more': False,
       'has_more': False,
-      'data': [
-        ['value1', 'value2'], ['value3', 'value4'], ['value5', 'value6'],
-        ['value7', 'value8'], ['value9', 'value10'], ['value11', 'value12']
-      ],
+      'data': mock_data[:90],
       'meta': [{
       'meta': [{
         'name': column['name'],
         'name': column['name'],
         'type': column['type'],
         'type': column['type'],
@@ -266,7 +315,7 @@ class TestTrinoApi(TestCase):
     }
     }
 
 
     assert result == expected_result
     assert result == expected_result
-    assert len(result['data']) == 6
+    assert len(result['data']) == 89
     assert len(result['meta']) == 2
     assert len(result['meta']) == 2
 
 
   def test_get_select_query(self):
   def test_get_select_query(self):