|
@@ -28,6 +28,7 @@ from trino.client import ClientSession, TrinoQuery, TrinoRequest
|
|
|
from trino.exceptions import TrinoConnectionError
|
|
from trino.exceptions import TrinoConnectionError
|
|
|
|
|
|
|
|
from beeswax import conf, data_export
|
|
from beeswax import conf, data_export
|
|
|
|
|
+from desktop.auth.backend import rewrite_user
|
|
|
from desktop.conf import AUTH_PASSWORD as DEFAULT_AUTH_PASSWORD, AUTH_USERNAME as DEFAULT_AUTH_USERNAME
|
|
from desktop.conf import AUTH_PASSWORD as DEFAULT_AUTH_PASSWORD, AUTH_USERNAME as DEFAULT_AUTH_USERNAME
|
|
|
from desktop.lib import export_csvxls
|
|
from desktop.lib import export_csvxls
|
|
|
from desktop.lib.conf import coerce_password_from_script
|
|
from desktop.lib.conf import coerce_password_from_script
|
|
@@ -37,6 +38,7 @@ from desktop.lib.rest.resource import Resource
|
|
|
from notebook.connectors.base import Api, ExecutionWrapper, QueryError, ResultWrapper
|
|
from notebook.connectors.base import Api, ExecutionWrapper, QueryError, ResultWrapper
|
|
|
|
|
|
|
|
LOG = logging.getLogger()
|
|
LOG = logging.getLogger()
|
|
|
|
|
+SESSION_KEY = '%(username)s-%(interpreter_name)s'
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_error_handler(func):
|
|
def query_error_handler(func):
|
|
@@ -71,11 +73,12 @@ class TrinoApi(Api):
|
|
|
self.auth_password = auth_password
|
|
self.auth_password = auth_password
|
|
|
self.auth = BasicAuthentication(self.auth_username, self.auth_password)
|
|
self.auth = BasicAuthentication(self.auth_username, self.auth_password)
|
|
|
|
|
|
|
|
- trino_session = ClientSession(user.username)
|
|
|
|
|
|
|
+ self.session_info = self.create_session()
|
|
|
|
|
+ self.trino_session = ClientSession(self.user.username, properties=self.session_info['properties'])
|
|
|
self.trino_request = TrinoRequest(
|
|
self.trino_request = TrinoRequest(
|
|
|
host=self.server_host,
|
|
host=self.server_host,
|
|
|
port=self.server_port,
|
|
port=self.server_port,
|
|
|
- client_session=trino_session,
|
|
|
|
|
|
|
+ client_session=self.trino_session,
|
|
|
http_scheme=self.http_scheme,
|
|
http_scheme=self.http_scheme,
|
|
|
auth=self.auth
|
|
auth=self.auth
|
|
|
)
|
|
)
|
|
@@ -109,9 +112,48 @@ class TrinoApi(Api):
|
|
|
parsed_url = urlparse(api_url)
|
|
parsed_url = urlparse(api_url)
|
|
|
return parsed_url.hostname, parsed_url.port, parsed_url.scheme
|
|
return parsed_url.hostname, parsed_url.port, parsed_url.scheme
|
|
|
|
|
|
|
|
|
|
+ def _get_session_key(self):
|
|
|
|
|
+ return SESSION_KEY % {
|
|
|
|
|
+ 'username': self.user.username if hasattr(self.user, 'username') else self.user,
|
|
|
|
|
+ 'interpreter_name': self.interpreter['name']
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def _get_session_info_from_user(self):
|
|
|
|
|
+ self.user = rewrite_user(self.user)
|
|
|
|
|
+ session_key = self._get_session_key()
|
|
|
|
|
+
|
|
|
|
|
+ if self.user.profile.data.get(session_key):
|
|
|
|
|
+ return self.user.profile.data[session_key]
|
|
|
|
|
+
|
|
|
|
|
+ def _set_session_info_to_user(self, session_info):
|
|
|
|
|
+ self.user = rewrite_user(self.user)
|
|
|
|
|
+ session_key = self._get_session_key()
|
|
|
|
|
+
|
|
|
|
|
+ self.user.profile.update_data({session_key: session_info})
|
|
|
|
|
+ self.user.profile.save()
|
|
|
|
|
+
|
|
|
|
|
+ def _remove_session_info_from_user(self):
|
|
|
|
|
+ self.user = rewrite_user(self.user)
|
|
|
|
|
+ session_key = self._get_session_key()
|
|
|
|
|
+
|
|
|
|
|
+ if self.user.profile.data.get(session_key):
|
|
|
|
|
+ json_data = self.user.profile.data
|
|
|
|
|
+ json_data.pop(session_key)
|
|
|
|
|
+ self.user.profile.json_data = json.dumps(json_data)
|
|
|
|
|
+
|
|
|
|
|
+ self.user.profile.save()
|
|
|
|
|
+
|
|
|
@query_error_handler
|
|
@query_error_handler
|
|
|
def create_session(self, lang=None, properties=None):
|
|
def create_session(self, lang=None, properties=None):
|
|
|
- pass
|
|
|
|
|
|
|
+ properties = properties or self._get_session_info_from_user()
|
|
|
|
|
+
|
|
|
|
|
+ new_session_info = {
|
|
|
|
|
+ 'type': lang,
|
|
|
|
|
+ 'id': None,
|
|
|
|
|
+ 'properties': properties if not None else []
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return new_session_info
|
|
|
|
|
|
|
|
@query_error_handler
|
|
@query_error_handler
|
|
|
def execute(self, notebook, snippet):
|
|
def execute(self, notebook, snippet):
|
|
@@ -207,6 +249,9 @@ class TrinoApi(Api):
|
|
|
data = data[processed_rows:processed_rows + 100]
|
|
data = data[processed_rows:processed_rows + 100]
|
|
|
processed_rows -= current_length
|
|
processed_rows -= current_length
|
|
|
|
|
|
|
|
|
|
+ properties = self.trino_session.properties
|
|
|
|
|
+ self._set_session_info_to_user(properties)
|
|
|
|
|
+
|
|
|
return {
|
|
return {
|
|
|
'row_count': 100 + processed_rows,
|
|
'row_count': 100 + processed_rows,
|
|
|
'next_uri': next_uri,
|
|
'next_uri': next_uri,
|
|
@@ -292,8 +337,7 @@ class TrinoApi(Api):
|
|
|
return {'status': 0}
|
|
return {'status': 0}
|
|
|
|
|
|
|
|
def close_session(self, session):
|
|
def close_session(self, session):
|
|
|
- # Avoid closing session on page refresh or editor close for now
|
|
|
|
|
- pass
|
|
|
|
|
|
|
+ self._remove_session_info_from_user()
|
|
|
|
|
|
|
|
def _show_databases(self):
|
|
def _show_databases(self):
|
|
|
catalogs = self._show_catalogs()
|
|
catalogs = self._show_catalogs()
|