Pārlūkot izejas kodu

[Trino] Add ldap/basic username and password authentication support (#3624)

Ayush Goyal 1 gadu atpakaļ
vecāks
revīzija
8dd2540553

+ 7 - 0
desktop/conf.dist/hue.ini

@@ -1145,6 +1145,13 @@ tls=no
 #   name=Shell
 #   interface=oozie
 
+# [[[trino]]]
+#   name=Trino
+#   interface=trino
+#   ## username and password for LDAP enabled over HTTPS.
+#   options='{"url": "http://localhost:8080", "auth_username": "", "auth_password": ""}'
+
+
 # [[[presto]]]
 # name=Presto SQL
 # interface=presto

+ 6 - 0
desktop/conf/pseudo-distributed.ini.tmpl

@@ -1128,6 +1128,12 @@
     #   name=Shell
     #   interface=oozie
 
+    # [[[trino]]]
+    #   name=Trino
+    #   interface=trino
+    #   ## username and password for LDAP enabled over HTTPS.
+    #   options='{"url": "http://localhost:8080", "auth_username": "", "auth_password": ""}'
+
     # [[[presto]]]
     # name=Presto SQL
     # interface=presto

+ 21 - 13
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -33,6 +33,7 @@ from desktop.lib.rest.resource import Resource
 from notebook.connectors.base import Api, QueryError, ExecutionWrapper, ResultWrapper
 
 from trino import exceptions
+from trino.auth import BasicAuthentication
 from trino.client import ClientSession, TrinoRequest, TrinoQuery
 
 def query_error_handler(func):
@@ -59,19 +60,29 @@ class TrinoApi(Api):
     Api.__init__(self, user, interpreter=interpreter)
 
     self.options = interpreter['options']
-    
-    self.api_url = self.options['url']
-    hostname, port = self.get_hostname_and_port(self.api_url)
+
+    self.server_host, self.server_port, self.http_scheme = self.parse_api_url(self.options['url'])
+    self.auth = None
+
+    if self.options.get('auth_username') and self.options.get('auth_password'):
+      self.auth_username = self.options['auth_username']
+      self.auth_password = self.options['auth_password']
+      self.auth = BasicAuthentication(self.auth_username, self.auth_password)
+
     trino_session = ClientSession(user.username)
     
-    self.db = TrinoRequest(hostname, port, trino_session)
-
+    self.db = TrinoRequest(
+      host=self.server_host,
+      port=self.server_port,
+      client_session=trino_session,
+      http_scheme=self.http_scheme,
+      auth=self.auth
+    )
 
-  def get_hostname_and_port(self, api_url):
+  @query_error_handler
+  def parse_api_url(self, api_url):
     parsed_url = urlparse(api_url)
-    hostname = parsed_url.hostname
-    port = parsed_url.port
-    return hostname, port
+    return parsed_url.hostname, parsed_url.port, parsed_url.scheme
 
 
   @query_error_handler
@@ -271,14 +282,11 @@ class TrinoApi(Api):
 
   def _show_databases(self):
     catalogs = self._show_catalogs()
-    hostname, port = self.get_hostname_and_port(self.api_url)
     databases = []
 
     for catalog in catalogs:
-      trino_session = ClientSession(self.user.username, catalog)
-      trino_request = TrinoRequest(hostname, port, trino_session)
 
-      query_client = TrinoQuery(trino_request, 'SHOW SCHEMAS')
+      query_client = TrinoQuery(self.db, 'SHOW SCHEMAS FROM ' + catalog)
       response = query_client.execute()
       databases += [f'{catalog}.{item}' for sublist in response.rows for item in sublist]