Ver código fonte

[importer] Update SQL type mapping API to return unique SQL types as a sorted list

This commit modifies the SQL type mapping API to return a sorted list of unique SQL types supported by various SQL dialects instead of a dictionary. The changes enhance the API's clarity and usability.

- Updates the `get_sql_type_mapping` function to return a sorted list of unique SQL types.
- Adjusts related tests to validate the new output format.
- Refactors internal mapping functions for consistency and clarity.
- Adds handling for dialect-specific overrides, ensuring accurate type mappings across different SQL dialects.
Harsh Gupta 6 meses atrás
pai
commit
d3eab61c55

+ 6 - 6
desktop/core/src/desktop/lib/importer/api.py

@@ -203,18 +203,18 @@ def guess_file_header(request: Request) -> Response:
 @parser_classes([JSONParser])
 @api_error_handler
 def get_sql_type_mapping(request: Request) -> Response:
-  """Get mapping from Polars data types to SQL types for a specific dialect.
+  """Get SQL types supported by a specific dialect.
 
-  This API endpoint returns a dictionary mapping Polars data types to the corresponding
-  SQL types for a specific SQL dialect.
+  This API endpoint returns a sorted list of unique SQL types that are supported
+  by a specific SQL dialect.
 
   Args:
     request: Request object containing query parameters:
-      - sql_dialect: The SQL dialect to get mappings for (e.g., 'hive', 'impala', 'trino')
+      - sql_dialect: The SQL dialect to get types for (e.g., 'hive', 'impala', 'trino')
 
   Returns:
-    Response containing a mapping dictionary:
-      - A mapping from Polars data type names to SQL type names for the specified dialect
+    Response containing a list of SQL types:
+      - A sorted list of unique SQL type names supported by the specified dialect
   """
   serializer = SqlTypeMapperSerializer(data=request.query_params)
 

+ 34 - 2
desktop/core/src/desktop/lib/importer/api_tests.py

@@ -619,7 +619,23 @@ class TestSqlTypeMappingAPI:
     mock_serializer = MagicMock(is_valid=MagicMock(return_value=True), validated_data=mock_schema)
     mock_serializer_class.return_value = mock_serializer
 
-    mock_get_sql_type_mapping.return_value = {"Int32": "INT", "Utf8": "STRING", "Float64": "DOUBLE", "Boolean": "BOOLEAN"}
+    mock_get_sql_type_mapping.return_value = [
+      "ARRAY",
+      "BIGINT",
+      "BINARY",
+      "BOOLEAN",
+      "DATE",
+      "DECIMAL",
+      "DOUBLE",
+      "FLOAT",
+      "INT",
+      "INTERVAL DAY TO SECOND",
+      "SMALLINT",
+      "STRING",
+      "STRUCT",
+      "TIMESTAMP",
+      "TINYINT",
+    ]
 
     request = APIRequestFactory().get("importer/sql_type_mapping/")
     request.user = MagicMock(username="test_user")
@@ -628,7 +644,23 @@ class TestSqlTypeMappingAPI:
     response = api.get_sql_type_mapping(request)
 
     assert response.status_code == status.HTTP_200_OK
-    assert response.data == {"Int32": "INT", "Utf8": "STRING", "Float64": "DOUBLE", "Boolean": "BOOLEAN"}
+    assert response.data == [
+      "ARRAY",
+      "BIGINT",
+      "BINARY",
+      "BOOLEAN",
+      "DATE",
+      "DECIMAL",
+      "DOUBLE",
+      "FLOAT",
+      "INT",
+      "INTERVAL DAY TO SECOND",
+      "SMALLINT",
+      "STRING",
+      "STRUCT",
+      "TIMESTAMP",
+      "TINYINT",
+    ]
     mock_get_sql_type_mapping.assert_called_once_with(mock_schema)
 
   @patch("desktop.lib.importer.api.SqlTypeMapperSerializer")

+ 33 - 9
desktop/core/src/desktop/lib/importer/operations.py

@@ -87,7 +87,9 @@ SQL_TYPE_BASE_MAP = {
 # Per‑dialect overrides for the few differences
 SQL_TYPE_DIALECT_OVERRIDES = {
   "hive": {},
-  "impala": {},
+  "impala": {
+    "Duration": "STRING",  # Impala doesn't support INTERVAL types
+  },
   "sparksql": {},
   "trino": {
     "Int32": "INTEGER",
@@ -627,14 +629,13 @@ def guess_file_header(data: GuessFileHeaderSchema, username: str) -> bool:
     fh.close()
 
 
-def get_sql_type_mapping(data: SqlTypeMapperSchema) -> Dict[str, str]:
-  """Get all type mappings from Polars dtypes to SQL types for a given SQL dialect.
+def _get_polars_to_sql_mapping(dialect: str) -> Dict[str, str]:
+  """Get full mapping from Polars dtypes to SQL types for a given SQL dialect.
 
-  This function returns a dictionary mapping of all Polars data types to their
-  corresponding SQL types for a specific dialect.
+  Internal function that returns the complete mapping dictionary.
 
   Args:
-    data: A Pydantic schema with the SQL dialect.
+    dialect: One of "hive", "impala", "trino", "phoenix", "sparksql".
 
   Returns:
     A dict mapping Polars dtype names to SQL type names.
@@ -642,14 +643,37 @@ def get_sql_type_mapping(data: SqlTypeMapperSchema) -> Dict[str, str]:
   Raises:
     ValueError: If the dialect is not supported.
   """
