浏览代码

HUE-9367 [phoenix] Adding PHOENIX-5938 Support impersonation in the python driver

Romain 5 年之前
父节点
当前提交
57a24ccc06

+ 57 - 20
desktop/core/ext-py/phoenixdb/phoenixdb/__init__.py

@@ -54,7 +54,7 @@ For example::
 
 
 
 
 def connect(url, max_retries=None, auth=None, authentication=None, avatica_user=None, avatica_password=None,
 def connect(url, max_retries=None, auth=None, authentication=None, avatica_user=None, avatica_password=None,
-            truststore=None, verify=None, **kwargs):
+            truststore=None, verify=None, do_as=None, user=None, password=None, **kwargs):
     """Connects to a Phoenix query server.
     """Connects to a Phoenix query server.
 
 
     :param url:
     :param url:
@@ -77,6 +77,14 @@ def connect(url, max_retries=None, auth=None, authentication=None, avatica_user=
         Authentication configuration object as expected by the underlying python_requests and
         Authentication configuration object as expected by the underlying python_requests and
         python_requests_gssapi library
         python_requests_gssapi library
 
 
+    :param verify:
+        The path to the PEM file for verifying the server's certificate. It is passed directly to
+        the `~verify` parameter of the underlying python_requests library.
+        Setting it to False disables the server certificate verification.
+
+    :param do_as:
+        Username to impersonate (sets the Hadoop doAs URL parameter)
+
     :param authentication:
     :param authentication:
         Alternative way to specify the authentication mechanism that mimics
         Alternative way to specify the authentication mechanism that mimics
         the semantics of the JDBC drirver
         the semantics of the JDBC drirver
@@ -89,10 +97,12 @@ def connect(url, max_retries=None, auth=None, authentication=None, avatica_user=
         Password for BASIC or DIGEST authentication. Use in conjunction with the
         Password for BASIC or DIGEST authentication. Use in conjunction with the
         `~authentication' option.
         `~authentication' option.
 
 
-    :param verify:
-        The path to the PEM file for verifying the server's certificate. It is passed directly to
-        the `~verify` parameter of the underlying python_requests library.
-        Setting it to false disables the server certificate verification.
+    :param user
+        If `~authentication' is BASIC or DIGEST then alias for `~avatica_user`
+        If `~authentication' is NONE or SPNEGO then alias for `~do_as`
+
+    :param password
+        If `~authentication' is BASIC or DIGEST then is alias for `~avatica_password`
 
 
     :param truststore:
     :param truststore:
         Alias for verify
         Alias for verify
@@ -101,33 +111,65 @@ def connect(url, max_retries=None, auth=None, authentication=None, avatica_user=
         :class:`~phoenixdb.connection.Connection` object.
         :class:`~phoenixdb.connection.Connection` object.
     """
     """
 
 
+    (url, auth, verify) = _process_args(
+        url, auth=auth, authentication=authentication,
+        avatica_user=avatica_user, avatica_password=avatica_password,
+        truststore=truststore, verify=verify, do_as=do_as, user=user, password=password)
+
+    client = AvaticaClient(url, max_retries=max_retries, auth=auth, verify=verify)
+    client.connect()
+    return Connection(client, **kwargs)
+
+
+def _process_args(
+        url, auth=None, authentication=None, avatica_user=None, avatica_password=None,
+        truststore=None, verify=None, do_as=None, user=None, password=None):
     url_parsed = urlparse(url)
     url_parsed = urlparse(url)
     url_params = parse_qs(url_parsed.query, keep_blank_values=True)
     url_params = parse_qs(url_parsed.query, keep_blank_values=True)
 
 
-    # Parse supported JDBC compatible options from URL. args have precendece
-    rebuild = False
+    # Parse supported JDBC compatible parameters from URL. args have precendece
+    # Unlike the JDBC driver, we are expecting these as query params, as the avatica java client
+    # has a different idea of what an URL param is than urlparse. (urlparse seems just broken
+    # in this regard)
+    params_changed = False
     if auth is None and authentication is None and 'authentication' in url_params:
     if auth is None and authentication is None and 'authentication' in url_params:
         authentication = url_params['authentication'][0]
         authentication = url_params['authentication'][0]
         del url_params['authentication']
         del url_params['authentication']
-        rebuild = True
+        params_changed = True
 
 
     if avatica_user is None and 'avatica_user' in url_params:
     if avatica_user is None and 'avatica_user' in url_params:
         avatica_user = url_params['avatica_user'][0]
         avatica_user = url_params['avatica_user'][0]
         del url_params['avatica_user']
         del url_params['avatica_user']
-        rebuild = True
+        params_changed = True
 
 
     if avatica_password is None and 'avatica_password' in url_params:
     if avatica_password is None and 'avatica_password' in url_params:
         avatica_password = url_params['avatica_password'][0]
         avatica_password = url_params['avatica_password'][0]
         del url_params['avatica_password']
         del url_params['avatica_password']
-        rebuild = True
+        params_changed = True
 
 
     if verify is None and truststore is None and 'truststore' in url_params:
     if verify is None and truststore is None and 'truststore' in url_params:
         truststore = url_params['truststore'][0]
         truststore = url_params['truststore'][0]
         del url_params['truststore']
         del url_params['truststore']
-        rebuild = True
-
-    if rebuild:
-        url_parsed._replace(query=urlencode(url_params, True))
+        params_changed = True
+
+    if authentication == 'BASIC' or authentication == 'DIGEST':
+        # Handle standard user and password parameters
+        if user is not None and avatica_user is None:
+            avatica_user = user
+        if password is not None and avatica_password is None:
+            avatica_password = password
+    else:
+        # interpret standard user parameter as do_as for SPNEGO and NONE
+        if user is not None and do_as is None:
+            do_as = user
+
+    # Add doAs
+    if do_as:
+        url_params['doAs'] = do_as
+        params_changed = True
+
+    if params_changed:
+        url_parsed = url_parsed._replace(query=urlencode(url_params))
         url = urlunparse(url_parsed)
         url = urlunparse(url_parsed)
 
 
     if auth == "SPNEGO":
     if auth == "SPNEGO":
@@ -144,9 +186,4 @@ def connect(url, max_retries=None, auth=None, authentication=None, avatica_user=
     if verify is None and truststore is not None:
     if verify is None and truststore is not None:
         verify = truststore
         verify = truststore
 
 
-    client = AvaticaClient(url, max_retries=max_retries,
-                           auth=auth,
-                           verify=verify
-                           )
-    client.connect()
-    return Connection(client, **kwargs)
+    return (url, auth, verify)

+ 4 - 1
desktop/core/ext-py/phoenixdb/phoenixdb/avatica/client.py

@@ -92,7 +92,10 @@ OPEN_CONNECTION_PROPERTIES = (
     'auth',
     'auth',
     'authentication',
     'authentication',
     'truststore',
     'truststore',
-    'verify'
+    'verify',
+    'do_as',
+    'user',
+    'password'
 )
 )
 
 
 
 

+ 18 - 5
desktop/core/ext-py/phoenixdb/phoenixdb/sqlalchemy_phoenix.py

@@ -12,6 +12,11 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
+import re
+import sys
+
+import phoenixdb
+
 from sqlalchemy import types
 from sqlalchemy import types
 from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
 from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
 from sqlalchemy.exc import CompileError
 from sqlalchemy.exc import CompileError
@@ -19,10 +24,6 @@ from sqlalchemy.sql.compiler import DDLCompiler
 from sqlalchemy.types import BIGINT, BOOLEAN, CHAR, DATE, DECIMAL, FLOAT, INTEGER, NUMERIC,\
 from sqlalchemy.types import BIGINT, BOOLEAN, CHAR, DATE, DECIMAL, FLOAT, INTEGER, NUMERIC,\
     SMALLINT, TIME, TIMESTAMP, VARBINARY, VARCHAR
     SMALLINT, TIME, TIMESTAMP, VARBINARY, VARCHAR
 
 
-import phoenixdb
-import re
-import sys
-
 if sys.version_info.major == 3:
 if sys.version_info.major == 3:
     from urllib.parse import urlunsplit, SplitResult, urlencode
     from urllib.parse import urlunsplit, SplitResult, urlencode
 else:
 else:
@@ -94,6 +95,13 @@ class PhoenixDialect(DefaultDialect):
     execution_ctx_cls = PhoenixExecutionContext
     execution_ctx_cls = PhoenixExecutionContext
 
 
     def __init__(self, tls=False, path='/', **opts):
     def __init__(self, tls=False, path='/', **opts):
+        '''
+        :param tls:
+            If True, then use https for connecting, otherwise use http
+
+        :param path:
+            The path component of the connection URL
+        '''
         # There is no way to pass these via the SqlAlchemy url object
         # There is no way to pass these via the SqlAlchemy url object
         self.tls = tls
         self.tls = tls
         self.path = path
         self.path = path
@@ -104,6 +112,11 @@ class PhoenixDialect(DefaultDialect):
         return phoenixdb
         return phoenixdb
 
 
     def create_connect_args(self, url):
     def create_connect_args(self, url):
+        connect_args = dict()
+        if url.username is not None:
+            connect_args['user'] = url.username
+            if url.password is not None:
+                connect_args['password'] = url.username
         phoenix_url = urlunsplit(SplitResult(
         phoenix_url = urlunsplit(SplitResult(
             scheme='https' if self.tls else 'http',
             scheme='https' if self.tls else 'http',
             netloc='{}:{}'.format(url.host, 8765 if url.port is None else url.port),
             netloc='{}:{}'.format(url.host, 8765 if url.port is None else url.port),
@@ -111,7 +124,7 @@ class PhoenixDialect(DefaultDialect):
             query=urlencode(url.query),
             query=urlencode(url.query),
             fragment='',
             fragment='',
         ))
         ))
-        return [phoenix_url], {}
+        return [phoenix_url], connect_args
 
 
     def has_table(self, connection, table_name, schema=None):
     def has_table(self, connection, table_name, schema=None):
         if schema is None:
         if schema is None:

+ 25 - 0
desktop/core/ext-py/phoenixdb/phoenixdb/tests/test_avatica.py

@@ -15,8 +15,11 @@
 
 
 import unittest
 import unittest
 
 
+import phoenixdb
 from phoenixdb.avatica.client import parse_url, urlparse
 from phoenixdb.avatica.client import parse_url, urlparse
 
 
+from requests.auth import HTTPBasicAuth
+
 
 
 class ParseUrlTest(unittest.TestCase):
 class ParseUrlTest(unittest.TestCase):
 
 
@@ -24,3 +27,25 @@ class ParseUrlTest(unittest.TestCase):
         self.assertEqual(urlparse.urlparse('http://localhost:8765/'), parse_url('localhost'))
         self.assertEqual(urlparse.urlparse('http://localhost:8765/'), parse_url('localhost'))
         self.assertEqual(urlparse.urlparse('http://localhost:2222/'), parse_url('localhost:2222'))
         self.assertEqual(urlparse.urlparse('http://localhost:2222/'), parse_url('localhost:2222'))
         self.assertEqual(urlparse.urlparse('http://localhost:2222/'), parse_url('http://localhost:2222/'))
         self.assertEqual(urlparse.urlparse('http://localhost:2222/'), parse_url('http://localhost:2222/'))
+
+    def test_url_params(self):
+        (url, auth, verify) = phoenixdb._process_args((
+            "https://localhost:8765?authentication=BASIC&"
+            "avatica_user=user&avatica_password=password&truststore=truststore"))
+        self.assertEqual("https://localhost:8765", url)
+        self.assertEqual("truststore", verify)
+        self.assertEqual(auth, HTTPBasicAuth("user", "password"))
+
+        (url, auth, verify) = phoenixdb._process_args(
+            "http://localhost:8765", authentication='BASIC', user='user', password='password',
+            truststore='truststore')
+        self.assertEqual("http://localhost:8765", url)
+        self.assertEqual("truststore", verify)
+        self.assertEqual(auth, HTTPBasicAuth("user", "password"))
+
+        (url, auth, verify) = phoenixdb._process_args(
+            "https://localhost:8765", authentication='SPNEGO', user='user', truststore='truststore')
+        self.assertEqual("https://localhost:8765?doAs=user", url)
+        self.assertEqual("truststore", verify)
+        # SPNEGO auth objects seem to have no working __eq__
+        # self.assertEqual(auth, HTTPSPNEGOAuth(opportunistic_auth=True))

+ 3 - 4
desktop/core/ext-py/phoenixdb/phoenixdb/tests/test_sqlalchemy.py

@@ -13,14 +13,14 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-import unittest
 import sys
 import sys
+import unittest
 
 
 import sqlalchemy as db
 import sqlalchemy as db
 from sqlalchemy import text
 from sqlalchemy import text
 
 
-from . import TEST_DB_URL, TEST_DB_AUTHENTICATION, TEST_DB_AVATICA_USER, TEST_DB_AVATICA_PASSWORD,\
-        TEST_DB_TRUSTSTORE
+from . import TEST_DB_AUTHENTICATION, TEST_DB_AVATICA_PASSWORD, TEST_DB_AVATICA_USER, \
+    TEST_DB_TRUSTSTORE, TEST_DB_URL
 
 
 if sys.version_info.major == 3:
 if sys.version_info.major == 3:
     from urllib.parse import urlparse, urlunparse
     from urllib.parse import urlparse, urlunparse
@@ -67,7 +67,6 @@ class SQLAlchemyTest(unittest.TestCase):
                 CONSTRAINT my_pk PRIMARY KEY (state, city))'''))
                 CONSTRAINT my_pk PRIMARY KEY (state, city))'''))
                 columns_result = inspector.get_columns('us_population')
                 columns_result = inspector.get_columns('us_population')
                 self.assertEqual(len(columns_result), 3)
                 self.assertEqual(len(columns_result), 3)
-                print(columns_result)
             finally:
             finally:
                 connection.execute('drop table if exists us_population')
                 connection.execute('drop table if exists us_population')