瀏覽代碼

HUE-8737 [beeswax] Fix unit test test_column_format_values_nulls

Ying Chen 5 年之前
父節點
當前提交
b91c599871

+ 3 - 2
apps/beeswax/src/beeswax/server/hive_server2_lib.py

@@ -30,7 +30,7 @@ from operator import itemgetter
 
 from django.utils.translation import ugettext as _
 
-from desktop.lib import thrift_util
+from desktop.lib import python_util, thrift_util
 from desktop.conf import DEFAULT_USER
 from beeswax import conf
 
@@ -310,7 +310,8 @@ class HiveServerTColumnValue2(object):
     if sys.version_info[0] < 3 or isinstance(bytestring, bytes):
       mask = bytearray(bytestring)
     else:
-      mask = bytearray(bytestring, 'utf-8')
+      bitstring = python_util.from_string_to_bits(bytestring)
+      mask = python_util.get_bytes_from_bits(bitstring)
 
     for n in mask:
       yield n & 0x01

+ 36 - 0
apps/beeswax/src/beeswax/tests.py

@@ -60,6 +60,7 @@ from desktop.redaction import logfilter
 from desktop.redaction.engine import RedactionPolicy, RedactionRule
 from desktop.lib.django_test_util import make_logged_in_client, assert_equal_mod_whitespace
 from desktop.lib.parameterization import substitute_variables
+from desktop.lib.python_util import from_string_to_bits, get_bytes_from_bits
 from desktop.lib.test_utils import grant_access, add_to_group
 from desktop.lib.security_util import get_localhost_name
 from desktop.lib.test_export_csvxls import _read_xls_sheet_data
@@ -2704,6 +2705,41 @@ class TestHiveServer2API(object):
     assert_false(data is HiveServerTColumnValue2.set_nulls(data, nulls))
 
 
+  def test_bits_to_bytes_conversion(self):
+    if sys.version_info[0] < 3:
+      raise SkipTest
+
+    nulls = '\x00'
+    bitstring = from_string_to_bits(nulls)
+    assert_equal('00000000', bitstring)
+    assert_equal([0, 0], get_bytes_from_bits(bitstring))
+
+    nulls = '\x03'
+    bitstring = from_string_to_bits(nulls)
+    assert_equal('00000011', bitstring)
+    assert_equal([3, 0], get_bytes_from_bits(bitstring))
+
+    nulls = 't'
+    bitstring = from_string_to_bits(nulls)
+    assert_equal('01110100', bitstring)
+    assert_equal([116, 0], get_bytes_from_bits(bitstring))
+
+    nulls = '\xff\xee\x03'
+    bitstring = from_string_to_bits(nulls)
+    assert_equal('111111111110111000000011', bitstring)
+    assert_equal([255, 238, 3, 0], get_bytes_from_bits(bitstring))
+
+    nulls = '\x41'
+    bitstring = from_string_to_bits(nulls)
+    assert_equal('01000001', bitstring)
+    assert_equal([65, 0], get_bytes_from_bits(bitstring))
+
+    nulls = '\x01\x23\x45\x67\x89\xab\xcd\xef'
+    bitstring = from_string_to_bits(nulls)
+    assert_equal('0000000100100011010001010110011110001001101010111100110111101111', bitstring)
+    assert_equal([1, 35, 69, 103, 137, 171, 205, 239, 0], get_bytes_from_bits(bitstring))
+
+
 class MockDbms(object):
 
   def __init__(self, client, server_type):

+ 13 - 0
desktop/core/src/desktop/lib/python_util.py

@@ -156,6 +156,19 @@ def force_dict_to_strings(dictionary):
 
   return new_dict
 
+
+def from_string_to_bits(str_value):
+  return ''.join(format(ord(byte), '08b') for byte in str_value)
+
+
+def get_bytes_from_bits(bit_string):
+  """
+  This should be used in py3 or above
+  """
+  padded_bits = bit_string + '0' * (8 - len(bit_string) % 8)
+  return list(int(padded_bits, 2).to_bytes(len(padded_bits) // 8, 'big'))
+
+
 def isASCII(data):
   try:
     data.decode('ASCII')