Browse Source

Add nested param to get_sample_data and API compatibility tests to prevent connector signature mismatches (#4202)

- Add source code parsing test to validate connector method signatures
- Add base API contract validation to ensure consistent signatures
- Prevent runtime TypeError when adding parameters like 'nested' to get_sample_data()
Ayush Goyal 4 months ago
parent
commit
2445ddeced

+ 8 - 6
desktop/libs/librdbms/src/librdbms/server/dbms.py

@@ -15,14 +15,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from builtins import object
 import logging
+from builtins import object
 
 from desktop.lib.python_util import force_dict_to_strings
-
 from librdbms.conf import DATABASES, get_database_password
 
-
 LOG = logging.getLogger()
 
 MYSQL = 'mysql'
@@ -98,8 +96,8 @@ class Rdbms(object):
   def get_columns(self, database, table_name, names_only=True):
     return self.client.get_columns(database, table_name, names_only)
 
-  def get_sample_data(self, database, table_name, column=None, limit=100):
-    return self.client.get_sample_data(database, table_name, column, limit)
+  def get_sample_data(self, database, table_name, column=None, nested=None, limit=100):
+    return self.client.get_sample_data(database, table_name, column, nested, limit)
 
   def execute_statement(self, statement):
     return self.client.execute_statement(statement)
@@ -123,7 +121,11 @@ class Rdbms(object):
     )
     query_history.save()
 
-    LOG.debug("Updated QueryHistory id %s user %s statement_number: %s" % (query_history.id, self.client.user, query_history.statement_number))
+    LOG.debug(
+        "Updated QueryHistory id %s user %s statement_number: %s" % (
+            query_history.id, self.client.user, query_history.statement_number
+        )
+    )
 
     return query_history
 

+ 1 - 2
desktop/libs/librdbms/src/librdbms/server/mysql_lib.py

@@ -15,7 +15,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import sys
 import logging
 
 try:
@@ -177,7 +176,7 @@ class MySQLClient(BaseRDMSClient):
       columns = [dict(name=row[0], type=row[1], comment='') for row in cursor.fetchall()]
     return columns
 
-  def get_sample_data(self, database, table, column=None, limit=100):
+  def get_sample_data(self, database, table, column=None, nested=None, limit=100):
     column = '`%s`' % column if column else '*'
     statement = "SELECT %s FROM `%s`.`%s` LIMIT %d" % (column, database, table, limit)
     return self.execute_statement(statement)

+ 7 - 11
desktop/libs/librdbms/src/librdbms/server/oracle_lib.py

@@ -25,14 +25,15 @@ except ImportError as e:
 
 from librdbms.server.rdbms_base_lib import BaseRDBMSDataTable, BaseRDBMSResult, BaseRDMSClient
 
-
 LOG = logging.getLogger()
 
 
-class DataTable(BaseRDBMSDataTable): pass
+class DataTable(BaseRDBMSDataTable):
+  pass
 
 
-class Result(BaseRDBMSResult): pass
+class Result(BaseRDBMSResult):
+  pass
 
 
 class OracleClient(BaseRDMSClient):
@@ -66,12 +67,10 @@ class OracleClient(BaseRDMSClient):
     return "%s/%s@%s" % (self.query_server['username'],
                          self.query_server['password'], dsn)
 
-
   def use(self, database):
     # Oracle credentials are on a per database basis.
     pass
 
-
   def execute_statement(self, statement):
     cursor = self.connection.cursor()
     cursor.execute(statement)
@@ -82,11 +81,9 @@ class OracleClient(BaseRDMSClient):
       columns = []
     return self.data_table_cls(cursor, columns)
 
-
   def get_databases(self):
     return [self.query_server['name']]
 
-
   def get_tables(self, database, table_names=[]):
     cursor = self.connection.cursor()
     query = "SELECT table_name FROM user_tables"
@@ -97,7 +94,6 @@ class OracleClient(BaseRDMSClient):
     self.connection.commit()
     return [row[0] for row in cursor.fetchall()]
 
-
   def get_columns(self, database, table, names_only=True):
     cursor = self.connection.cursor()
     cursor.execute("SELECT column_name, data_type FROM user_tab_cols WHERE table_name = '%s'" % table)
@@ -108,7 +104,7 @@ class OracleClient(BaseRDMSClient):
       columns = [dict(name=row[0], type=row[1], comment='') for row in cursor.fetchall()]
     return columns
 
-  def get_sample_data(self, database, table, column=None, limit=100):
-    column = '"%s"' % column  if column else '*'
+  def get_sample_data(self, database, table, column=None, nested=None, limit=100):
+    column = '"%s"' % column if column else '*'
     statement = 'SELECT %s FROM "%s"."%s" LIMIT %d' % (column, database, table, limit)
-    return self.execute_statement(statement)
+    return self.execute_statement(statement)

