소스 검색

HUE-8737 [indexer] Fix indexer unit tests in py3

Ying Chen 5 년 전
부모
커밋
3f5b11633e

+ 5 - 5
apps/beeswax/src/beeswax/templates/create_table_statement.mako

@@ -65,16 +65,16 @@ PARTITIONED BY ${column_list(partition_columns)|n}
 ## TODO: CLUSTERED BY here
 ## TODO: CLUSTERED BY here
 ## TODO: SORTED BY...INTO...BUCKETS here
 ## TODO: SORTED BY...INTO...BUCKETS here
 ROW FORMAT \
 ROW FORMAT \
-% if table.has_key('row_format'):
+% if 'row_format' in table:
 %   if table["row_format"] == "Delimited":
 %   if table["row_format"] == "Delimited":
   DELIMITED
   DELIMITED
-%     if table.has_key('field_terminator'):
+%     if 'field_terminator' in table:
     FIELDS TERMINATED BY '${table["field_terminator"] | n}'
     FIELDS TERMINATED BY '${table["field_terminator"] | n}'
 %     endif
 %     endif
-%     if table.has_key('collection_terminator'):
+%     if 'collection_terminator' in table:
     COLLECTION ITEMS TERMINATED BY '${table["collection_terminator"] | n}'
     COLLECTION ITEMS TERMINATED BY '${table["collection_terminator"] | n}'
 %     endif
 %     endif
-%     if table.has_key('map_key_terminator'):
+%     if 'map_key_terminator' in table:
     MAP KEYS TERMINATED BY '${table["map_key_terminator"] | n}'
     MAP KEYS TERMINATED BY '${table["map_key_terminator"] | n}'
 %     endif
 %     endif
 %   else:
 %   else:
@@ -84,7 +84,7 @@ ROW FORMAT \
 %     endif
 %     endif
 %   endif
 %   endif
 % endif
 % endif
-% if table.has_key('file_format'):
+% if 'file_format' in table:
   STORED AS ${table["file_format"] | n} \
   STORED AS ${table["file_format"] | n} \
 % endif
 % endif
 % if table.get("file_format") == "InputFormat":
 % if table.get("file_format") == "InputFormat":

+ 4 - 4
desktop/libs/indexer/src/indexer/file_format.py

@@ -312,9 +312,9 @@ class CSVFormat(FileFormat):
   _extensions = ["csv", "tsv"]
   _extensions = ["csv", "tsv"]
 
 
   def __init__(self, delimiter=',', line_terminator='\n', quote_char='"', has_header=False, sample="", fields=None):
   def __init__(self, delimiter=',', line_terminator='\n', quote_char='"', has_header=False, sample="", fields=None):
-    self._delimiter = delimiter
-    self._line_terminator = line_terminator
-    self._quote_char = quote_char
+    self._delimiter = delimiter if isinstance(delimiter, str) else delimiter.decode('utf-8')
+    self._line_terminator = line_terminator if isinstance(line_terminator, str) else line_terminator.decode('utf-8')
+    self._quote_char = quote_char if isinstance(quote_char, str) else quote_char.decode('utf-8')
     self._has_header = has_header
     self._has_header = has_header
 
 
     # sniffer insists on \r\n even when \n. This is safer and good enough for a preview
     # sniffer insists on \r\n even when \n. This is safer and good enough for a preview
@@ -346,7 +346,7 @@ class CSVFormat(FileFormat):
   @classmethod
   @classmethod
   def _guess_dialect(cls, sample):
   def _guess_dialect(cls, sample):
     sniffer = csv.Sniffer()
     sniffer = csv.Sniffer()
-    dialect = sniffer.sniff(sample)
+    dialect = sniffer.sniff(sample if isinstance(sample, str) else sample.decode('utf-8'))
     has_header = cls._hasHeader(sniffer, sample, dialect)
     has_header = cls._hasHeader(sniffer, sample, dialect)
     return dialect, has_header
     return dialect, has_header
 
 

+ 20 - 11
desktop/libs/indexer/src/indexer/indexers/sql_tests.py

@@ -18,6 +18,7 @@
 
 
 from builtins import object
 from builtins import object
 import json
 import json
+import sys
 
 
 from nose.tools import assert_equal, assert_true
 from nose.tools import assert_equal, assert_true
 
 
@@ -26,6 +27,10 @@ from useradmin.models import User
 
 
 from indexer.indexers.sql import SQLIndexer
 from indexer.indexers.sql import SQLIndexer
 
 