-  dl = data.sql_dialect.lower()
+  dl = dialect.lower()
   if dl not in SQL_TYPE_DIALECT_OVERRIDES:
-    raise ValueError(f"Unsupported dialect: {data.sql_dialect}")
+    raise ValueError(f"Unsupported dialect: {dialect}")
 
   # Merge base_map and overrides[dl] into a new dict, giving precedence to any overlapping keys in overrides[dl]
   return {**SQL_TYPE_BASE_MAP, **SQL_TYPE_DIALECT_OVERRIDES[dl]}
 
 
+def get_sql_type_mapping(data: SqlTypeMapperSchema) -> List[str]:
+  """Get all unique SQL types supported by a given SQL dialect.
+
+  This function returns a sorted list of unique SQL types that are supported
+  by the specified SQL dialect based on the Polars to SQL type mappings.
+
+  Args:
+    data: A Pydantic schema with the SQL dialect.
+
+  Returns:
+    A sorted list of unique SQL type names for the dialect.
+
+  Raises:
+    ValueError: If the dialect is not supported.
+  """
+  # Get the full mapping
+  mapping = _get_polars_to_sql_mapping(data.sql_dialect)
+
+  # Extract unique SQL types and return as sorted list
+  unique_sql_types = sorted(set(mapping.values()))
+  return unique_sql_types
+
+
 def _map_polars_dtype_to_sql_type(dialect: str, polars_type: str) -> str:
   """Map a Polars dtype to the corresponding SQL type for a given dialect.
 
@@ -663,7 +687,7 @@ def _map_polars_dtype_to_sql_type(dialect: str, polars_type: str) -> str:
   Raises:
     ValueError: If the dialect or polars_type is not supported.
   """
-  mapping = get_sql_type_mapping(SqlTypeMapperSchema(sql_dialect=dialect))
+  mapping = _get_polars_to_sql_mapping(dialect)
 
   if polars_type not in mapping:
     raise ValueError(f"No mapping for Polars dtype {polars_type} in dialect {dialect}")

+ 139 - 167
desktop/core/src/desktop/lib/importer/operations_tests.py

@@ -585,42 +585,25 @@ class TestSqlTypeMapping:
 
     result = operations.get_sql_type_mapping(schema)
 
-    # Test all integer types (signed and unsigned)
-    assert result["Int8"] == "TINYINT"
-    assert result["Int16"] == "SMALLINT"
-    assert result["Int32"] == "INT"
-    assert result["Int64"] == "BIGINT"
-    assert result["UInt8"] == "TINYINT"  # Unsigned mapped to signed in Hive
-    assert result["UInt16"] == "SMALLINT"
-    assert result["UInt32"] == "INT"
-    assert result["UInt64"] == "BIGINT"
-
-    # Test floating point and decimal types
-    assert result["Float32"] == "FLOAT"
-    assert result["Float64"] == "DOUBLE"
-    assert result["Decimal"] == "DECIMAL"
-
-    # Test boolean, string, and binary types
-    assert result["Boolean"] == "BOOLEAN"
-    assert result["Utf8"] == "STRING"
-    assert result["String"] == "STRING"
-    assert result["Categorical"] == "STRING"
-    assert result["Enum"] == "STRING"
-    assert result["Binary"] == "BINARY"
-
-    # Test temporal types
-    assert result["Date"] == "DATE"
-    assert result["Time"] == "TIMESTAMP"  # No pure TIME type in Hive
-    assert result["Datetime"] == "TIMESTAMP"
-    assert result["Duration"] == "INTERVAL DAY TO SECOND"
-
-    # Test nested and other types
-    assert result["Array"] == "ARRAY"
-    assert result["List"] == "ARRAY"
-    assert result["Struct"] == "STRUCT"
-    assert result["Object"] == "STRING"
-    assert result["Null"] == "STRING"
-    assert result["Unknown"] == "STRING"
+    # Test that Hive returns the expected unique SQL types
+    expected_hive_types = [
+      "ARRAY",
+      "BIGINT",
+      "BINARY",
+      "BOOLEAN",
+      "DATE",
+      "DECIMAL",
+      "DOUBLE",
+      "FLOAT",
+      "INT",
+      "INTERVAL DAY TO SECOND",
+      "SMALLINT",
+      "STRING",
+      "STRUCT",
+      "TIMESTAMP",
+      "TINYINT",
+    ]
+    assert result == expected_hive_types
 
   def test_get_sql_type_mapping_trino(self):
     # Create schema object
