Browse Source

[Trino] Support ldap password script in trino (#3689)

Ayush Goyal 1 year ago
parent
commit
c435498704

+ 18 - 3
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -27,7 +27,9 @@ from urllib.parse import urlparse
 
 
 from beeswax import conf
 from beeswax import conf
 from beeswax import data_export
 from beeswax import data_export
+from desktop.conf import AUTH_USERNAME as DEFAULT_AUTH_USERNAME, AUTH_PASSWORD as DEFAULT_AUTH_PASSWORD
 from desktop.lib import export_csvxls
 from desktop.lib import export_csvxls
+from desktop.lib.conf import coerce_password_from_script
 from desktop.lib.i18n import force_unicode
 from desktop.lib.i18n import force_unicode
 from desktop.lib.rest.http_client import HttpClient, RestException
 from desktop.lib.rest.http_client import HttpClient, RestException
 from desktop.lib.rest.resource import Resource
 from desktop.lib.rest.resource import Resource
@@ -62,9 +64,12 @@ class TrinoApi(Api):
     self.server_host, self.server_port, self.http_scheme = self.parse_api_url(self.options['url'])
     self.server_host, self.server_port, self.http_scheme = self.parse_api_url(self.options['url'])
     self.auth = None
     self.auth = None
 
 
-    if self.options.get('auth_username') and self.options.get('auth_password'):
-      self.auth_username = self.options['auth_username']
-      self.auth_password = self.options['auth_password']
+    auth_username = self.options.get('auth_username', DEFAULT_AUTH_USERNAME.get())
+    auth_password = self.options.get('auth_password', self.get_auth_password())
+
+    if auth_username and auth_password:
+      self.auth_username = auth_username
+      self.auth_password = auth_password
       self.auth = BasicAuthentication(self.auth_username, self.auth_password)
       self.auth = BasicAuthentication(self.auth_username, self.auth_password)
 
 
     trino_session = ClientSession(user.username)
     trino_session = ClientSession(user.username)
@@ -76,6 +81,16 @@ class TrinoApi(Api):
       auth=self.auth
       auth=self.auth
     )
     )
 
 
+
+  def get_auth_password(self):
+    auth_password_script = self.options.get('auth_password_script')
+    return (
+        coerce_password_from_script(auth_password_script)
+        if auth_password_script
+        else DEFAULT_AUTH_PASSWORD.get()
+    )
+
+
   @query_error_handler
   @query_error_handler
   def parse_api_url(self, api_url):
   def parse_api_url(self, api_url):
     parsed_url = urlparse(api_url)
     parsed_url = urlparse(api_url)

+ 33 - 1
desktop/libs/notebook/src/notebook/connectors/trino_tests.py

@@ -34,7 +34,7 @@ class TestTrinoApi(unittest.TestCase):
     cls.user = User.objects.get(username="hue_test")
     cls.user = User.objects.get(username="hue_test")
     cls.interpreter = {
     cls.interpreter = {
       'options': {
       'options': {
-        'url': 'http://example.com:8080'
+        'url': 'https://example.com:8080'
       }
       }
     }
     }
     # Initialize TrinoApi with mock user and interpreter
     # Initialize TrinoApi with mock user and interpreter
@@ -295,3 +295,35 @@ class TestTrinoApi(unittest.TestCase):
       # Assert the exception message
       # Assert the exception message
       assert_equal(result['explanation'], 'Mocked exception')
       assert_equal(result['explanation'], 'Mocked exception')
 
 
+
+  @patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
+  @patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
+  def test_auth_username_and_auth_password_default(self, mock_default_username, mock_default_password):
+    trino_api = TrinoApi(self.user, interpreter=self.interpreter)
+
+    assert_equal(trino_api.auth_username, 'mocked_username')
+    assert_equal(trino_api.auth_password, 'mocked_password')
+
+
+  @patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
+  @patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
+  def test_auth_username_custom(self, mock_default_username, mock_default_password):
+    self.interpreter['options']['auth_username'] = 'custom_username'
+    self.interpreter['options']['auth_password'] = 'custom_password'
+    trino_api = TrinoApi(self.user, interpreter=self.interpreter)
+
+    assert_equal(trino_api.auth_username, 'custom_username')
+    assert_equal(trino_api.auth_password, 'custom_password')  
+
+  @patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
+  def test_auth_password_script(self, mock_default_password):
+    interpreter = {
+      'options': {
+        'url': 'https://example.com:8080',
+        'auth_password_script': 'custom_script'
+      }
+    }
+
+    with patch('notebook.connectors.trino.coerce_password_from_script', return_value='custom_password_script'):
+      trino_api = TrinoApi(self.user, interpreter=interpreter)
+      assert_equal(trino_api.auth_password, 'custom_password_script')