Browse Source

[Trino] Refactor Trino-related code for readability and maintainability (#3637)

Ayush Goyal 1 year ago
parent
commit
a17d4995b5

+ 3 - 3
desktop/core/src/desktop/js/apps/notebook/snippet.js

@@ -1858,7 +1858,7 @@ class Snippet {
 
             if (self.type() === 'trino') {
               const existing_handle = self.result.handle();
-              existing_handle.row_n = data.handle.row_n;
+              existing_handle.row_count = data.handle.row_count;
               existing_handle.next_uri = data.handle.next_uri;
             }
             self.showLogs(true);
@@ -2189,7 +2189,7 @@ class Snippet {
 
                 if (self.type() === 'trino') {
                   const existing_handle = self.result.handle();
-                  existing_handle.row_n = data.result.row_n;
+                  existing_handle.row_count = data.result.row_count;
                   existing_handle.next_uri = data.result.next_uri;
                 }
               } else {
@@ -2369,7 +2369,7 @@ class Snippet {
                 ) {
                   if (self.type() === 'trino') {
                     const existing_handle = self.result.handle();
-                    existing_handle.row_n = 0;
+                    existing_handle.row_count = 0;
                     existing_handle.next_uri = data.query_status.next_uri;
                   }
                   const delay = self.result.executionTime() > 45000 ? 5000 : 1000; // 5s if more than 45s

+ 65 - 84
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -17,7 +17,6 @@
 
 import logging
 import json
-import posixpath
 import requests
 import sys
 import textwrap
@@ -26,6 +25,8 @@ import time
 from django.utils.translation import gettext as _
 from urllib.parse import urlparse
 
+from beeswax import conf
+from beeswax import data_export
 from desktop.lib import export_csvxls
 from desktop.lib.i18n import force_unicode
 from desktop.lib.rest.http_client import HttpClient, RestException
@@ -36,6 +37,7 @@ from trino import exceptions
 from trino.auth import BasicAuthentication
 from trino.client import ClientSession, TrinoRequest, TrinoQuery
 
+
 def query_error_handler(func):
   def decorator(*args, **kwargs):
     try:
@@ -53,14 +55,10 @@ def query_error_handler(func):
   return decorator
 
 
-
 class TrinoApi(Api):
-
   def __init__(self, user, interpreter=None):
     Api.__init__(self, user, interpreter=interpreter)
-
     self.options = interpreter['options']
-
     self.server_host, self.server_port, self.http_scheme = self.parse_api_url(self.options['url'])
     self.auth = None
 
@@ -70,8 +68,7 @@ class TrinoApi(Api):
       self.auth = BasicAuthentication(self.auth_username, self.auth_password)
 
     trino_session = ClientSession(user.username)
-    
-    self.db = TrinoRequest(
+    self.trino_request = TrinoRequest(
       host=self.server_host,
       port=self.server_port,
       client_session=trino_session,
@@ -93,16 +90,16 @@ class TrinoApi(Api):
   @query_error_handler
   def execute(self, notebook, snippet):
     database = snippet['database']
-    query_client = TrinoQuery(self.db, 'USE ' + database)
+    query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
-    
+
     statement = snippet['statement'].rstrip(';')
-    query_client = TrinoQuery(self.db, statement)
-    response = self.db.post(query_client.query)
-    status = self.db.process(response)
+    query_client = TrinoQuery(self.trino_request, statement)
+    response = self.trino_request.post(query_client.query)
+    status = self.trino_request.process(response)
 
     return {
-      'row_n': 0,
+      'row_count': 0,
       'next_uri': status.next_uri,
       'sync': None,
       'has_result_set': status.next_uri is not None,
@@ -127,75 +124,68 @@ class TrinoApi(Api):
   def check_status(self, notebook, snippet):
     response = {}
     status = 'expired'
+    next_uri = snippet['result']['handle']['next_uri']
 
-    if snippet['result']['handle']['next_uri'] is None:
+    if next_uri is None:
       status = 'available'
     else:
-      _response = self.db.get(snippet['result']['handle']['next_uri'])
-      _status = self.db.process(_response)
+      _response = self.trino_request.get(next_uri)
+      _status = self.trino_request.process(_response)
       if _status.stats['state'] == 'QUEUED':
         status = 'waiting'
       elif _status.stats['state'] == 'RUNNING':
-        status = 'available' # need to varify
+        status = 'available' # need to verify
       else:
         status = 'available'
 
     response['status'] = status
-
-    if status != 'available':
-      response['next_uri'] = _status.next_uri
-    else:
-      response['next_uri'] = snippet['result']['handle']['next_uri']
-
+    response['next_uri'] = _status.next_uri if status != 'available' else next_uri
     return response
 
-
   @query_error_handler
   def fetch_result(self, notebook, snippet, rows, start_over):
     data = []
-    _columns = []
-    _next_uri = snippet['result']['handle']['next_uri']
-    processed_rows = snippet['result']['handle'].get('row_n', 0)
+    columns = []
+    next_uri = snippet['result']['handle']['next_uri']
+    processed_rows = snippet['result']['handle'].get('row_count', 0)
     status = False
 
     if processed_rows == 0:
       data = snippet['result']['handle']['result']['data']
 
-    while _next_uri:
+    while next_uri:
       try:
-        response = self.db.get(_next_uri)
+        response = self.trino_request.get(next_uri)
       except requests.exceptions.RequestException as e:
         raise trino.exceptions.TrinoConnectionError("failed to fetch: {}".format(e))
 
-      status = self.db.process(response)
+      status = self.trino_request.process(response)
       data += status.rows
-      _columns = status.columns
+      columns = status.columns
 
       if len(data) >= processed_rows + 100:
         if processed_rows < 0:
-          data = data[0:100]
+          data = data[:100]
         else:
           data = data[processed_rows:processed_rows + 100]
         break
 
-      _next_uri = status.next_uri
+      next_uri = status.next_uri
       current_length = len(data)
       data = data[processed_rows:processed_rows + 100]
-      processed_rows = processed_rows - current_length
+      processed_rows -= current_length
 
     return {
-        'row_n': 100 + processed_rows,
-        'next_uri': _next_uri,
-        'has_more': bool(status.next_uri) if status else False,
-        'data': data or [],
-        'meta': [{
-            'name': column['name'],
-            'type': column['type'],
-            'comment': ''
-          }
-          for column in _columns if status
-        ],
-        'type': 'table'
+      'row_count': 100 + processed_rows,
+      'next_uri': next_uri,
+      'has_more': bool(status.next_uri) if status else False,
+      'data': data or [],
+      'meta': [{
+        'name': column['name'],
+        'type': column['type'],
+        'comment': ''
+        } for column in columns] if status else [],
+      'type': 'table'
     }
 
 
@@ -203,8 +193,6 @@ class TrinoApi(Api):
   def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
     response = {}
 
-    # if catalog is None:
-    #   response['catalogs'] = self._show_catalogs()
     if database is None:
       response['databases'] = self._show_databases()
     elif table is None:
@@ -213,10 +201,10 @@ class TrinoApi(Api):
       columns = self._get_columns(database, table)
       response['columns'] = [col['name'] for col in columns]
       response['extended_columns'] = [{
-          'comment': col.get('comment'),
-          'name': col.get('name'),
-          'type': col['type']
-        }
+        'comment': col.get('comment'),
+        'name': col.get('name'),
+        'type': col['type']
+      }
         for col in columns
       ]
 
@@ -225,9 +213,8 @@ class TrinoApi(Api):
 
   @query_error_handler
   def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
-    
     statement = self._get_select_query(database, table, column, operation)
-    query_client = TrinoQuery(self.db, statement)
+    query_client = TrinoQuery(self.trino_request, statement)
     query_client.execute()
 
     response = {
@@ -239,7 +226,7 @@ class TrinoApi(Api):
     response['full_headers'] = query_client.columns
 
     return response
-  
+
 
   def _get_select_query(self, database, table, column=None, operation=None, limit=100):
     if operation == 'hello':
@@ -251,11 +238,11 @@ class TrinoApi(Api):
           FROM %(database)s.%(table)s
           LIMIT %(limit)s
           ''' % {
-            'database': database,
-            'table': table,
-            'column': column,
-            'limit': limit,
-        })
+        'database': database,
+        'table': table,
+        'column': column,
+        'limit': limit,
+      })
 
     return statement
 
@@ -263,7 +250,7 @@ class TrinoApi(Api):
   def close_statement(self, notebook, snippet):
     try:
       if snippet['result']['handle']['next_uri']:
-        self.db.delete(snippet['result']['handle']['next_uri'])
+        self.trino_request.delete(snippet['result']['handle']['next_uri'])
       else:
         return {'status': -1} # missing operation ids
     except Exception as e:
@@ -285,8 +272,7 @@ class TrinoApi(Api):
     databases = []
 
     for catalog in catalogs:
-
-      query_client = TrinoQuery(self.db, 'SHOW SCHEMAS FROM ' + catalog)
+      query_client = TrinoQuery(self.trino_request, 'SHOW SCHEMAS FROM ' + catalog)
       response = query_client.execute()
       databases += [f'{catalog}.{item}' for sublist in response.rows for item in sublist]
 
@@ -294,8 +280,7 @@ class TrinoApi(Api):
 
 
   def _show_catalogs(self):
-
-    query_client = TrinoQuery(self.db, 'SHOW CATALOGS')
+    query_client = TrinoQuery(self.trino_request, 'SHOW CATALOGS')
     response = query_client.execute()
     res = response.rows
     catalogs = [item for sublist in res for item in sublist]
@@ -304,41 +289,37 @@ class TrinoApi(Api):
 
 
   def _show_tables(self, database):
-    
-    query_client = TrinoQuery(self.db, 'USE ' + database)
+    query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
-    query_client = TrinoQuery(self.db, 'SHOW TABLES')
+    query_client = TrinoQuery(self.trino_request, 'SHOW TABLES')
     response = query_client.execute()
     tables = response.rows
     return [{
-        'name': table[0],
-        'type': 'table',
-        'comment': '',
-      }
+      'name': table[0],
+      'type': 'table',
+      'comment': '',
+    }
       for table in tables
     ]
 
 
   def _get_columns(self, database, table):
-
-    query_client = TrinoQuery(self.db, 'USE ' + database)
+    query_client = TrinoQuery(self.trino_request, 'USE ' + database)
     query_client.execute()
-    query_client = TrinoQuery(self.db, 'DESCRIBE ' + table)
+    query_client = TrinoQuery(self.trino_request, 'DESCRIBE ' + table)
     response = query_client.execute()
     columns = response.rows
 
     return [{
-        'name': col[0],
-        'type': col[1],
-        'comment': '',
-      }
+      'name': col[0],
+      'type': col[1],
+      'comment': '',
+    }
       for col in columns
     ]
-  
-  def download(self, notebook, snippet, file_format='csv'):
-    from beeswax import data_export #TODO: Move to notebook?
-    from beeswax import conf
 
+
+  def download(self, notebook, snippet, file_format='csv'):
     result_wrapper = TrinoExecutionWrapper(self, notebook, snippet)
 
     max_rows = conf.DOWNLOAD_ROW_LIMIT.get()
@@ -369,7 +350,7 @@ class TrinoExecutionWrapper(ExecutionWrapper):
       result = self.snippet['result']['handle']['result']
     else:
       result = self.api.fetch_result(self.notebook, self.snippet, rows, start_over)
-      self.snippet['result']['handle']['row_n'] = result['row_n']
+      self.snippet['result']['handle']['row_count'] = result['row_count']
       self.snippet['result']['handle']['next_uri'] = result['next_uri']
 
     return ResultWrapper(result.get('meta'), result.get('data'), result.get('has_more'))