@@ -628,26 +611,27 @@ class TestSqlTypeMapping:
 
     result = operations.get_sql_type_mapping(schema)
 
-    # Test Trino-specific overrides
-    assert result["Int32"] == "INTEGER"  # Not INT
-    assert result["UInt32"] == "INTEGER"  # Not INT
-    assert result["Float32"] == "REAL"  # Not FLOAT
-    assert result["Utf8"] == "VARCHAR"  # Not STRING
-    assert result["String"] == "VARCHAR"  # Not STRING
-    assert result["Binary"] == "VARBINARY"  # Not BINARY
-    assert result["Struct"] == "ROW"  # Not STRUCT
-    assert result["Object"] == "JSON"  # Not STRING
-    assert result["Duration"] == "INTERVAL DAY TO SECOND"
-
-    # Test types that remain the same as base mapping
-    assert result["Int8"] == "TINYINT"
-    assert result["Int16"] == "SMALLINT"
-    assert result["Int64"] == "BIGINT"
-    assert result["Float64"] == "DOUBLE"
-    assert result["Boolean"] == "BOOLEAN"
-    assert result["Date"] == "DATE"
-    assert result["Time"] == "TIMESTAMP"
-    assert result["Datetime"] == "TIMESTAMP"
+    # Test that Trino returns the expected unique SQL types
+    expected_trino_types = [
+      "ARRAY",
+      "BIGINT",
+      "BOOLEAN",
+      "DATE",
+      "DECIMAL",
+      "DOUBLE",
+      "INTEGER",
+      "INTERVAL DAY TO SECOND",
+      "JSON",
+      "REAL",
+      "ROW",
+      "SMALLINT",
+      "STRING",
+      "TIMESTAMP",
+      "TINYINT",
+      "VARBINARY",
+      "VARCHAR",
+    ]
+    assert result == expected_trino_types
 
   def test_get_sql_type_mapping_phoenix(self):
     # Create schema object
@@ -655,34 +639,29 @@ class TestSqlTypeMapping:
 
     result = operations.get_sql_type_mapping(schema)
 
-    # Test Phoenix-specific unsigned integer mappings
-    assert result["UInt8"] == "UNSIGNED_TINYINT"
-    assert result["UInt16"] == "UNSIGNED_SMALLINT"
-    assert result["UInt32"] == "UNSIGNED_INT"
-    assert result["UInt64"] == "UNSIGNED_LONG"
-
-    # Test other Phoenix-specific overrides
-    assert result["Utf8"] == "VARCHAR"  # Not STRING
-    assert result["String"] == "VARCHAR"  # Not STRING
-    assert result["Binary"] == "VARBINARY"  # Not BINARY
-    assert result["Duration"] == "STRING"  # Phoenix treats durations as strings
-    assert result["Struct"] == "STRING"  # No native STRUCT type
-    assert result["Object"] == "VARCHAR"  # Not STRING
-    assert result["Time"] == "TIME"  # Phoenix has its own TIME type
-    assert result["Decimal"] == "DECIMAL"
-
-    # Test signed integers (use base mapping)
-    assert result["Int8"] == "TINYINT"
-    assert result["Int16"] == "SMALLINT"
-    assert result["Int32"] == "INT"
-    assert result["Int64"] == "BIGINT"
-
-    # Test other types that remain the same
-    assert result["Float32"] == "FLOAT"
-    assert result["Float64"] == "DOUBLE"
-    assert result["Boolean"] == "BOOLEAN"
-    assert result["Date"] == "DATE"
-    assert result["Datetime"] == "TIMESTAMP"
+    # Test that Phoenix returns the expected unique SQL types
+    expected_phoenix_types = [
+      "ARRAY",
+      "BIGINT",
+      "BOOLEAN",
+      "DATE",
+      "DECIMAL",
+      "DOUBLE",
+      "FLOAT",
+      "INT",
+      "SMALLINT",
+      "STRING",
+      "TIME",
+      "TIMESTAMP",
+      "TINYINT",
+      "UNSIGNED_INT",
+      "UNSIGNED_LONG",
+      "UNSIGNED_SMALLINT",
+      "UNSIGNED_TINYINT",
+      "VARBINARY",
+      "VARCHAR",
+    ]
+    assert result == expected_phoenix_types
 
   def test_get_sql_type_mapping_impala(self):
     # Create schema object
