|
|
@@ -18,65 +18,139 @@
|
|
|
import os
|
|
|
import tempfile
|
|
|
import zipfile
|
|
|
-from unittest.mock import MagicMock, mock_open, patch
|
|
|
+from unittest.mock import MagicMock, patch
|
|
|
|
|
|
import pytest
|
|
|
from django.core.files.uploadedfile import SimpleUploadedFile
|
|
|
|
|
|
+from desktop.conf import IMPORTER
|
|
|
from desktop.lib.importer import operations
|
|
|
+from desktop.lib.importer.schemas import (
|
|
|
+ GuessFileHeaderSchema,
|
|
|
+ GuessFileMetadataSchema,
|
|
|
+ LocalFileUploadSchema,
|
|
|
+ PreviewFileSchema,
|
|
|
+ SqlTypeMapperSchema,
|
|
|
+)
|
|
|
|
|
|
|
|
|
class TestLocalFileUpload:
|
|
|
- @patch("uuid.uuid4")
|
|
|
- def test_local_file_upload_success(self, mock_uuid):
|
|
|
- # Mock uuid to get a predictable filename
|
|
|
- mock_uuid.return_value.hex = "12345678"
|
|
|
+ def test_local_file_upload_success(self):
|
|
|
+ resets = [
|
|
|
+ IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([".exe", ".bat"]),
|
|
|
+ IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10 * 1024 * 1024), # 10 MiB limit
|
|
|
+ ]
|
|
|
|
|
|
- test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv")
|
|
|
+ try:
|
|
|
+ test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv")
|
|
|
|
|
|
- result = operations.local_file_upload(test_file, "test_user")
|
|
|
+ # Create schema object
|
|
|
+ schema = LocalFileUploadSchema(file=test_file, filename="test_file.csv", filesize=test_file.size)
|
|
|
|
|
|
- # Get the expected file path
|
|
|
- temp_dir = tempfile.gettempdir()
|
|
|
- expected_path = os.path.join(temp_dir, "test_user_12345678_test_file.csv")
|
|
|
+ result = operations.local_file_upload(schema, "test_user")
|
|
|
|
|
|
- try:
|
|
|
assert "file_path" in result
|
|
|
- assert result["file_path"] == expected_path
|
|
|
+ file_path = result["file_path"]
|
|
|
+
|
|
|
+ # Verify the file path contains expected components
|
|
|
+ assert "test_user_" in file_path
|
|
|
+ assert "_test_file.csv" in file_path
|
|
|
+ assert file_path.startswith(tempfile.gettempdir())
|
|
|
|
|
|
# Verify the file was created and has the right content
|
|
|
- assert os.path.exists(expected_path)
|
|
|
- with open(expected_path, "rb") as f:
|
|
|
+ assert os.path.exists(file_path)
|
|
|
+ with open(file_path, "rb") as f:
|
|
|
assert f.read() == b"header1,header2\nvalue1,value2"
|
|
|
|
|
|
finally:
|
|
|
- # Clean up the file
|
|
|
- if os.path.exists(expected_path):
|
|
|
- os.remove(expected_path)
|
|
|
+ # Clean up in case assertion fails
|
|
|
+ if os.path.exists(result["file_path"]):
|
|
|
+ os.remove(result["file_path"])
|
|
|
|
|
|
- assert not os.path.exists(expected_path), "Temporary file was not cleaned up properly"
|
|
|
+ for reset in resets:
|
|
|
+ reset()
|
|
|
|
|
|
- def test_local_file_upload_none_file(self):
|
|
|
- with pytest.raises(ValueError, match="Upload file cannot be None or empty."):
|
|
|
- operations.local_file_upload(None, "test_user")
|
|
|
-
|
|
|
- def test_local_file_upload_none_username(self):
|
|
|
+ def test_local_file_upload_empty_username(self):
|
|
|
test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv")
|
|
|
|
|
|
+ # Create schema object
|
|
|
+ schema = LocalFileUploadSchema(file=test_file, filename="test_file.csv", filesize=test_file.size)
|
|
|
+
|
|
|
with pytest.raises(ValueError, match="Username cannot be None or empty."):
|
|
|
- operations.local_file_upload(test_file, None)
|
|
|
+ operations.local_file_upload(schema, "")
|
|
|
+
|
|
|
+ @patch("tempfile.NamedTemporaryFile")
|
|
|
+ @patch("shutil.copyfileobj")
|
|
|
+ def test_local_file_upload_exception_handling_with_cleanup(self, mock_copyfileobj, mock_tempfile):
|
|
|
+ resets = [
|
|
|
+ IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([".exe", ".bat"]),
|
|
|
+ IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10 * 1024 * 1024), # 10 MiB limit
|
|
|
+ ]
|
|
|
|
|
|
- @patch("os.path.join")
|
|
|
- @patch("builtins.open", new_callable=mock_open)
|
|
|
- def test_local_file_upload_exception_handling(self, mock_file_open, mock_join):
|
|
|
- # Setup mocks to raise an exception when opening the file
|
|
|
- mock_file_open.side_effect = IOError("Test IO Error")
|
|
|
- mock_join.return_value = "/tmp/test_user_12345678_test_file.csv"
|
|
|
+ # Mock the temporary file
|
|
|
+ mock_file = MagicMock()
|
|
|
+ mock_file.name = "/tmp/test_user_12345678_test_file.csv"
|
|
|
+ mock_tempfile.return_value = mock_file
|
|
|
+ mock_file.__enter__.return_value = mock_file
|
|
|
+
|
|
|
+ # Make copyfileobj raise an exception
|
|
|
+ mock_copyfileobj.side_effect = IOError("Test IO Error")
|
|
|
|
|
|
test_file = SimpleUploadedFile(name="test_file.csv", content=b"header1,header2\nvalue1,value2", content_type="text/csv")
|
|
|
|
|
|
- with pytest.raises(Exception, match="Test IO Error"):
|
|
|
- operations.local_file_upload(test_file, "test_user")
|
|
|
+ # Create schema object
|
|
|
+ schema = LocalFileUploadSchema(file=test_file, filename="test_file.csv", filesize=test_file.size)
|
|
|
+
|
|
|
+ # Create the temp file for testing cleanup
|
|
|
+ with open(mock_file.name, "w") as f:
|
|
|
+ f.write("temp")
|
|
|
+
|
|
|
+ try:
|
|
|
+ with pytest.raises(IOError, match="Test IO Error"):
|
|
|
+ operations.local_file_upload(schema, "test_user")
|
|
|
+
|
|
|
+ # Verify the file was cleaned up after the exception
|
|
|
+ assert not os.path.exists(mock_file.name), "Temporary file was not cleaned up after exception"
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # Clean up in case assertion fails
|
|
|
+ if os.path.exists(mock_file.name):
|
|
|
+ os.remove(mock_file.name)
|
|
|
+
|
|
|
+ for reset in resets:
|
|
|
+ reset()
|
|
|
+
|
|
|
+ def test_local_file_upload_special_characters_in_filename(self):
|
|
|
+ resets = [
|
|
|
+ IMPORTER.RESTRICT_LOCAL_FILE_EXTENSIONS.set_for_testing([".exe", ".bat"]),
|
|
|
+ IMPORTER.MAX_LOCAL_FILE_SIZE_UPLOAD_LIMIT.set_for_testing(10 * 1024 * 1024), # 10 MiB limit
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Test with special characters in filename
|
|
|
+ test_file = SimpleUploadedFile(name="test file (with) [special] {chars} & symbols!.csv", content=b"data", content_type="text/csv")
|
|
|
+
|
|
|
+ # Create schema object
|
|
|
+ schema = LocalFileUploadSchema(file=test_file, filename="test file (with) [special] {chars} & symbols!.csv", filesize=test_file.size)
|
|
|
+
|
|
|
+ result = operations.local_file_upload(schema, "test_user")
|
|
|
+
|
|
|
+ try:
|
|
|
+ assert "file_path" in result
|
|
|
+ file_path = result["file_path"]
|
|
|
+
|
|
|
+ # Verify the file was created
|
|
|
+ assert os.path.exists(file_path)
|
|
|
+
|
|
|
+ # Verify filename is sanitized properly
|
|
|
+ assert "_test file (with) [special] {chars} & symbols!.csv" in file_path
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # Clean up the file
|
|
|
+ if os.path.exists(result["file_path"]):
|
|
|
+ os.remove(result["file_path"])
|
|
|
+
|
|
|
+ for reset in resets:
|
|
|
+ reset()
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("cleanup_temp_files")
|
|
|
@@ -107,7 +181,10 @@ class TestGuessFileMetadata:
|
|
|
# Mock magic.from_buffer to return text/csv MIME type
|
|
|
mock_magic.from_buffer.return_value = "text/plain"
|
|
|
|
|
|
- result = operations.guess_file_metadata(temp_file.name, "local")
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileMetadataSchema(file_path=temp_file.name, import_type="local")
|
|
|
+
|
|
|
+ result = operations.guess_file_metadata(data=schema)
|
|
|
|
|
|
assert result == {
|
|
|
"type": "csv",
|
|
|
@@ -130,7 +207,10 @@ class TestGuessFileMetadata:
|
|
|
# Mock magic.from_buffer to return text/plain MIME type
|
|
|
mock_magic.from_buffer.return_value = "text/plain"
|
|
|
|
|
|
- result = operations.guess_file_metadata(temp_file.name, "local")
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileMetadataSchema(file_path=temp_file.name, import_type="local")
|
|
|
+
|
|
|
+ result = operations.guess_file_metadata(data=schema)
|
|
|
|
|
|
assert result == {
|
|
|
"type": "tsv",
|
|
|
@@ -165,7 +245,10 @@ class TestGuessFileMetadata:
|
|
|
# Mock _get_sheet_names_xlsx to return sheet names
|
|
|
mock_get_sheet_names.return_value = ["Sheet1", "Sheet2", "Sheet3"]
|
|
|
|
|
|
- result = operations.guess_file_metadata(temp_file.name, "local")
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileMetadataSchema(file_path=temp_file.name, import_type="local")
|
|
|
+
|
|
|
+ result = operations.guess_file_metadata(data=schema)
|
|
|
|
|
|
assert result == {
|
|
|
"type": "excel",
|
|
|
@@ -184,21 +267,25 @@ class TestGuessFileMetadata:
|
|
|
# Mock magic.from_buffer to return an unsupported MIME type
|
|
|
mock_magic.from_buffer.return_value = "application/octet-stream"
|
|
|
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileMetadataSchema(file_path=temp_file.name, import_type="local")
|
|
|
+
|
|
|
with pytest.raises(ValueError, match="Unable to detect file format."):
|
|
|
- operations.guess_file_metadata(temp_file.name, "local")
|
|
|
+ operations.guess_file_metadata(data=schema)
|
|
|
|
|
|
def test_guess_file_metadata_nonexistent_file(self):
|
|
|
- file_path = "/path/to/nonexistent/file.csv"
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileMetadataSchema(file_path="/path/to/nonexistent/file.csv", import_type="local")
|
|
|
|
|
|
with pytest.raises(ValueError, match="Local file does not exist."):
|
|
|
- operations.guess_file_metadata(file_path, "local")
|
|
|
+ operations.guess_file_metadata(data=schema)
|
|
|
|
|
|
def test_guess_remote_file_metadata_no_fs(self):
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileMetadataSchema(file_path="s3a://bucket/user/test_user/test.csv", import_type="remote")
|
|
|
+
|
|
|
with pytest.raises(ValueError, match="File system object is required for remote import type"):
|
|
|
- operations.guess_file_metadata(
|
|
|
- file_path="s3a://bucket/user/test_user/test.csv", # Remote file path
|
|
|
- import_type="remote", # Remote file but no fs provided
|
|
|
- )
|
|
|
+ operations.guess_file_metadata(data=schema, fs=None)
|
|
|
|
|
|
def test_guess_file_metadata_empty_file(self, cleanup_temp_files):
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
|
|
@@ -206,8 +293,11 @@ class TestGuessFileMetadata:
|
|
|
|
|
|
cleanup_temp_files.append(temp_file.name)
|
|
|
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileMetadataSchema(file_path=temp_file.name, import_type="local")
|
|
|
+
|
|
|
with pytest.raises(ValueError, match="File is empty, cannot detect file format."):
|
|
|
- operations.guess_file_metadata(temp_file.name, "local")
|
|
|
+ operations.guess_file_metadata(data=schema)
|
|
|
|
|
|
@patch("desktop.lib.importer.operations.is_magic_lib_available", False)
|
|
|
def test_guess_file_metadata_no_magic_lib(self, cleanup_temp_files):
|
|
|
@@ -217,8 +307,11 @@ class TestGuessFileMetadata:
|
|
|
|
|
|
cleanup_temp_files.append(temp_file.name)
|
|
|
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileMetadataSchema(file_path=temp_file.name, import_type="local")
|
|
|
+
|
|
|
with pytest.raises(RuntimeError, match="Unable to guess file type. python-magic or its dependency libmagic is not installed."):
|
|
|
- operations.guess_file_metadata(temp_file.name, "local")
|
|
|
+ operations.guess_file_metadata(data=schema)
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("cleanup_temp_files")
|
|
|
@@ -260,10 +353,13 @@ class TestPreviewFile:
|
|
|
|
|
|
mock_pl.read_excel.return_value = mock_df
|
|
|
|
|
|
- result = operations.preview_file(
|
|
|
+ # Create schema object
|
|
|
+ schema = PreviewFileSchema(
|
|
|
file_path=temp_file.name, file_type="excel", import_type="local", sql_dialect="hive", has_header=True, sheet_name="Sheet1"
|
|
|
)
|
|
|
|
|
|
+ result = operations.preview_file(data=schema)
|
|
|
+
|
|
|
assert result == {
|
|
|
"type": "excel",
|
|
|
"columns": [
|
|
|
@@ -282,7 +378,8 @@ class TestPreviewFile:
|
|
|
|
|
|
cleanup_temp_files.append(temp_file.name)
|
|
|
|
|
|
- result = operations.preview_file(
|
|
|
+ # Create schema object
|
|
|
+ schema = PreviewFileSchema(
|
|
|
file_path=temp_file.name,
|
|
|
file_type="csv",
|
|
|
import_type="local",
|
|
|
@@ -293,6 +390,8 @@ class TestPreviewFile:
|
|
|
record_separator="\n",
|
|
|
)
|
|
|
|
|
|
+ result = operations.preview_file(data=schema)
|
|
|
+
|
|
|
assert result == {
|
|
|
"type": "csv",
|
|
|
"columns": [
|
|
|
@@ -311,7 +410,8 @@ class TestPreviewFile:
|
|
|
|
|
|
cleanup_temp_files.append(temp_file.name)
|
|
|
|
|
|
- result = operations.preview_file(
|
|
|
+ # Create schema object
|
|
|
+ schema = PreviewFileSchema(
|
|
|
file_path=temp_file.name,
|
|
|
file_type="csv",
|
|
|
import_type="local",
|
|
|
@@ -322,6 +422,8 @@ class TestPreviewFile:
|
|
|
record_separator="\n",
|
|
|
)
|
|
|
|
|
|
+ result = operations.preview_file(data=schema)
|
|
|
+
|
|
|
assert result == {
|
|
|
"type": "csv",
|
|
|
"columns": [{"name": "column_1", "type": "STRING"}, {"name": "column_2", "type": "STRING"}],
|
|
|
@@ -336,7 +438,8 @@ class TestPreviewFile:
|
|
|
|
|
|
cleanup_temp_files.append(temp_file.name)
|
|
|
|
|
|
- result = operations.preview_file(
|
|
|
+ # Create schema object
|
|
|
+ schema = PreviewFileSchema(
|
|
|
file_path=temp_file.name,
|
|
|
file_type="csv",
|
|
|
import_type="local",
|
|
|
@@ -347,54 +450,34 @@ class TestPreviewFile:
|
|
|
record_separator="\n",
|
|
|
)
|
|
|
|
|
|
+ result = operations.preview_file(data=schema)
|
|
|
+
|
|
|
assert result == {
|
|
|
"type": "csv",
|
|
|
"columns": [],
|
|
|
"preview_data": {},
|
|
|
}
|
|
|
|
|
|
- def test_preview_invalid_file_path(self):
|
|
|
- with pytest.raises(ValueError, match="File path cannot be empty"):
|
|
|
- operations.preview_file(file_path="", file_type="csv", import_type="local", sql_dialect="hive", has_header=True)
|
|
|
-
|
|
|
- def test_preview_unsupported_file_type(self):
|
|
|
- with pytest.raises(ValueError, match="Unsupported file type: json"):
|
|
|
- operations.preview_file(
|
|
|
- file_path="/path/to/test.json",
|
|
|
- file_type="json", # Unsupported type
|
|
|
- import_type="local",
|
|
|
- sql_dialect="hive",
|
|
|
- has_header=True,
|
|
|
- )
|
|
|
-
|
|
|
- def test_preview_unsupported_sql_dialect(self):
|
|
|
- with pytest.raises(ValueError, match="Unsupported SQL dialect: mysql"):
|
|
|
- operations.preview_file(
|
|
|
- file_path="/path/to/test.csv",
|
|
|
- file_type="csv",
|
|
|
- import_type="local",
|
|
|
- sql_dialect="mysql", # Unsupported dialect
|
|
|
- has_header=True,
|
|
|
- )
|
|
|
-
|
|
|
def test_preview_remote_file_no_fs(self):
|
|
|
+ # Create schema object
|
|
|
+ schema = PreviewFileSchema(
|
|
|
+ file_path="s3a://bucket/user/test_user/test.csv", file_type="csv", import_type="remote", sql_dialect="hive", has_header=True
|
|
|
+ )
|
|
|
+
|
|
|
with pytest.raises(ValueError, match="File system object is required for remote import type"):
|
|
|
- operations.preview_file(
|
|
|
- file_path="s3a://bucket/user/test_user/test.csv", # Remote file path
|
|
|
- file_type="csv",
|
|
|
- import_type="remote", # Remote file but no fs provided
|
|
|
- sql_dialect="hive",
|
|
|
- has_header=True,
|
|
|
- )
|
|
|
+ operations.preview_file(data=schema, fs=None)
|
|
|
|
|
|
@patch("os.path.exists")
|
|
|
def test_preview_nonexistent_local_file(self, mock_exists):
|
|
|
mock_exists.return_value = False
|
|
|
|
|
|
+ # Create schema object
|
|
|
+ schema = PreviewFileSchema(
|
|
|
+ file_path="/path/to/nonexistent.csv", file_type="csv", import_type="local", sql_dialect="hive", has_header=True
|
|
|
+ )
|
|
|
+
|
|
|
with pytest.raises(ValueError, match="Local file does not exist: /path/to/nonexistent.csv"):
|
|
|
- operations.preview_file(
|
|
|
- file_path="/path/to/nonexistent.csv", file_type="csv", import_type="local", sql_dialect="hive", has_header=True
|
|
|
- )
|
|
|
+ operations.preview_file(data=schema)
|
|
|
|
|
|
def test_preview_trino_dialect_type_mapping(self, cleanup_temp_files):
|
|
|
test_content = "string_col\nfoo\nbar"
|
|
|
@@ -404,7 +487,8 @@ class TestPreviewFile:
|
|
|
|
|
|
cleanup_temp_files.append(temp_file.name)
|
|
|
|
|
|
- result = operations.preview_file(
|
|
|
+ # Create schema object
|
|
|
+ schema = PreviewFileSchema(
|
|
|
file_path=temp_file.name,
|
|
|
file_type="csv",
|
|
|
import_type="local",
|
|
|
@@ -413,6 +497,8 @@ class TestPreviewFile:
|
|
|
field_separator=",",
|
|
|
)
|
|
|
|
|
|
+ result = operations.preview_file(data=schema)
|
|
|
+
|
|
|
# Check the result for Trino-specific type mapping
|
|
|
assert result["columns"][0]["type"] == "VARCHAR" # Not STRING
|
|
|
|
|
|
@@ -439,7 +525,10 @@ class TestGuessFileHeader:
|
|
|
|
|
|
cleanup_temp_files.append(temp_file.name)
|
|
|
|
|
|
- result = operations.guess_file_header(file_path=temp_file.name, file_type="csv", import_type="local")
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileHeaderSchema(file_path=temp_file.name, file_type="csv", import_type="local")
|
|
|
+
|
|
|
+ result = operations.guess_file_header(data=schema)
|
|
|
|
|
|
assert result
|
|
|
|
|
|
@@ -467,120 +556,270 @@ class TestGuessFileHeader:
|
|
|
mock_sniffer_instance.has_header.return_value = True
|
|
|
mock_sniffer.return_value = mock_sniffer_instance
|
|
|
|
|
|
- result = operations.guess_file_header(file_path=temp_file.name, file_type="excel", import_type="local", sheet_name="Sheet1")
|
|
|
-
|
|
|
- assert result
|
|
|
-
|
|
|
- def test_guess_header_excel_no_sheet_name(self, cleanup_temp_files):
|
|
|
- test_content = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
|
|
- <workbook xmlns="http://schemas.openxmlformats.org/spreadsheetml/2006/main">
|
|
|
- </workbook>"""
|
|
|
-
|
|
|
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx")
|
|
|
- temp_file.write(test_content.encode("utf-8"))
|
|
|
- temp_file.close()
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileHeaderSchema(file_path=temp_file.name, file_type="excel", import_type="local", sheet_name="Sheet1")
|
|
|
|
|
|
- cleanup_temp_files.append(temp_file.name)
|
|
|
+ result = operations.guess_file_header(data=schema)
|
|
|
|
|
|
- with pytest.raises(ValueError, match="Sheet name is required for Excel files"):
|
|
|
- operations.guess_file_header(
|
|
|
- file_path=temp_file.name,
|
|
|
- file_type="excel",
|
|
|
- import_type="local",
|
|
|
- # Missing sheet_name
|
|
|
- )
|
|
|
-
|
|
|
- def test_guess_header_invalid_path(self):
|
|
|
- with pytest.raises(ValueError, match="File path cannot be empty"):
|
|
|
- operations.guess_file_header(file_path="", file_type="csv", import_type="local")
|
|
|
-
|
|
|
- def test_guess_header_unsupported_file_type(self):
|
|
|
- with pytest.raises(ValueError, match="Unsupported file type: json"):
|
|
|
- operations.guess_file_header(
|
|
|
- file_path="/path/to/test.json",
|
|
|
- file_type="json", # Unsupported type
|
|
|
- import_type="local",
|
|
|
- )
|
|
|
+ assert result
|
|
|
|
|
|
def test_guess_header_nonexistent_local_file(self):
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileHeaderSchema(file_path="/path/to/nonexistent/file.csv", file_type="csv", import_type="local")
|
|
|
+
|
|
|
with pytest.raises(ValueError, match="Local file does not exist"):
|
|
|
- operations.guess_file_header(file_path="/path/to/nonexistent.csv", file_type="csv", import_type="local")
|
|
|
+ operations.guess_file_header(data=schema)
|
|
|
|
|
|
def test_guess_header_remote_file_no_fs(self):
|
|
|
+ # Create schema object
|
|
|
+ schema = GuessFileHeaderSchema(file_path="s3a://bucket/user/test_user/test.csv", file_type="csv", import_type="remote")
|
|
|
+
|
|
|
with pytest.raises(ValueError, match="File system object is required for remote import type"):
|
|
|
- operations.guess_file_header(
|
|
|
- file_path="hdfs:///path/to/test.csv",
|
|
|
- file_type="csv",
|
|
|
- import_type="remote", # Remote but no fs provided
|
|
|
- )
|
|
|
+ operations.guess_file_header(data=schema, fs=None)
|
|
|
|
|
|
|
|
|
class TestSqlTypeMapping:
|
|
|
def test_get_sql_type_mapping_hive(self):
|
|
|
- mappings = operations.get_sql_type_mapping("hive")
|
|
|
+ # Create schema object
|
|
|
+ schema = SqlTypeMapperSchema(sql_dialect="hive")
|
|
|
|
|
|
- # Check some key mappings for Hive
|
|
|
- assert mappings["Int32"] == "INT"
|
|
|
- assert mappings["Utf8"] == "STRING"
|
|
|
- assert mappings["Float64"] == "DOUBLE"
|
|
|
- assert mappings["Boolean"] == "BOOLEAN"
|
|
|
- assert mappings["Decimal"] == "DECIMAL"
|
|
|
+ result = operations.get_sql_type_mapping(schema)
|
|
|
|
|
|
- def test_get_sql_type_mapping_trino(self):
|
|
|
- mappings = operations.get_sql_type_mapping("trino")
|
|
|
+ # Test all integer types (signed and unsigned)
|
|
|
+ assert result["Int8"] == "TINYINT"
|
|
|
+ assert result["Int16"] == "SMALLINT"
|
|
|
+ assert result["Int32"] == "INT"
|
|
|
+ assert result["Int64"] == "BIGINT"
|
|
|
+ assert result["UInt8"] == "TINYINT" # Unsigned mapped to signed in Hive
|
|
|
+ assert result["UInt16"] == "SMALLINT"
|
|
|
+ assert result["UInt32"] == "INT"
|
|
|
+ assert result["UInt64"] == "BIGINT"
|
|
|
+
|
|
|
+ # Test floating point and decimal types
|
|
|
+ assert result["Float32"] == "FLOAT"
|
|
|
+ assert result["Float64"] == "DOUBLE"
|
|
|
+ assert result["Decimal"] == "DECIMAL"
|
|
|
|
|
|
- # Check some key mappings for Trino that differ from Hive
|
|
|
- assert mappings["Int32"] == "INTEGER"
|
|
|
- assert mappings["Utf8"] == "VARCHAR"
|
|
|
- assert mappings["Binary"] == "VARBINARY"
|
|
|
- assert mappings["Float32"] == "REAL"
|
|
|
- assert mappings["Struct"] == "ROW"
|
|
|
- assert mappings["Object"] == "JSON"
|
|
|
+ # Test boolean, string, and binary types
|
|
|
+ assert result["Boolean"] == "BOOLEAN"
|
|
|
+ assert result["Utf8"] == "STRING"
|
|
|
+ assert result["String"] == "STRING"
|
|
|
+ assert result["Categorical"] == "STRING"
|
|
|
+ assert result["Enum"] == "STRING"
|
|
|
+ assert result["Binary"] == "BINARY"
|
|
|
+
|
|
|
+ # Test temporal types
|
|
|
+ assert result["Date"] == "DATE"
|
|
|
+ assert result["Time"] == "TIMESTAMP" # No pure TIME type in Hive
|
|
|
+ assert result["Datetime"] == "TIMESTAMP"
|
|
|
+ assert result["Duration"] == "INTERVAL DAY TO SECOND"
|
|
|
+
|
|
|
+ # Test nested and other types
|
|
|
+ assert result["Array"] == "ARRAY"
|
|
|
+ assert result["List"] == "ARRAY"
|
|
|
+ assert result["Struct"] == "STRUCT"
|
|
|
+ assert result["Object"] == "STRING"
|
|
|
+ assert result["Null"] == "STRING"
|
|
|
+ assert result["Unknown"] == "STRING"
|
|
|
+
|
|
|
+ def test_get_sql_type_mapping_trino(self):
|
|
|
+ # Create schema object
|
|
|
+ schema = SqlTypeMapperSchema(sql_dialect="trino")
|
|
|
+
|
|
|
+ result = operations.get_sql_type_mapping(schema)
|
|
|
+
|
|
|
+ # Test Trino-specific overrides
|
|
|
+ assert result["Int32"] == "INTEGER" # Not INT
|
|
|
+ assert result["UInt32"] == "INTEGER" # Not INT
|
|
|
+ assert result["Float32"] == "REAL" # Not FLOAT
|
|
|
+ assert result["Utf8"] == "VARCHAR" # Not STRING
|
|
|
+ assert result["String"] == "VARCHAR" # Not STRING
|
|
|
+ assert result["Binary"] == "VARBINARY" # Not BINARY
|
|
|
+ assert result["Struct"] == "ROW" # Not STRUCT
|
|
|
+ assert result["Object"] == "JSON" # Not STRING
|
|
|
+ assert result["Duration"] == "INTERVAL DAY TO SECOND"
|
|
|
+
|
|
|
+ # Test types that remain the same as base mapping
|
|
|
+ assert result["Int8"] == "TINYINT"
|
|
|
+ assert result["Int16"] == "SMALLINT"
|
|
|
+ assert result["Int64"] == "BIGINT"
|
|
|
+ assert result["Float64"] == "DOUBLE"
|
|
|
+ assert result["Boolean"] == "BOOLEAN"
|
|
|
+ assert result["Date"] == "DATE"
|
|
|
+ assert result["Time"] == "TIMESTAMP"
|
|
|
+ assert result["Datetime"] == "TIMESTAMP"
|
|
|
|
|
|
def test_get_sql_type_mapping_phoenix(self):
|
|
|
- mappings = operations.get_sql_type_mapping("phoenix")
|
|
|
+ # Create schema object
|
|
|
+ schema = SqlTypeMapperSchema(sql_dialect="phoenix")
|
|
|
+
|
|
|
+ result = operations.get_sql_type_mapping(schema)
|
|
|
+
|
|
|
+ # Test Phoenix-specific unsigned integer mappings
|
|
|
+ assert result["UInt8"] == "UNSIGNED_TINYINT"
|
|
|
+ assert result["UInt16"] == "UNSIGNED_SMALLINT"
|
|
|
+ assert result["UInt32"] == "UNSIGNED_INT"
|
|
|
+ assert result["UInt64"] == "UNSIGNED_LONG"
|
|
|
+
|
|
|
+ # Test other Phoenix-specific overrides
|
|
|
+ assert result["Utf8"] == "VARCHAR" # Not STRING
|
|
|
+ assert result["String"] == "VARCHAR" # Not STRING
|
|
|
+ assert result["Binary"] == "VARBINARY" # Not BINARY
|
|
|
+ assert result["Duration"] == "STRING" # Phoenix treats durations as strings
|
|
|
+ assert result["Struct"] == "STRING" # No native STRUCT type
|
|
|
+ assert result["Object"] == "VARCHAR" # Not STRING
|
|
|
+ assert result["Time"] == "TIME" # Phoenix has its own TIME type
|
|
|
+ assert result["Decimal"] == "DECIMAL"
|
|
|
+
|
|
|
+ # Test signed integers (use base mapping)
|
|
|
+ assert result["Int8"] == "TINYINT"
|
|
|
+ assert result["Int16"] == "SMALLINT"
|
|
|
+ assert result["Int32"] == "INT"
|
|
|
+ assert result["Int64"] == "BIGINT"
|
|
|
|
|
|
- # Check some key mappings for Phoenix
|
|
|
- assert mappings["UInt32"] == "UNSIGNED_INT"
|
|
|
- assert mappings["Utf8"] == "VARCHAR"
|
|
|
- assert mappings["Time"] == "TIME"
|
|
|
- assert mappings["Struct"] == "STRING" # Phoenix treats structs as strings
|
|
|
- assert mappings["Duration"] == "STRING" # Phoenix treats durations as strings
|
|
|
+ # Test other types that remain the same
|
|
|
+ assert result["Float32"] == "FLOAT"
|
|
|
+ assert result["Float64"] == "DOUBLE"
|
|
|
+ assert result["Boolean"] == "BOOLEAN"
|
|
|
+ assert result["Date"] == "DATE"
|
|
|
+ assert result["Datetime"] == "TIMESTAMP"
|
|
|
|
|
|
def test_get_sql_type_mapping_impala(self):
|
|
|
- result = operations.get_sql_type_mapping("impala")
|
|
|
+ # Create schema object
|
|
|
+ schema = SqlTypeMapperSchema(sql_dialect="impala")
|
|
|
+
|
|
|
+ result = operations.get_sql_type_mapping(schema)
|
|
|
|
|
|
- # Impala uses the base mappings, so check those
|
|
|
+ # Impala uses all base mappings (no overrides)
|
|
|
+ # Test a comprehensive set to ensure no overrides are applied
|
|
|
+ assert result["Int8"] == "TINYINT"
|
|
|
+ assert result["Int16"] == "SMALLINT"
|
|
|
assert result["Int32"] == "INT"
|
|
|
assert result["Int64"] == "BIGINT"
|
|
|
+ assert result["UInt8"] == "TINYINT"
|
|
|
+ assert result["UInt16"] == "SMALLINT"
|
|
|
+ assert result["UInt32"] == "INT"
|
|
|
+ assert result["UInt64"] == "BIGINT"
|
|
|
+ assert result["Float32"] == "FLOAT"
|
|
|
assert result["Float64"] == "DOUBLE"
|
|
|
+ assert result["Decimal"] == "DECIMAL"
|
|
|
+ assert result["Boolean"] == "BOOLEAN"
|
|
|
assert result["Utf8"] == "STRING"
|
|
|
+ assert result["String"] == "STRING"
|
|
|
+ assert result["Binary"] == "BINARY"
|
|
|
+ assert result["Date"] == "DATE"
|
|
|
+ assert result["Time"] == "TIMESTAMP"
|
|
|
+ assert result["Datetime"] == "TIMESTAMP"
|
|
|
+ assert result["Duration"] == "INTERVAL DAY TO SECOND"
|
|
|
+ assert result["Array"] == "ARRAY"
|
|
|
+ assert result["Struct"] == "STRUCT"
|
|
|
+ assert result["Object"] == "STRING"
|
|
|
|
|
|
def test_get_sql_type_mapping_sparksql(self):
|
|
|
- result = operations.get_sql_type_mapping("sparksql")
|
|
|
+ # Create schema object
|
|
|
+ schema = SqlTypeMapperSchema(sql_dialect="sparksql")
|
|
|
|
|
|
- # SparkSQL uses the base mappings, so check those
|
|
|
+ result = operations.get_sql_type_mapping(schema)
|
|
|
+
|
|
|
+ # SparkSQL uses all base mappings (no overrides)
|
|
|
+ # Test a comprehensive set to ensure no overrides are applied
|
|
|
+ assert result["Int8"] == "TINYINT"
|
|
|
+ assert result["Int16"] == "SMALLINT"
|
|
|
assert result["Int32"] == "INT"
|
|
|
assert result["Int64"] == "BIGINT"
|
|
|
+ assert result["UInt8"] == "TINYINT"
|
|
|
+ assert result["UInt16"] == "SMALLINT"
|
|
|
+ assert result["UInt32"] == "INT"
|
|
|
+ assert result["UInt64"] == "BIGINT"
|
|
|
+ assert result["Float32"] == "FLOAT"
|
|
|
assert result["Float64"] == "DOUBLE"
|
|
|
+ assert result["Decimal"] == "DECIMAL"
|
|
|
+ assert result["Boolean"] == "BOOLEAN"
|
|
|
assert result["Utf8"] == "STRING"
|
|
|
-
|
|
|
- def test_get_sql_type_mapping_unsupported_dialect(self):
|
|
|
- with pytest.raises(ValueError, match="Unsupported dialect: mysql"):
|
|
|
- operations.get_sql_type_mapping("mysql")
|
|
|
+ assert result["String"] == "STRING"
|
|
|
+ assert result["Binary"] == "BINARY"
|
|
|
+ assert result["Date"] == "DATE"
|
|
|
+ assert result["Time"] == "TIMESTAMP"
|
|
|
+ assert result["Datetime"] == "TIMESTAMP"
|
|
|
+ assert result["Duration"] == "INTERVAL DAY TO SECOND"
|
|
|
+ assert result["Array"] == "ARRAY"
|
|
|
+ assert result["Struct"] == "STRUCT"
|
|
|
+ assert result["Object"] == "STRING"
|
|
|
+
|
|
|
+ def test_get_sql_type_mapping_all_dialects_consistency(self):
|
|
|
+ # Test that all dialects return mappings for all base types
|
|
|
+ dialects = ["hive", "impala", "sparksql", "trino", "phoenix"]
|
|
|
+ base_types = [
|
|
|
+ "Int8",
|
|
|
+ "Int16",
|
|
|
+ "Int32",
|
|
|
+ "Int64",
|
|
|
+ "UInt8",
|
|
|
+ "UInt16",
|
|
|
+ "UInt32",
|
|
|
+ "UInt64",
|
|
|
+ "Float32",
|
|
|
+ "Float64",
|
|
|
+ "Decimal",
|
|
|
+ "Boolean",
|
|
|
+ "Utf8",
|
|
|
+ "String",
|
|
|
+ "Categorical",
|
|
|
+ "Enum",
|
|
|
+ "Binary",
|
|
|
+ "Date",
|
|
|
+ "Time",
|
|
|
+ "Datetime",
|
|
|
+ "Duration",
|
|
|
+ "Array",
|
|
|
+ "List",
|
|
|
+ "Struct",
|
|
|
+ "Object",
|
|
|
+ "Null",
|
|
|
+ "Unknown",
|
|
|
+ ]
|
|
|
+
|
|
|
+ for dialect in dialects:
|
|
|
+ schema = SqlTypeMapperSchema(sql_dialect=dialect)
|
|
|
+ result = operations.get_sql_type_mapping(schema)
|
|
|
+
|
|
|
+ # Ensure all base types have mappings
|
|
|
+ for base_type in base_types:
|
|
|
+ assert base_type in result, f"Missing mapping for {base_type} in {dialect} dialect"
|
|
|
+ assert isinstance(result[base_type], str), f"Invalid mapping type for {base_type} in {dialect} dialect"
|
|
|
+ assert len(result[base_type]) > 0, f"Empty mapping for {base_type} in {dialect} dialect"
|
|
|
|
|
|
def test_map_polars_dtype_to_sql_type(self):
|
|
|
- # Test with Hive dialect
|
|
|
- assert operations._map_polars_dtype_to_sql_type("hive", "Int64") == "BIGINT"
|
|
|
- assert operations._map_polars_dtype_to_sql_type("hive", "Float32") == "FLOAT"
|
|
|
-
|
|
|
- # Test with Trino dialect
|
|
|
- assert operations._map_polars_dtype_to_sql_type("trino", "Int64") == "BIGINT"
|
|
|
+ # Test comprehensive type mapping for each dialect
|
|
|
+
|
|
|
+ # Hive dialect tests
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("hive", "Int8") == "TINYINT"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("hive", "Int32") == "INT"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("hive", "Float64") == "DOUBLE"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("hive", "Utf8") == "STRING"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("hive", "Boolean") == "BOOLEAN"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("hive", "Date") == "DATE"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("hive", "Array") == "ARRAY"
|
|
|
+
|
|
|
+ # Trino dialect tests (with overrides)
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("trino", "Int32") == "INTEGER"
|
|
|
assert operations._map_polars_dtype_to_sql_type("trino", "Float32") == "REAL"
|
|
|
-
|
|
|
- # Test unsupported type
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("trino", "Utf8") == "VARCHAR"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("trino", "Binary") == "VARBINARY"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("trino", "Struct") == "ROW"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("trino", "Object") == "JSON"
|
|
|
+
|
|
|
+ # Phoenix dialect tests (with unsigned types)
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("phoenix", "UInt8") == "UNSIGNED_TINYINT"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("phoenix", "UInt32") == "UNSIGNED_INT"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("phoenix", "UInt64") == "UNSIGNED_LONG"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("phoenix", "Time") == "TIME"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("phoenix", "Duration") == "STRING"
|
|
|
+ assert operations._map_polars_dtype_to_sql_type("phoenix", "Struct") == "STRING"
|
|
|
+
|
|
|
+ # Test error for unknown type
|
|
|
with pytest.raises(ValueError, match="No mapping for Polars dtype"):
|
|
|
- operations._map_polars_dtype_to_sql_type("hive", "NonExistentType")
|
|
|
+ operations._map_polars_dtype_to_sql_type("hive", "UnknownType")
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("cleanup_temp_files")
|