浏览代码

[Trino] Improve session management and property handling. (#3942)

Ayush Goyal 10 月之前
父节点
当前提交
2810ecccc6

+ 49 - 5
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -28,6 +28,7 @@ from trino.client import ClientSession, TrinoQuery, TrinoRequest
 from trino.exceptions import TrinoConnectionError
 
 from beeswax import conf, data_export
+from desktop.auth.backend import rewrite_user
 from desktop.conf import AUTH_PASSWORD as DEFAULT_AUTH_PASSWORD, AUTH_USERNAME as DEFAULT_AUTH_USERNAME
 from desktop.lib import export_csvxls
 from desktop.lib.conf import coerce_password_from_script
@@ -37,6 +38,7 @@ from desktop.lib.rest.resource import Resource
 from notebook.connectors.base import Api, ExecutionWrapper, QueryError, ResultWrapper
 
 LOG = logging.getLogger()
+SESSION_KEY = '%(username)s-%(interpreter_name)s'
 
 
 def query_error_handler(func):
@@ -71,11 +73,12 @@ class TrinoApi(Api):
       self.auth_password = auth_password
       self.auth = BasicAuthentication(self.auth_username, self.auth_password)
 
-    trino_session = ClientSession(user.username)
+    self.session_info = self.create_session()
+    self.trino_session = ClientSession(self.user.username, properties=self.session_info['properties'])
     self.trino_request = TrinoRequest(
       host=self.server_host,
       port=self.server_port,
-      client_session=trino_session,
+      client_session=self.trino_session,
       http_scheme=self.http_scheme,
       auth=self.auth
     )
@@ -109,9 +112,48 @@ class TrinoApi(Api):
     parsed_url = urlparse(api_url)
     return parsed_url.hostname, parsed_url.port, parsed_url.scheme
 
+  def _get_session_key(self):
+    return SESSION_KEY % {
+      'username': self.user.username if hasattr(self.user, 'username') else self.user,
+      'interpreter_name': self.interpreter['name']
+    }
+
+  def _get_session_info_from_user(self):
+    self.user = rewrite_user(self.user)
+    session_key = self._get_session_key()
+
+    if self.user.profile.data.get(session_key):
+      return self.user.profile.data[session_key]
+
+  def _set_session_info_to_user(self, session_info):
+    self.user = rewrite_user(self.user)
+    session_key = self._get_session_key()
+
+    self.user.profile.update_data({session_key: session_info})
+    self.user.profile.save()
+
+  def _remove_session_info_from_user(self):
+    self.user = rewrite_user(self.user)
+    session_key = self._get_session_key()
+
+    if self.user.profile.data.get(session_key):
+      json_data = self.user.profile.data
+      json_data.pop(session_key)
+      self.user.profile.json_data = json.dumps(json_data)
+
+    self.user.profile.save()
+
   @query_error_handler
   def create_session(self, lang=None, properties=None):
-    pass
+    properties = properties or self._get_session_info_from_user()
+
+    new_session_info = {
+        'type': lang,
+        'id': None,
+        'properties': properties if not None else []
+    }
+
+    return new_session_info
 
   @query_error_handler
   def execute(self, notebook, snippet):
@@ -207,6 +249,9 @@ class TrinoApi(Api):
       data = data[processed_rows:processed_rows + 100]
       processed_rows -= current_length
 
+    properties = self.trino_session.properties
+    self._set_session_info_to_user(properties)
+
     return {
       'row_count': 100 + processed_rows,
       'next_uri': next_uri,
@@ -292,8 +337,7 @@ class TrinoApi(Api):
     return {'status': 0}
 
   def close_session(self, session):
-    # Avoid closing session on page refresh or editor close for now
-    pass
+    self._remove_session_info_from_user()
 
   def _show_databases(self):
     catalogs = self._show_catalogs()

+ 4 - 2
desktop/libs/notebook/src/notebook/connectors/trino_tests.py

@@ -36,7 +36,8 @@ class TestTrinoApi(TestCase):
     cls.interpreter = {
       'options': {
         'url': 'https://example.com:8080'
-      }
+      },
+      'name': 'trino'
     }
     # Initialize TrinoApi with mock user and interpreter
     cls.trino_api = TrinoApi(cls.user, interpreter=cls.interpreter)
@@ -358,7 +359,8 @@ class TestTrinoApi(TestCase):
       'options': {
         'url': 'https://example.com:8080',
         'auth_password_script': 'custom_script'
-      }
+      },
+      'name': 'trino'
     }
 
     with patch('notebook.connectors.trino.coerce_password_from_script', return_value='custom_password_script'):