@@ -690,30 +669,25 @@ class TestSqlTypeMapping:
 
     result = operations.get_sql_type_mapping(schema)
 
-    # Impala uses all base mappings (no overrides)
-    # Test a comprehensive set to ensure no overrides are applied
-    assert result["Int8"] == "TINYINT"
-    assert result["Int16"] == "SMALLINT"
-    assert result["Int32"] == "INT"
-    assert result["Int64"] == "BIGINT"
-    assert result["UInt8"] == "TINYINT"
-    assert result["UInt16"] == "SMALLINT"
-    assert result["UInt32"] == "INT"
-    assert result["UInt64"] == "BIGINT"
-    assert result["Float32"] == "FLOAT"
-    assert result["Float64"] == "DOUBLE"
-    assert result["Decimal"] == "DECIMAL"
-    assert result["Boolean"] == "BOOLEAN"
-    assert result["Utf8"] == "STRING"
-    assert result["String"] == "STRING"
-    assert result["Binary"] == "BINARY"
-    assert result["Date"] == "DATE"
-    assert result["Time"] == "TIMESTAMP"
-    assert result["Datetime"] == "TIMESTAMP"
-    assert result["Duration"] == "INTERVAL DAY TO SECOND"
-    assert result["Array"] == "ARRAY"
-    assert result["Struct"] == "STRUCT"
-    assert result["Object"] == "STRING"
+    # Test that Impala returns the expected unique SQL types
+    # Note: Impala doesn't support INTERVAL types, so Duration maps to STRING
+    expected_impala_types = [
+      "ARRAY",
+      "BIGINT",
+      "BINARY",
+      "BOOLEAN",
+      "DATE",
+      "DECIMAL",
+      "DOUBLE",
+      "FLOAT",
+      "INT",
+      "SMALLINT",
+      "STRING",
+      "STRUCT",
+      "TIMESTAMP",
+      "TINYINT",
+    ]
+    assert result == expected_impala_types
 
   def test_get_sql_type_mapping_sparksql(self):
     # Create schema object
@@ -721,73 +695,66 @@ class TestSqlTypeMapping:
 
     result = operations.get_sql_type_mapping(schema)
 
-    # SparkSQL uses all base mappings (no overrides)
-    # Test a comprehensive set to ensure no overrides are applied
-    assert result["Int8"] == "TINYINT"
-    assert result["Int16"] == "SMALLINT"
-    assert result["Int32"] == "INT"
-    assert result["Int64"] == "BIGINT"
-    assert result["UInt8"] == "TINYINT"
-    assert result["UInt16"] == "SMALLINT"
-    assert result["UInt32"] == "INT"
-    assert result["UInt64"] == "BIGINT"
-    assert result["Float32"] == "FLOAT"
-    assert result["Float64"] == "DOUBLE"
-    assert result["Decimal"] == "DECIMAL"
-    assert result["Boolean"] == "BOOLEAN"
-    assert result["Utf8"] == "STRING"
-    assert result["String"] == "STRING"
-    assert result["Binary"] == "BINARY"
-    assert result["Date"] == "DATE"
-    assert result["Time"] == "TIMESTAMP"
-    assert result["Datetime"] == "TIMESTAMP"
-    assert result["Duration"] == "INTERVAL DAY TO SECOND"
-    assert result["Array"] == "ARRAY"
-    assert result["Struct"] == "STRUCT"
-    assert result["Object"] == "STRING"
+    # Test that SparkSQL returns the expected unique SQL types (same as Hive)
+    expected_sparksql_types = [
+      "ARRAY",
+      "BIGINT",
+      "BINARY",
+      "BOOLEAN",
+      "DATE",
+      "DECIMAL",
+      "DOUBLE",
+      "FLOAT",
+      "INT",
+      "INTERVAL DAY TO SECOND",
+      "SMALLINT",
+      "STRING",
+      "STRUCT",
+      "TIMESTAMP",
+      "TINYINT",
+    ]
+    assert result == expected_sparksql_types
 
   def test_get_sql_type_mapping_all_dialects_consistency(self):
