|
@@ -15,9 +15,10 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
+from unittest.mock import MagicMock, Mock, patch
|
|
|
|
|
+
|
|
|
import pytest
|
|
import pytest
|
|
|
from django.test import TestCase
|
|
from django.test import TestCase
|
|
|
-from unittest.mock import MagicMock, patch, Mock
|
|
|
|
|
|
|
|
|
|
from desktop.auth.backend import rewrite_user
|
|
from desktop.auth.backend import rewrite_user
|
|
|
from desktop.lib.django_test_util import make_logged_in_client
|
|
from desktop.lib.django_test_util import make_logged_in_client
|
|
@@ -40,6 +41,14 @@ class TestTrinoApi(TestCase):
|
|
|
# Initialize TrinoApi with mock user and interpreter
|
|
# Initialize TrinoApi with mock user and interpreter
|
|
|
cls.trino_api = TrinoApi(cls.user, interpreter=cls.interpreter)
|
|
cls.trino_api = TrinoApi(cls.user, interpreter=cls.interpreter)
|
|
|
|
|
|
|
|
|
|
+ def test_format_identifier(self):
|
|
|
|
|
+ test_cases = [
|
|
|
|
|
+ ("my_db", '"my_db"'),
|
|
|
|
|
+ ("my_db.table", '"my_db"."table"'),
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ for database, expected_output in test_cases:
|
|
|
|
|
+ assert self.trino_api._format_identifier(database) == expected_output
|
|
|
|
|
|
|
|
def test_parse_api_url(self):
|
|
def test_parse_api_url(self):
|
|
|
# Test parse_api_url method
|
|
# Test parse_api_url method
|
|
@@ -49,7 +58,6 @@ class TestTrinoApi(TestCase):
|
|
|
|
|
|
|
|
assert result == expected_result
|
|
assert result == expected_result
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_autocomplete_with_database(self):
|
|
def test_autocomplete_with_database(self):
|
|
|
with patch('notebook.connectors.trino.TrinoApi._show_databases') as _show_databases:
|
|
with patch('notebook.connectors.trino.TrinoApi._show_databases') as _show_databases:
|
|
|
_show_databases.return_value = [
|
|
_show_databases.return_value = [
|
|
@@ -62,7 +70,6 @@ class TestTrinoApi(TestCase):
|
|
|
assert (response['databases'] ==
|
|
assert (response['databases'] ==
|
|
|
[{'name': 'test_catalog1.test_db1'}, {'name': 'test_catalog2.test_db1'}, {'name': 'test_catalog2.test_db2'}])
|
|
[{'name': 'test_catalog1.test_db1'}, {'name': 'test_catalog2.test_db1'}, {'name': 'test_catalog2.test_db2'}])
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_autocomplete_with_database_and_table(self):
|
|
def test_autocomplete_with_database_and_table(self):
|
|
|
with patch('notebook.connectors.trino.TrinoApi._show_tables') as _show_tables:
|
|
with patch('notebook.connectors.trino.TrinoApi._show_tables') as _show_tables:
|
|
|
_show_tables.return_value = [
|
|
_show_tables.return_value = [
|
|
@@ -73,7 +80,7 @@ class TestTrinoApi(TestCase):
|
|
|
snippet = {}
|
|
snippet = {}
|
|
|
database = 'test_db1'
|
|
database = 'test_db1'
|
|
|
response = self.trino_api.autocomplete(snippet, database)
|
|
response = self.trino_api.autocomplete(snippet, database)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
assert 'tables_meta' in response # Check if 'table_meta' key exists in the response
|
|
assert 'tables_meta' in response # Check if 'table_meta' key exists in the response
|
|
|
assert (response['tables_meta'] ==
|
|
assert (response['tables_meta'] ==
|
|
|
[
|
|
[
|
|
@@ -82,7 +89,6 @@ class TestTrinoApi(TestCase):
|
|
|
{'name': 'test_table3', 'type': 'table', 'comment': ''}
|
|
{'name': 'test_table3', 'type': 'table', 'comment': ''}
|
|
|
])
|
|
])
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_autocomplete_with_database_table_and_column(self):
|
|
def test_autocomplete_with_database_table_and_column(self):
|
|
|
with patch('notebook.connectors.trino.TrinoApi._get_columns') as _get_columns:
|
|
with patch('notebook.connectors.trino.TrinoApi._get_columns') as _get_columns:
|
|
|
_get_columns.return_value = [
|
|
_get_columns.return_value = [
|
|
@@ -106,7 +112,6 @@ class TestTrinoApi(TestCase):
|
|
|
assert 'columns' in response # Check if 'columns' key exists in the response
|
|
assert 'columns' in response # Check if 'columns' key exists in the response
|
|
|
assert response['columns'] == ['test_column1', 'test_column2', 'test_column3']
|
|
assert response['columns'] == ['test_column1', 'test_column2', 'test_column3']
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_get_sample_data_success(self):
|
|
def test_get_sample_data_success(self):
|
|
|
with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
|
|
with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
|
|
|
# Mock TrinoQuery object and its execute method
|
|
# Mock TrinoQuery object and its execute method
|
|
@@ -124,7 +129,6 @@ class TestTrinoApi(TestCase):
|
|
|
assert (result['full_headers'] ==
|
|
assert (result['full_headers'] ==
|
|
|
[{'name': 'test_column1', 'type': 'string', 'comment': ''}, {'name': 'test_column2', 'type': 'string', 'comment': ''}])
|
|
[{'name': 'test_column1', 'type': 'string', 'comment': ''}, {'name': 'test_column2', 'type': 'string', 'comment': ''}])
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_check_status_available(self):
|
|
def test_check_status_available(self):
|
|
|
mock_trino_request = MagicMock()
|
|
mock_trino_request = MagicMock()
|
|
|
self.trino_api.trino_request = mock_trino_request
|
|
self.trino_api.trino_request = mock_trino_request
|
|
@@ -139,7 +143,6 @@ class TestTrinoApi(TestCase):
|
|
|
assert result['status'] == 'available'
|
|
assert result['status'] == 'available'
|
|
|
assert result['next_uri'] == 'http://url'
|
|
assert result['next_uri'] == 'http://url'
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_execute(self):
|
|
def test_execute(self):
|
|
|
with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
|
|
with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
|
|
|
# Mock TrinoQuery object and its methods
|
|
# Mock TrinoQuery object and its methods
|
|
@@ -147,7 +150,6 @@ class TestTrinoApi(TestCase):
|
|
|
mock_query_instance.query = "SELECT * FROM test_table"
|
|
mock_query_instance.query = "SELECT * FROM test_table"
|
|
|
mock_query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=[], columns=[])
|
|
mock_query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=[], columns=[])
|
|
|
|
|
|
|
|
-
|
|
|
|
|
mock_trino_request = MagicMock()
|
|
mock_trino_request = MagicMock()
|
|
|
self.trino_api.trino_request = mock_trino_request
|
|
self.trino_api.trino_request = mock_trino_request
|
|
|
|
|
|
|
@@ -207,7 +209,6 @@ class TestTrinoApi(TestCase):
|
|
|
}
|
|
}
|
|
|
assert result == expected_result
|
|
assert result == expected_result
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_fetch_result(self):
|
|
def test_fetch_result(self):
|
|
|
# Mock TrinoRequest object and its methods
|
|
# Mock TrinoRequest object and its methods
|
|
|
mock_trino_request = MagicMock()
|
|
mock_trino_request = MagicMock()
|
|
@@ -252,21 +253,20 @@ class TestTrinoApi(TestCase):
|
|
|
} for column in _columns],
|
|
} for column in _columns],
|
|
|
'type': 'table'
|
|
'type': 'table'
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
assert result == expected_result
|
|
assert result == expected_result
|
|
|
assert len(result['data']) == 6
|
|
assert len(result['data']) == 6
|
|
|
assert len(result['meta']) == 2
|
|
assert len(result['meta']) == 2
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_get_select_query(self):
|
|
def test_get_select_query(self):
|
|
|
# Test with specified database, table, and column
|
|
# Test with specified database, table, and column
|
|
|
- database = 'test_db'
|
|
|
|
|
- table = 'test_table'
|
|
|
|
|
|
|
+ database = '`test_schema.test_db`'
|
|
|
|
|
+ table = '`test_table`'
|
|
|
column = 'test_column'
|
|
column = 'test_column'
|
|
|
expected_statement = (
|
|
expected_statement = (
|
|
|
- "SELECT test_column\n"
|
|
|
|
|
- "FROM test_db.test_table\n"
|
|
|
|
|
- "LIMIT 100\n"
|
|
|
|
|
|
|
+ 'SELECT "test_column"\n'
|
|
|
|
|
+ 'FROM "test_schema"."test_db"."test_table"\n'
|
|
|
|
|
+ 'LIMIT 100\n'
|
|
|
)
|
|
)
|
|
|
assert (
|
|
assert (
|
|
|
self.trino_api._get_select_query(database, table, column) ==
|
|
self.trino_api._get_select_query(database, table, column) ==
|
|
@@ -276,38 +276,37 @@ class TestTrinoApi(TestCase):
|
|
|
database = 'test_db'
|
|
database = 'test_db'
|
|
|
table = 'test_table'
|
|
table = 'test_table'
|
|
|
expected_statement = (
|
|
expected_statement = (
|
|
|
- "SELECT *\n"
|
|
|
|
|
- "FROM test_db.test_table\n"
|
|
|
|
|
- "LIMIT 100\n"
|
|
|
|
|
|
|
+ 'SELECT *\n'
|
|
|
|
|
+ 'FROM "test_db"."test_table"\n'
|
|
|
|
|
+ 'LIMIT 100\n'
|
|
|
)
|
|
)
|
|
|
assert (
|
|
assert (
|
|
|
self.trino_api._get_select_query(database, table) ==
|
|
self.trino_api._get_select_query(database, table) ==
|
|
|
expected_statement)
|
|
expected_statement)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def test_explain(self):
|
|
def test_explain(self):
|
|
|
with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
|
|
with patch('notebook.connectors.trino.TrinoQuery') as TrinoQuery:
|
|
|
snippet = {'statement': 'SELECT * FROM tpch.sf1.partsupp LIMIT 100;', 'database': 'tpch.sf1'}
|
|
snippet = {'statement': 'SELECT * FROM tpch.sf1.partsupp LIMIT 100;', 'database': 'tpch.sf1'}
|
|
|
- output = [['Trino version: 432\nFragment 0 [SINGLE]\n Output layout: [partkey, suppkey, availqty, supplycost, comment]\n '\
|
|
|
|
|
- 'Output partitioning: SINGLE []\n Output[columnNames = [partkey, suppkey, availqty, supplycost, comment]]\n │ '\
|
|
|
|
|
- 'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n │ '\
|
|
|
|
|
- 'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n └─ Limit[count = 100]\n │ '\
|
|
|
|
|
- 'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n │ '\
|
|
|
|
|
- 'Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n └─ LocalExchange[partitioning = SINGLE]\n '\
|
|
|
|
|
- '│ Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n │ '\
|
|
|
|
|
- 'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n └─ RemoteSource[sourceFragmentIds = [1]]\n'\
|
|
|
|
|
- ' Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n\n'\
|
|
|
|
|
- 'Fragment 1 [SOURCE]\n Output layout: [partkey, suppkey, availqty, supplycost, comment]\n Output partitioning: SINGLE []\n'\
|
|
|
|
|
- ' LimitPartial[count = 100]\n │ Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '\
|
|
|
|
|
- 'comment:varchar(199)]\n │ Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n └─ '\
|
|
|
|
|
- 'TableScan[table = tpch:sf1:partsupp]\n Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '\
|
|
|
|
|
- 'comment:varchar(199)]\n Estimates: {rows: 800000 (122.44MB), cpu: 122.44M, memory: 0B, network: 0B}\n '\
|
|
|
|
|
- 'partkey := tpch:partkey\n availqty := tpch:availqty\n supplycost := tpch:supplycost\n '\
|
|
|
|
|
|
|
+ output = [['Trino version: 432\nFragment 0 [SINGLE]\n Output layout: [partkey, suppkey, availqty, supplycost, comment]\n '
|
|
|
|
|
+ 'Output partitioning: SINGLE []\n Output[columnNames = [partkey, suppkey, availqty, supplycost, comment]]\n │ '
|
|
|
|
|
+ 'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n │ '
|
|
|
|
|
+ 'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n └─ Limit[count = 100]\n │ '
|
|
|
|
|
+ 'Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n │ '
|
|
|
|
|
+ 'Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n └─ LocalExchange[partitioning = SINGLE]\n '
|
|
|
|
|
+ '│ Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n │ '
|
|
|
|
|
+ 'Estimates: {rows: 100 (15.67kB), cpu: 0, memory: 0B, network: 0B}\n └─ RemoteSource[sourceFragmentIds = [1]]\n'
|
|
|
|
|
+ ' Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, comment:varchar(199)]\n\n'
|
|
|
|
|
+ 'Fragment 1 [SOURCE]\n Output layout: [partkey, suppkey, availqty, supplycost, comment]\n Output partitioning: SINGLE []\n'
|
|
|
|
|
+ ' LimitPartial[count = 100]\n │ Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '
|
|
|
|
|
+ 'comment:varchar(199)]\n │ Estimates: {rows: 100 (15.67kB), cpu: 15.67k, memory: 0B, network: 0B}\n └─ '
|
|
|
|
|
+ 'TableScan[table = tpch:sf1:partsupp]\n Layout: [partkey:bigint, suppkey:bigint, availqty:integer, supplycost:double, '
|
|
|
|
|
+ 'comment:varchar(199)]\n Estimates: {rows: 800000 (122.44MB), cpu: 122.44M, memory: 0B, network: 0B}\n '
|
|
|
|
|
+ 'partkey := tpch:partkey\n availqty := tpch:availqty\n supplycost := tpch:supplycost\n '
|
|
|
'comment := tpch:comment\n suppkey := tpch:suppkey\n\n']]
|
|
'comment := tpch:comment\n suppkey := tpch:suppkey\n\n']]
|
|
|
# Mock TrinoQuery object and its execute method
|
|
# Mock TrinoQuery object and its execute method
|
|
|
query_instance = TrinoQuery.return_value
|
|
query_instance = TrinoQuery.return_value
|
|
|
query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=output, columns=[])
|
|
query_instance.execute.return_value = MagicMock(next_uri=None, id='123', rows=output, columns=[])
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
# Call the explain method
|
|
# Call the explain method
|
|
|
result = self.trino_api.explain(notebook=None, snippet=snippet)
|
|
result = self.trino_api.explain(notebook=None, snippet=snippet)
|
|
|
|
|
|
|
@@ -325,7 +324,6 @@ class TestTrinoApi(TestCase):
|
|
|
# Assert the exception message
|
|
# Assert the exception message
|
|
|
assert result['explanation'] == 'Mocked exception'
|
|
assert result['explanation'] == 'Mocked exception'
|
|
|
|
|
|
|
|
-
|
|
|
|
|
@patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
|
|
@patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
|
|
|
@patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
|
|
@patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
|
|
|
def test_auth_username_and_auth_password_default(self, mock_default_username, mock_default_password):
|
|
def test_auth_username_and_auth_password_default(self, mock_default_username, mock_default_password):
|
|
@@ -334,7 +332,6 @@ class TestTrinoApi(TestCase):
|
|
|
assert trino_api.auth_username == 'mocked_username'
|
|
assert trino_api.auth_username == 'mocked_username'
|
|
|
assert trino_api.auth_password == 'mocked_password'
|
|
assert trino_api.auth_password == 'mocked_password'
|
|
|
|
|
|
|
|
-
|
|
|
|
|
@patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
|
|
@patch('notebook.connectors.trino.DEFAULT_AUTH_USERNAME.get', return_value='mocked_username')
|
|
|
@patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
|
|
@patch('notebook.connectors.trino.DEFAULT_AUTH_PASSWORD.get', return_value='mocked_password')
|
|
|
def test_auth_username_custom(self, mock_default_username, mock_default_password):
|
|
def test_auth_username_custom(self, mock_default_username, mock_default_password):
|