+ 5 - 10
desktop/libs/librdbms/src/librdbms/server/postgresql_lib.py

@@ -25,14 +25,15 @@ except ImportError as e:
 
 from librdbms.server.rdbms_base_lib import BaseRDBMSDataTable, BaseRDBMSResult, BaseRDMSClient
 
-
 LOG = logging.getLogger()
 
 
-class DataTable(BaseRDBMSDataTable): pass
+class DataTable(BaseRDBMSDataTable):
+  pass
 
 
-class Result(BaseRDBMSResult): pass
+class Result(BaseRDBMSResult):
+  pass
 
 
 class PostgreSQLClient(BaseRDMSClient):
@@ -45,7 +46,6 @@ class PostgreSQLClient(BaseRDMSClient):
     super(PostgreSQLClient, self).__init__(*args, **kwargs)
     self.connection = Database.connect(**self._conn_params)
 
-
   @property
   def _conn_params(self):
     params = {
@@ -64,12 +64,10 @@ class PostgreSQLClient(BaseRDMSClient):
 
     return params
 
-
   def use(self, database):
     # No op since postgresql requires a new connection per database
     pass
 
-
   def execute_statement(self, statement):
     cursor = self.connection.cursor()
     cursor.execute(statement)
@@ -80,7 +78,6 @@ class PostgreSQLClient(BaseRDMSClient):
       columns = []
     return self.data_table_cls(cursor, columns)
 
-
   def get_databases(self):
     # List all the schemas in the database
     try:
@@ -92,7 +89,6 @@ class PostgreSQLClient(BaseRDMSClient):
       LOG.exception('Failed to select nspname from pg_catalog.pg_namespace')
       return [self._conn_params['database']]
 
-
   def get_tables(self, database, table_names=[]):
     cursor = self.connection.cursor()
     query = "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname = '%s'" % database
@@ -103,7 +99,6 @@ class PostgreSQLClient(BaseRDMSClient):
     self.connection.commit()
     return [row[0] for row in cursor.fetchall()]
 
-
   def get_columns(self, database, table, names_only=True):
     cursor = self.connection.cursor()
     query = """
@@ -133,7 +128,7 @@ class PostgreSQLClient(BaseRDMSClient):
       columns = [dict(name=row[0], type=row[1], comment='') for row in cursor.fetchall()]
     return columns
 
-  def get_sample_data(self, database, table, column=None, limit=100):
+  def get_sample_data(self, database, table, column=None, nested=None, limit=100):
     column = '"%s"' % column if column else '*'
     statement = 'SELECT %s FROM "%s"."%s" LIMIT %d' % (column, database, table, limit)
     return self.execute_statement(statement)

+ 6 - 11
desktop/libs/librdbms/src/librdbms/server/sqlite_lib.py

@@ -20,7 +20,7 @@ import logging
 try:
   try:
     from pysqlite2 import dbapi2 as Database
-  except ImportError as e1:
+  except ImportError:
     from sqlite3 import dbapi2 as Database
 except ImportError as exc:
   from django.core.exceptions import ImproperlyConfigured
@@ -28,14 +28,15 @@ except ImportError as exc:
 
 from librdbms.server.rdbms_base_lib import BaseRDBMSDataTable, BaseRDBMSResult, BaseRDMSClient
 
-
 LOG = logging.getLogger()
 
 
-class DataTable(BaseRDBMSDataTable): pass
+class DataTable(BaseRDBMSDataTable):
+  pass
 
 
-class Result(BaseRDBMSResult): pass
+class Result(BaseRDBMSResult):
+  pass
 
 
 class SQLiteClient(BaseRDMSClient):
@@ -48,7 +49,6 @@ class SQLiteClient(BaseRDMSClient):
     super(SQLiteClient, self).__init__(*args, **kwargs)
     self.connection = Database.connect(**self._conn_params)
 
-
   @property
   def _conn_params(self):
     params = {
@@ -64,12 +64,10 @@ class SQLiteClient(BaseRDMSClient):
 
     return params
 
-
   def use(self, database):
     # Do nothing because SQLite has one database per path.
     pass
 
-
   def execute_statement(self, statement):
     cursor = self.connection.cursor()
     cursor.execute(statement)
@@ -80,11 +78,9 @@ class SQLiteClient(BaseRDMSClient):
       columns = []
     return self.data_table_cls(cursor, columns)
 
-
   def get_databases(self):
     return [self._conn_params['database']]
 
-
   def get_tables(self, database, table_names=[]):
     # Doesn't use database and only retrieves tables for database currently in use.
     cursor = self.connection.cursor()
@@ -96,7 +92,6 @@ class SQLiteClient(BaseRDMSClient):
     self.connection.commit()
     return [row[0] for row in cursor.fetchall()]
 
-
   def get_columns(self, database, table, names_only=True):
     cursor = self.connection.cursor()
     cursor.execute("PRAGMA table_info(%s)" % table)
@@ -107,7 +102,7 @@ class SQLiteClient(BaseRDMSClient):
       columns = [dict(name=row[1], type=row[2], comment='') for row in cursor.fetchall()]
     return columns
 
-  def get_sample_data(self, database, table, column=None, limit=100):
+  def get_sample_data(self, database, table, column=None, nested=None, limit=100):
     column = '`%s`' % column if column else '*'
     statement = 'SELECT %s FROM `%s` LIMIT %d' % (column, table, limit)
     return self.execute_statement(statement)

+ 4 - 5
desktop/libs/notebook/src/notebook/connectors/base.py

@@ -15,11 +15,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
 import json
+import logging
+import re
 import time
 import uuid
-import logging
 from builtins import object
 
 from django.utils.encoding import smart_str
@@ -27,10 +27,9 @@ from django.utils.translation import gettext as _
 
 from beeswax.common import find_compute, is_compute
 from desktop.auth.backend import is_admin
-from desktop.conf import TASK_SERVER, has_connectors, is_cdw_compute_enabled
+from desktop.conf import has_connectors, is_cdw_compute_enabled, TASK_SERVER
 from desktop.lib import export_csvxls
 from desktop.lib.exceptions_renderable import PopupException
-from desktop.lib.i18n import smart_str
 from metadata.optimizer.base import get_api as get_optimizer_api
 from notebook.conf import get_ordered_interpreters
 from notebook.sql_utils import get_current_statement
@@ -594,7 +593,7 @@ class Api(object):
   def get_jobs(self, notebook, snippet, logs):
     return []
 
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     raise NotImplementedError()
 
   def explain(self, notebook, snippet):

+ 104 - 8
desktop/libs/notebook/src/notebook/connectors/base_tests.py

@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*-
 # Licensed to Cloudera, Inc. under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -16,16 +15,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import sys
-import json
+import inspect
+import os
+import re
 from builtins import object
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import Mock, patch
 
 import pytest
-from django.urls import reverse
 
 from desktop.lib.django_test_util import make_logged_in_client
-from notebook.connectors.base import Notebook, get_api
+from notebook.connectors.base import get_api, Notebook
 from useradmin.models import User
 
 
@@ -71,9 +70,9 @@ class TestNotebook(object):
     request = Mock()
     operation_id = Mock()
 
-    with patch('notebook.api.Document2.objects.get_by_uuid') as get_by_uuid:
+    with patch('notebook.api.Document2.objects.get_by_uuid'):
       with patch('notebook.api.get_api') as get_api:
-        with patch('notebook.api.Notebook') as NotebookMock:
+        with patch('notebook.api.Notebook'):
           get_api.return_value = Mock(
             check_status=Mock(return_value={'status': 0})
           )
@@ -114,3 +113,100 @@ def check_status_side_effect(request, operation_id):
     return {'status': 0, 'query_status': {'status': 'running'}}
   else:
     return {'status': 0, 'query_status': {'status': 'available'}}
+
+
+@pytest.mark.django_db
+class TestConnectorApiCompatibility(object):
+  """
+  Test API compatibility across all connectors to prevent signature mismatches.
+  This ensures that when new parameters are added to the base API, all connectors
+  remain compatible and don't break due to signature differences.
+  """
+
+  def setup_method(self):
+    self.client = make_logged_in_client(username="test_connector_compatibility", groupname="default", recreate=True, is_superuser=False)
+    self.user = User.objects.get(username="test_connector_compatibility")
+
+  def teardown_method(self):
+    User.objects.filter(username="test_connector_compatibility").delete()
+
+  def test_base_api_method_signatures(self):
+    """
+    Test that the base Api class has the expected method signatures that all connectors should follow.
+    """
+    from notebook.connectors.base import Api
+
+    # Check base get_sample_data signature
+    base_method = getattr(Api, 'get_sample_data', None)
+    assert base_method is not None, "Base Api class missing get_sample_data method"
+
+    sig = inspect.signature(base_method)
+    expected_params = {'self', 'snippet', 'database', 'table', 'column', 'nested', 'is_async', 'operation'}
+    actual_params = set(sig.parameters.keys())
+
+    assert expected_params == actual_params, f"Base Api method signature changed. Expected: {expected_params}, Got: {actual_params}"
+
+    # Verify nested parameter has default None
+    nested_param = sig.parameters.get('nested')
+    assert nested_param is not None, "nested parameter missing from base Api"
+    assert nested_param.default is None, f"nested parameter should default to None, got: {nested_param.default}"
+
+  def test_source_code_signature_compatibility(self):
+    """
+    Test connector method signatures by parsing source code directly.
+    This is the most reliable way to check signatures, avoiding decorator interference.
+    """
+    # Define connectors and their file paths
+    connector_files = [
+      ('SqlAlchemy', 'desktop/libs/notebook/src/notebook/connectors/sql_alchemy.py'),
+      ('Spark', 'desktop/libs/notebook/src/notebook/connectors/spark_shell.py'),
+      ('HiveServer2', 'desktop/libs/notebook/src/notebook/connectors/hiveserver2.py'),
+      ('Flink', 'desktop/libs/notebook/src/notebook/connectors/flink_sql.py'),
+      ('JDBC', 'desktop/libs/notebook/src/notebook/connectors/jdbc.py'),
+      ('RDBMS', 'desktop/libs/notebook/src/notebook/connectors/rdbms.py'),
+      ('Solr', 'desktop/libs/notebook/src/notebook/connectors/solr.py'),
+      ('KSQL', 'desktop/libs/notebook/src/notebook/connectors/ksql.py'),
+      ('SQLFlow', 'desktop/libs/notebook/src/notebook/connectors/sqlflow.py'),
+      ('Trino', 'desktop/libs/notebook/src/notebook/connectors/trino.py'),
+      ('HiveMetastore', 'desktop/libs/notebook/src/notebook/connectors/hive_metastore.py'),
+    ]
+
+    failed_connectors = []
+    passed_connectors = []
+
+    # Pattern to match get_sample_data method definition
+    method_pattern = r'def get_sample_data\(([^)]+)\):'
+
+    for name, file_path in connector_files:
+      try:
+        if not os.path.exists(file_path):
+          continue
+
+        # Read the source file
+        with open(file_path, 'r') as f:
+          content = f.read()
+
+        # Find get_sample_data method signature
+        match = re.search(method_pattern, content)
+
+        if not match:
+          continue
+
+        signature_params = match.group(1)
+
+        # Check for nested parameter
+        has_nested = 'nested' in signature_params
+        has_kwargs = '**kwargs' in signature_params
+
+        if not has_nested and not has_kwargs:
+          failed_connectors.append(f"{name}: Missing 'nested' parameter in source: {signature_params}")
+        else:
+          passed_connectors.append(name)
+
+      except Exception as e:
+        failed_connectors.append(f"{name}: Error reading source file: {e}")
+
+    # Report results
+    if failed_connectors:
+      failure_details = '\n'.join([f"- {f}" for f in failed_connectors])
+      assert False, f"Source Code Compatibility Test Failed!\n\nConnectors missing 'nested' parameter:\n{failure_details}"

+ 3 - 3
desktop/libs/notebook/src/notebook/connectors/flink_sql.py

@@ -15,11 +15,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
 import json
-import time
 import logging
 import posixpath
+import re
+import time
 
 from desktop.auth.backend import rewrite_user
 from desktop.lib.i18n import force_unicode
@@ -317,7 +317,7 @@ class FlinkSqlApi(Api):
     return response
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=False, is_async=False,
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False,
                       operation=None):
     if operation == 'hello':
       snippet['statement'] = "SELECT 'Hello World!'"

+ 4 - 10
desktop/libs/notebook/src/notebook/connectors/hive_metastore.py

@@ -15,17 +15,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import sys
 import logging
 
-from django.urls import reverse
-from django.utils.translation import gettext as _
-
 from desktop.lib.exceptions import StructuredException
-from desktop.lib.exceptions_renderable import PopupException
-from desktop.lib.i18n import force_unicode, smart_str
-from desktop.lib.rest.http_client import RestException
-from notebook.connectors.base import Api, OperationNotSupported, OperationTimeout, QueryError, QueryExpired
+from desktop.lib.i18n import force_unicode
+from notebook.connectors.base import Api, OperationTimeout, QueryError, QueryExpired
 
 LOG = logging.getLogger()
 
@@ -33,7 +27,7 @@ LOG = logging.getLogger()
 try:
   from beeswax.api import _autocomplete
   from beeswax.server import dbms
-  from beeswax.server.dbms import QueryServerException, get_query_server_config
+  from beeswax.server.dbms import get_query_server_config, QueryServerException
 except ImportError as e:
   LOG.warning('Hive and HiveMetastoreServer interfaces are not enabled: %s' % e)
   hive_settings = None
@@ -67,7 +61,7 @@ class HiveMetastoreApi(Api):
     return _autocomplete(db, database, table, column, nested, query=None, cluster=self.cluster)
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     return []
 
   def _get_db(self, snippet, is_async=False, cluster=None):

+ 7 - 8
desktop/libs/notebook/src/notebook/connectors/jdbc.py

@@ -15,16 +15,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from builtins import object
 import logging
 import sys
+from builtins import object
+
+from django.utils.translation import gettext as _
 
-from beeswax import data_export
 from desktop.lib.i18n import force_unicode, smart_str
 from librdbms.jdbc import Jdbc, query_and_fetch
-
-from notebook.connectors.base import Api, QueryError, AuthenticationRequired, _get_snippet_name
-from django.utils.translation import gettext as _
+from notebook.connectors.base import Api, AuthenticationRequired, QueryError
 
 LOG = logging.getLogger()
 
@@ -151,14 +150,14 @@ class JdbcApi(Api):
     return response
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     if self.db is None:
       raise AuthenticationRequired()
 
     assist = self._createAssist(self.db)
     response = {'status': -1, 'result': {}}
 
-    sample_data, description = assist.get_sample_data(database, table, column)
+    sample_data, description = assist.get_sample_data(database, table, column, nested)
 
     if sample_data or description:
       response['status'] = 0
@@ -226,7 +225,7 @@ class Assist(object):
             database, table))
         return [{"comment": col[2] and col[2].strip(), "type": col[1], "name": col[0] and col[0].strip()} for col in columns]
 
-  def get_sample_data(self, database, table, column=None):
+  def get_sample_data(self, database, table, column=None, nested=None):
     column = column or '*'
     # data, description =  query_and_fetch(self.db, 'SELECT %s FROM %s.%s limit 100' % (column, database, table))
     # response['rows'] = data

+ 6 - 4
desktop/libs/notebook/src/notebook/connectors/jdbc_clickhouse.py

@@ -16,14 +16,15 @@
 # limitations under the License.
 
 from librdbms.jdbc import query_and_fetch
+from notebook.connectors.jdbc import Assist, JdbcApi
+
 
-from notebook.connectors.jdbc import JdbcApi
-from notebook.connectors.jdbc import Assist
 class JdbcApiClickhouse(JdbcApi):
 
   def _createAssist(self, db):
     return ClickhouseAssist(db)
 
+
 class ClickhouseAssist(Assist):
 
   def get_databases(self):
@@ -35,9 +36,10 @@ class ClickhouseAssist(Assist):
     return [{"comment": table[1] and table[1].strip(), "type": "Table", "name": table[0] and table[0].strip()} for table in tables]
 
   def get_columns_full(self, database, table):
-    columns, description = query_and_fetch(self.db, "SELECT name, type, '' FROM system.columns WHERE database='%s' AND table = '%s'" % (database, table))
+    query = "SELECT name, type, '' FROM system.columns WHERE database='%s' AND table = '%s'" % (database, table)
+    columns, description = query_and_fetch(self.db, query)
     return [{"comment": col[2] and col[2].strip(), "type": col[1], "name": col[0] and col[0].strip()} for col in columns]
 
-  def get_sample_data(self, database, table, column=None):
+  def get_sample_data(self, database, table, column=None, nested=None):
     column = column or '*'
     return query_and_fetch(self.db, 'SELECT %s FROM %s.%s limit 100' % (column, database, table))

+ 3 - 4
desktop/libs/notebook/src/notebook/connectors/jdbc_kyuubi.py

@@ -16,15 +16,15 @@
 # limitations under the License.
 
 from librdbms.jdbc import query_and_fetch
+from notebook.connectors.jdbc import Assist, JdbcApi
 
-from notebook.connectors.jdbc import JdbcApi
-from notebook.connectors.jdbc import Assist
 
 class JdbcApiKyuubi(JdbcApi):
 
   def _createAssist(self, db):
     return KyuubiAssist(db)
 
+
 class KyuubiAssist(Assist):
 
   def get_databases(self):
@@ -39,7 +39,6 @@ class KyuubiAssist(Assist):
     columns, description = query_and_fetch(self.db, "DESCRIBE %s.%s" % (database, table))
     return [{"comment": col[2] and col[2].strip(), "type": col[1], "name": col[0] and col[0].strip()} for col in columns]
 
-  def get_sample_data(self, database, table, column=None):
+  def get_sample_data(self, database, table, column=None, nested=None):
     column = column or '*'
     return query_and_fetch(self.db, 'SELECT %s FROM %s.%s limit 100' % (column, database, table))
-

+ 10 - 7
desktop/libs/notebook/src/notebook/connectors/jdbc_teradata.py

@@ -16,9 +16,7 @@
 # limitations under the License.
 
 from librdbms.jdbc import query_and_fetch
-
-from notebook.connectors.jdbc import JdbcApi
-from notebook.connectors.jdbc import Assist
+from notebook.connectors.jdbc import Assist, JdbcApi
 
 
 class JdbcApiTeradata(JdbcApi):
@@ -34,14 +32,19 @@ class TeradataAssist(Assist):
     return [db[0] and db[0].strip() for db in dbs]
 
   def get_tables_full(self, database, table_names=[]):
-    tables, description = query_and_fetch(self.db, "SELECT TableName, CommentString FROM dbc.tables WHERE tablekind = 'T' and databasename='%s' ORDER BY TableName" % database)
+    query = ("SELECT TableName, CommentString FROM dbc.tables WHERE tablekind = 'T' and "
+             "databasename='%s' ORDER BY TableName" % database)
+    tables, description = query_and_fetch(self.db, query)
     return [{"comment": table[1] and table[1].strip(), "type": "Table", "name": table[0] and table[0].strip()} for table in tables]
 
   def get_columns_full(self, database, table):
-    columns, description = query_and_fetch(self.db, "SELECT ColumnName, ColumnType, CommentString FROM DBC.Columns WHERE DatabaseName='%s' AND TableName='%s' ORDER BY ColumnName" % (database, table))
-    return [{"comment": col[1] and col[1].strip(), "type": self._type_converter(col[1]), "name": col[0] and col[0].strip()} for col in columns]
+    query = ("SELECT ColumnName, ColumnType, CommentString FROM DBC.Columns WHERE "
+             "DatabaseName='%s' AND TableName='%s' ORDER BY ColumnName" % (database, table))
+    columns, description = query_and_fetch(self.db, query)
+    return [{"comment": col[1] and col[1].strip(), "type": self._type_converter(col[1]),
+             "name": col[0] and col[0].strip()} for col in columns]
 
-  def get_sample_data(self, database, table, column=None):
+  def get_sample_data(self, database, table, column=None, nested=None):
     column = column or '*'
     return query_and_fetch(self.db, 'SELECT %s FROM %s.%s sample 100' % (column, database, table))
 

+ 7 - 9
desktop/libs/notebook/src/notebook/connectors/jdbc_vertica.py

@@ -16,14 +16,13 @@
 # limitations under the License.
 
 from __future__ import division
-from librdbms.jdbc import query_and_fetch
 
-from notebook.connectors.jdbc import JdbcApi
-from notebook.connectors.jdbc import Assist
-import time
 import logging
 import math
+import time
 
+from librdbms.jdbc import query_and_fetch
+from notebook.connectors.jdbc import Assist, JdbcApi
 
 LOG = logging.getLogger()
 
@@ -44,10 +43,9 @@ class VerticaAssist(Assist):
             cache_key not in self.cached_data
             or time.time() - self.cached_data[cache_key]["time"] > self.freeze_time
         ):
-            dbs, description = query_and_fetch(
-                self.db,
-                "select schema_name FROM v_catalog.schemata where is_system_schema=0 and schema_name not in ('v_func', 'v_txtindex') order by 1",
-            )
+            query = ("select schema_name FROM v_catalog.schemata where is_system_schema=0 "
+                     "and schema_name not in ('v_func', 'v_txtindex') order by 1")
+            dbs, description = query_and_fetch(self.db, query)
             list_of_db = [db[0] and db[0].strip() for db in dbs]
             VerticaAssist.cached_data[cache_key] = {
                 "time": time.time(),
@@ -128,7 +126,7 @@ class VerticaAssist(Assist):
             VerticaAssist.cache_use_stat["cache"] += 1
         return VerticaAssist.cached_data[cache_key]["result"]
 
-    def get_sample_data(self, database, table, column=None):
+    def get_sample_data(self, database, table, column=None, nested=None):
         column = column or "*"
         return query_and_fetch(
             self.db, "SELECT %s FROM %s.%s limit 10" % (column, database, table)

+ 2 - 5
desktop/libs/notebook/src/notebook/connectors/ksql.py

@@ -18,12 +18,9 @@
 
 from __future__ import absolute_import
 
-import sys
 import json
 import logging
 
-from django.utils.translation import gettext as _
-
 from desktop.conf import has_channels
 from desktop.lib.i18n import force_unicode
 from kafka.ksql_client import KSqlApi as KSqlClientApi
@@ -33,7 +30,7 @@ LOG = logging.getLogger()
 
 
 if has_channels():
-  from notebook.consumer import _send_to_channel
+  pass
 
 
 def query_error_handler(func):
@@ -126,7 +123,7 @@ class KSqlApi(Api):
     return response
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     notebook = {}
 
     snippet = {

+ 7 - 6
desktop/libs/notebook/src/notebook/connectors/rdbms.py

@@ -15,14 +15,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import sys
 import logging
+import sys
 from builtins import next, object
 
-from beeswax import data_export
+from django.utils.translation import gettext as _
+
 from desktop.lib.i18n import force_unicode
 from librdbms.server import dbms
-from notebook.connectors.base import Api, QueryError, QueryExpired, _get_snippet_name
+from notebook.connectors.base import Api, QueryError, QueryExpired
 
 LOG = logging.getLogger()
 
@@ -130,14 +131,14 @@ class RdbmsApi(Api):
     return response
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     query_server = self._get_query_server()
     db = dbms.get(self.user, query_server)
 
     assist = Assist(db)
     response = {'status': -1, 'result': {}}
 
-    sample_data = assist.get_sample_data(database, table, column)
+    sample_data = assist.get_sample_data(database, table, column, nested)
 
     if sample_data:
       response['status'] = 0
@@ -204,7 +205,7 @@ class Assist(object):
   def get_columns(self, database, table):
     return self.db.get_columns(database, table, names_only=False)
 
-  def get_sample_data(self, database, table, column=None):
+  def get_sample_data(self, database, table, column=None, nested=None):
     return self.db.get_sample_data(database, table, column)
 
 

+ 4 - 6
desktop/libs/notebook/src/notebook/connectors/solr.py

@@ -15,13 +15,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import sys
 import logging
 from builtins import object
 
 from django.utils.translation import gettext as _
 
-from desktop.lib.exceptions_renderable import PopupException
 from desktop.lib.i18n import force_unicode
 from indexer.solr_client import SolrClient
 from notebook.connectors.base import Api, QueryError
@@ -32,7 +30,7 @@ LOG = logging.getLogger()
 
 try:
   from libsolr.api import SolrApi as NativeSolrApi
-except (ImportError, AttributeError) as e:
+except (ImportError, AttributeError):
   LOG.exception('Search is not enabled')
 
 
@@ -147,7 +145,7 @@ class SolrApi(Api):
     return response
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     from search.conf import SOLR_URL
     db = NativeSolrApi(SOLR_URL.get(), self.user)
 
@@ -157,7 +155,7 @@ class SolrApi(Api):
     if snippet.get('source') == 'sql':
       sample_data = assist.get_sample_data_sql(database, table, column)
     else:
-      sample_data = assist.get_sample_data(database, table, column)
+      sample_data = assist.get_sample_data(database, table, column, nested)
 
     if sample_data:
       response['status'] = 0
@@ -199,7 +197,7 @@ class Assist(object):
       } for field in self.db.schema_fields(table)['fields']
     ]
 
-  def get_sample_data(self, database, table, column=None):
+  def get_sample_data(self, database, table, column=None, nested=None):
     # Note: currently ignores dynamic fields
     full_headers = self.get_columns(database, table)
     headers = [col['name'] for col in full_headers]

+ 11 - 13
desktop/libs/notebook/src/notebook/connectors/spark_shell.py

@@ -15,13 +15,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
-import sys
 import json
-import time
 import logging
+import re
 import textwrap
-from builtins import object, range
+import time
+from builtins import object
 
 from django.utils.translation import gettext as _
 
@@ -32,8 +31,7 @@ from desktop.lib.exceptions_renderable import PopupException
 from desktop.lib.i18n import force_unicode
 from desktop.lib.rest.http_client import RestException
 from desktop.models import DefaultConfiguration
-from notebook.connectors.base import Api, QueryError, SessionExpired, _get_snippet_session
-from notebook.data_export import download as spark_download
+from notebook.connectors.base import _get_snippet_session, Api, QueryError, SessionExpired
 
 LOG = logging.getLogger()
 
@@ -41,7 +39,7 @@ LOG = logging.getLogger()
 try:
   from spark.conf import LIVY_SERVER_SESSION_KIND
   from spark.livy_client import get_api as get_spark_api
-except ImportError as e:
+except ImportError:
   LOG.exception('Spark is not enabled')
 
 SESSION_KEY = '%(username)s-%(interpreter_name)s'
@@ -118,7 +116,7 @@ class SparkApi(Api):
     api = self.get_api()
     try:
       session_present = api.get_session(session['id'])
-    except Exception as e:
+    except Exception:
       session_present = None
 
     if session_present and session_present['state'] not in ('dead', 'shutting_down', 'error', 'killed'):
@@ -348,7 +346,7 @@ class SparkApi(Api):
     session = self._handle_session_health_check(session)
 
     try:
-      response = api.cancel(session['id'])
+      api.cancel(session['id'])
     except Exception as e:
       message = force_unicode(str(e)).lower()
       LOG.debug(message)
@@ -487,7 +485,7 @@ class SparkApi(Api):
 
   def _show_tables(self, api, session, snippet_type, database):
     use_db_execute = self._execute(api, session, snippet_type, 'USE %(database)s' % {'database': database})
-    use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
+    self._check_status_and_fetch_result(api, session, use_db_execute)
 
     show_tables_execute = self._execute(api, session, snippet_type, 'SHOW TABLES')
     tables_list = self._check_status_and_fetch_result(api, session, show_tables_execute)
@@ -497,7 +495,7 @@ class SparkApi(Api):
 
   def _get_columns(self, api, session, snippet_type, database, table):
     use_db_execute = self._execute(api, session, snippet_type, 'USE %(database)s' % {'database': database})
-    use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
+    self._check_status_and_fetch_result(api, session, use_db_execute)
 
     describe_tables_execute = self._execute(api, session, snippet_type, 'DESCRIBE %(table)s' % {'table': table})
     columns_list = self._check_status_and_fetch_result(api, session, describe_tables_execute)
@@ -519,7 +517,7 @@ class SparkApi(Api):
 
       return cols
 
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     api = self.get_api()
     response = {
       'status': 0,
@@ -797,7 +795,7 @@ class SparkDescribeTable(Table):
       elif 'LazySimpleSerDe' in self.serde:
         details_format = 'text'
       else:
-        details_format = serde.rsplit('.', 1)[-1]
+        details_format = self.serde.rsplit('.', 1)[-1]
 
       self._details = {
         'stats': self.stats,

+ 8 - 14
desktop/libs/notebook/src/notebook/connectors/sql_alchemy.py

@@ -50,27 +50,21 @@ Each query statement grabs a connection from the engine and will return it after
 Disposing the engine closes all its connections.
 '''
 
-import re
-import sys
+import datetime
 import json
-import uuid
 import logging
-import datetime
+import re
 import textwrap
+import uuid
 from string import Template
 from urllib.parse import parse_qs as urllib_parse_qs, quote_plus as urllib_quote_plus, urlparse as urllib_urlparse
 
-from django.core.cache import caches
-from django.utils.translation import gettext as _
 from past.builtins import long
-from sqlalchemy import MetaData, Table, create_engine, inspect
+from sqlalchemy import create_engine, inspect, MetaData, Table
 from sqlalchemy.exc import CompileError, NoSuchTableError, OperationalError, ProgrammingError, UnsupportedCompilationError
 
-from beeswax import data_export
-from desktop.lib import export_csvxls
 from desktop.lib.i18n import force_unicode
-from librdbms.server import dbms
-from notebook.connectors.base import Api, AuthenticationRequired, QueryError, QueryExpired, _get_snippet_name
+from notebook.connectors.base import Api, AuthenticationRequired, QueryError, QueryExpired
 from notebook.models import escape_rows
 
 ENGINES = {}
@@ -463,14 +457,14 @@ class SqlAlchemyApi(Api):
     return response
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     engine = self._get_engine()
     inspector = inspect(engine)
 
     assist = Assist(inspector, engine, backticks=self.backticks, api=self)
     response = {'status': -1, 'result': {}}
 
-    metadata, sample_data = assist.get_sample_data(database, table, column=column, operation=operation)
+    metadata, sample_data = assist.get_sample_data(database, table, column=column, nested=nested, operation=operation)
 
     response['status'] = 0
     response['rows'] = escape_rows(sample_data)
@@ -548,7 +542,7 @@ class Assist(object):
     except NoSuchTableError:
       return []
 
-  def get_sample_data(self, database, table, column=None, operation=None):
+  def get_sample_data(self, database, table, column=None, nested=None, operation=None):
     if operation == 'hello':
       statement = "SELECT 'Hello World!'"
     else:

+ 2 - 7
desktop/libs/notebook/src/notebook/connectors/sqlflow.py

@@ -18,13 +18,10 @@
 
 from __future__ import absolute_import
 
-import os
-import sys
-import json
 import logging
+import os
 
 import sqlflow
-from django.utils.translation import gettext as _
 from sqlflow.rows import Rows
 
 from desktop.lib.i18n import force_unicode
@@ -63,8 +60,6 @@ class SqlFlowApi(Api):
   @query_error_handler
   @ssh_error_handler
   def execute(self, notebook, snippet):
-    db = self._get_db()
-
     statement = snippet['statement']
     statement = statement.replace('LIMIT 5000', '')
 
@@ -145,7 +140,7 @@ class SqlFlowApi(Api):
     return response
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     result = self._execute('SELECT * FROM %s.%s LIMIT 10' % (database, table))
 
     response = {

+ 3 - 5
desktop/libs/notebook/src/notebook/connectors/trino.py

@@ -16,13 +16,12 @@
 # limitations under the License.
 
 import json
-import time
 import logging
 import textwrap
+import time
 from urllib.parse import urlparse
 
 import requests
-from django.utils.translation import gettext as _
 from trino.auth import BasicAuthentication
 from trino.client import ClientSession, TrinoQuery, TrinoRequest
 from trino.exceptions import TrinoConnectionError
@@ -33,8 +32,7 @@ from desktop.conf import AUTH_PASSWORD as DEFAULT_AUTH_PASSWORD, AUTH_USERNAME a
 from desktop.lib import export_csvxls
 from desktop.lib.conf import coerce_password_from_script
 from desktop.lib.i18n import force_unicode
-from desktop.lib.rest.http_client import HttpClient, RestException
-from desktop.lib.rest.resource import Resource
+from desktop.lib.rest.http_client import RestException
 from notebook.connectors.base import Api, ExecutionWrapper, QueryError, ResultWrapper
 
 LOG = logging.getLogger()
@@ -289,7 +287,7 @@ class TrinoApi(Api):
     return response
 
   @query_error_handler
-  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=False, is_async=False, operation=None):
+  def get_sample_data(self, snippet, database=None, table=None, column=None, nested=None, is_async=False, operation=None):
     statement = self._get_select_query(database, table, column, operation)
     query_client = TrinoQuery(self.trino_request, statement)
     query_client.execute()