-    # Test that all dialects return mappings for all base types
+    # Test that all dialects return a non-empty list of SQL types
     dialects = ["hive", "impala", "sparksql", "trino", "phoenix"]
-    base_types = [
-      "Int8",
-      "Int16",
-      "Int32",
-      "Int64",
-      "UInt8",
-      "UInt16",
-      "UInt32",
-      "UInt64",
-      "Float32",
-      "Float64",
-      "Decimal",
-      "Boolean",
-      "Utf8",
-      "String",
-      "Categorical",
-      "Enum",
-      "Binary",
-      "Date",
-      "Time",
-      "Datetime",
-      "Duration",
-      "Array",
-      "List",
-      "Struct",
-      "Object",
-      "Null",
-      "Unknown",
-    ]
 
     for dialect in dialects:
       schema = SqlTypeMapperSchema(sql_dialect=dialect)
       result = operations.get_sql_type_mapping(schema)
 
-      # Ensure all base types have mappings
-      for base_type in base_types:
-        assert base_type in result, f"Missing mapping for {base_type} in {dialect} dialect"
-        assert isinstance(result[base_type], str), f"Invalid mapping type for {base_type} in {dialect} dialect"
-        assert len(result[base_type]) > 0, f"Empty mapping for {base_type} in {dialect} dialect"
+      # Ensure result is a list
+      assert isinstance(result, list), f"Result for {dialect} is not a list"
+
+      # Ensure the list is not empty
+      assert len(result) > 0, f"Empty result for {dialect} dialect"
+
+      # Ensure all items in the list are strings
+      for sql_type in result:
+        assert isinstance(sql_type, str), f"Invalid type in result for {dialect}: {sql_type}"
+        assert len(sql_type) > 0, f"Empty SQL type string in {dialect} dialect"
+
+  def test_get_polars_to_sql_mapping(self):
+    # Test the internal function that returns the full mapping
+
+    # Test Hive dialect
+    hive_mapping = operations._get_polars_to_sql_mapping("hive")
+    assert isinstance(hive_mapping, dict)
+    assert hive_mapping["Int32"] == "INT"
+    assert hive_mapping["Utf8"] == "STRING"
+    assert hive_mapping["Boolean"] == "BOOLEAN"
+
+    # Test Trino dialect with overrides
+    trino_mapping = operations._get_polars_to_sql_mapping("trino")
+    assert isinstance(trino_mapping, dict)
+    assert trino_mapping["Int32"] == "INTEGER"  # Override
+    assert trino_mapping["Float32"] == "REAL"  # Override
+    assert trino_mapping["Utf8"] == "VARCHAR"  # Override
+    assert trino_mapping["Boolean"] == "BOOLEAN"  # No override
+
+    # Test unsupported dialect
+    with pytest.raises(ValueError, match="Unsupported dialect"):
+      operations._get_polars_to_sql_mapping("unsupported_dialect")
 
   def test_map_polars_dtype_to_sql_type(self):
     # Test comprehensive type mapping for each dialect
@@ -817,6 +784,11 @@ class TestSqlTypeMapping:
     assert operations._map_polars_dtype_to_sql_type("phoenix", "Duration") == "STRING"
     assert operations._map_polars_dtype_to_sql_type("phoenix", "Struct") == "STRING"
 
+    # Impala dialect tests (with Duration override)
+    assert operations._map_polars_dtype_to_sql_type("impala", "Duration") == "STRING"
+    assert operations._map_polars_dtype_to_sql_type("impala", "Int32") == "INT"  # No override
+    assert operations._map_polars_dtype_to_sql_type("impala", "Utf8") == "STRING"  # No override
+
     # Test error for unknown type
     with pytest.raises(ValueError, match="No mapping for Polars dtype"):
       operations._map_polars_dtype_to_sql_type("hive", "UnknownType")