+table_properties_py2 = '"transactional" = "false", "skip.header.line.count" = "1"'
+table_properties_py3 = '"skip.header.line.count" = "1", "transactional" = "false"'
+is_py3 = sys.version_info[0] > 2
+
 
 
 class MockRequest(object):
 class MockRequest(object):
   def __init__(self, fs=None, user=None):
   def __init__(self, fs=None, user=None):
@@ -66,7 +71,7 @@ def test_generate_create_text_table_with_data_partition():
 
 
   assert_true('''USE default;''' in  sql, sql)
   assert_true('''USE default;''' in  sql, sql)
 
 
-  assert_true('''CREATE TABLE `default`.`customer_stats`
+  statement = '''CREATE TABLE `default`.`customer_stats`
 (
 (
   `customers.id` bigint ,
   `customers.id` bigint ,
   `customers.name` string ,
   `customers.name` string ,
@@ -78,8 +83,9 @@ ROW FORMAT   DELIMITED
     FIELDS TERMINATED BY ','
     FIELDS TERMINATED BY ','
     COLLECTION ITEMS TERMINATED BY '\\002'
     COLLECTION ITEMS TERMINATED BY '\\002'
     MAP KEYS TERMINATED BY '\\003'
     MAP KEYS TERMINATED BY '\\003'
-  STORED AS TextFile TBLPROPERTIES("transactional" = "false", "skip.header.line.count" = "1")
-;''' in  sql, sql)
+  STORED AS TextFile TBLPROPERTIES(%s)
+;''' % table_properties_py3 if is_py3 else  table_properties_py2
+  assert_true(statement in sql, sql)
 
 
   assert_true('''LOAD DATA INPATH '/user/romain/customer_stats.csv' INTO TABLE `default`.`customer_stats` PARTITION (new_field_1='AAA');''' in  sql, sql)
   assert_true('''LOAD DATA INPATH '/user/romain/customer_stats.csv' INTO TABLE `default`.`customer_stats` PARTITION (new_field_1='AAA');''' in  sql, sql)
 
 
@@ -93,7 +99,7 @@ def test_generate_create_kudu_table_with_data():
 
 
   assert_true('''DROP TABLE IF EXISTS `default`.`hue__tmp_index_data`;''' in  sql, sql)
   assert_true('''DROP TABLE IF EXISTS `default`.`hue__tmp_index_data`;''' in  sql, sql)
 
 
-  assert_true('''CREATE EXTERNAL TABLE `default`.`hue__tmp_index_data`
+  statement = '''CREATE EXTERNAL TABLE `default`.`hue__tmp_index_data`
 (
 (
   `business_id` string ,
   `business_id` string ,
   `cool` bigint ,
   `cool` bigint ,
@@ -116,7 +122,8 @@ def test_generate_create_kudu_table_with_data():
 ROW FORMAT   DELIMITED
 ROW FORMAT   DELIMITED
     FIELDS TERMINATED BY ','
     FIELDS TERMINATED BY ','
   STORED AS TextFile LOCATION '/A'
   STORED AS TextFile LOCATION '/A'
-TBLPROPERTIES("transactional" = "false", "skip.header.line.count" = "1")''' in  sql, sql)
+TBLPROPERTIES(%s)''' % table_properties_py3 if is_py3 else  table_properties_py2
+  assert_true(statement in sql in sql, sql)
 
 
   assert_true('''CREATE TABLE `default`.`index_data` COMMENT "Big Data"
   assert_true('''CREATE TABLE `default`.`index_data` COMMENT "Big Data"
         PRIMARY KEY (id)
         PRIMARY KEY (id)
@@ -140,7 +147,7 @@ def test_generate_create_parquet_table():
 
 
   assert_true('''USE default;''' in  sql, sql)
   assert_true('''USE default;''' in  sql, sql)
 
 
-  assert_true('''CREATE EXTERNAL TABLE `default`.`hue__tmp_parquet_table`
+  statement = '''CREATE EXTERNAL TABLE `default`.`hue__tmp_parquet_table`
 (
 (
   `acct_client` string ,
   `acct_client` string ,
   `tran_amount` double ,
   `tran_amount` double ,
@@ -152,8 +159,9 @@ def test_generate_create_parquet_table():
     COLLECTION ITEMS TERMINATED BY '\\002'
     COLLECTION ITEMS TERMINATED BY '\\002'
     MAP KEYS TERMINATED BY '\\003'
     MAP KEYS TERMINATED BY '\\003'
   STORED AS TextFile LOCATION '/user/hue/data'
   STORED AS TextFile LOCATION '/user/hue/data'
-TBLPROPERTIES("transactional" = "false", "skip.header.line.count" = "1")
-;''' in  sql, sql)
+TBLPROPERTIES(%s)
+;''' % table_properties_py3 if is_py3 else  table_properties_py2
+  assert_true(statement in  sql, sql)
 
 
   assert_true('''CREATE TABLE `default`.`parquet_table`
   assert_true('''CREATE TABLE `default`.`parquet_table`
         STORED AS parquet
         STORED AS parquet
@@ -176,7 +184,7 @@ def test_generate_create_orc_table_transactional():
 
 
   assert_true('''USE default;''' in  sql, sql)
   assert_true('''USE default;''' in  sql, sql)
 
 
-  assert_true('''CREATE EXTERNAL TABLE `default`.`hue__tmp_parquet_table`
+  statement = '''CREATE EXTERNAL TABLE `default`.`hue__tmp_parquet_table`
 (
 (
   `acct_client` string ,
   `acct_client` string ,
   `tran_amount` double ,
   `tran_amount` double ,
@@ -188,8 +196,9 @@ def test_generate_create_orc_table_transactional():
     COLLECTION ITEMS TERMINATED BY '\\002'
     COLLECTION ITEMS TERMINATED BY '\\002'
     MAP KEYS TERMINATED BY '\\003'
     MAP KEYS TERMINATED BY '\\003'
   STORED AS TextFile LOCATION '/user/hue/data'
   STORED AS TextFile LOCATION '/user/hue/data'
-TBLPROPERTIES("transactional" = "false", "skip.header.line.count" = "1")
-;''' in  sql, sql)
+TBLPROPERTIES(%s)
+;''' % table_properties_py3 if is_py3 else  table_properties_py2
+  assert_true(statement in sql in  sql, sql)
 
 
   assert_true('''CREATE TABLE `default`.`parquet_table`
   assert_true('''CREATE TABLE `default`.`parquet_table`
         STORED AS orc
         STORED AS orc

+ 2 - 2
desktop/libs/indexer/src/indexer/templates/gen/create_table_statement.mako

@@ -101,7 +101,7 @@ PARTITIONED BY ${ column_list(table, partition_columns) | n }
 ROW FORMAT \
 ROW FORMAT \
 %   if table["row_format"] == "Delimited":
 %   if table["row_format"] == "Delimited":
   DELIMITED
   DELIMITED
-%     if table.has_key('field_terminator'):
+%     if 'field_terminator' in table:
     FIELDS TERMINATED BY '${table["field_terminator"] | n}'
     FIELDS TERMINATED BY '${table["field_terminator"] | n}'
 %     endif
 %     endif
 ## [LINES TERMINATED BY char]
 ## [LINES TERMINATED BY char]
@@ -118,7 +118,7 @@ ROW FORMAT \
 %     endif
 %     endif
 %   endif
 %   endif
 % endif
 % endif
-% if table.has_key('file_format'):
+% if 'file_format' in table:
   STORED AS ${ table["file_format"] | n } \
   STORED AS ${ table["file_format"] | n } \
 % endif
 % endif
 % if table.get("file_format") == "InputFormat":
 % if table.get("file_format") == "InputFormat":

+ 6 - 1
desktop/libs/indexer/src/indexer/test_utils.py

@@ -20,6 +20,7 @@ from future import standard_library
 standard_library.install_aliases()
 standard_library.install_aliases()
 import sys
 import sys
 
 
+from desktop.lib.i18n import force_unicode
 from nose.tools import assert_equal
 from nose.tools import assert_equal
 
 
 from indexer.utils import field_values_from_separated_file
 from indexer.utils import field_values_from_separated_file
@@ -41,6 +42,10 @@ def test_get_ensemble():
   assert_equal(u'rel=""nofollow"">Twitter for BlackBerry®', result[0]['fieldA'])
   assert_equal(u'rel=""nofollow"">Twitter for BlackBerry®', result[0]['fieldA'])
 
 
   # Bad binary
   # Bad binary
-  data = string_io('fieldA\naaa\x80\x02\x03')
+  test_str = b'fieldA\naaa\x80\x02\x03'
+  if sys.version_info[0] > 2:
+    data = string_io(force_unicode(test_str, errors='ignore'))
+  else:
+    data = string_io(test_str)
   result = list(field_values_from_separated_file(data, delimiter='\t', quote_character='"'))
   result = list(field_values_from_separated_file(data, delimiter='\t', quote_character='"'))
   assert_equal(u'aaa\x02\x03', result[0]['fieldA'])
   assert_equal(u'aaa\x02\x03', result[0]['fieldA'])