encoding.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from __future__ import unicode_literals
  5. import array
  6. import io
  7. import math
  8. import os
  9. import struct
  10. import logging
  11. import thriftpy
  12. THRIFT_FILE = os.path.join(os.path.dirname(__file__), "parquet.thrift")
  13. parquet_thrift = thriftpy.load(THRIFT_FILE, module_name=str("parquet_thrift"))
  14. logger = logging.getLogger("parquet")
  15. def read_plain_boolean(fo, count):
  16. """Reads `count` booleans using the plain encoding"""
  17. # for bit packed, the count is stored shifted up. But we want to pass in a count,
  18. # so we shift up.
  19. # bit width is 1 for a single-bit boolean.
  20. return read_bitpacked(fo, count << 1, 1, logger.isEnabledFor(logging.DEBUG))
  21. def read_plain_int32(fo, count):
  22. """Reads `count` 32-bit ints using the plain encoding"""
  23. length = 4 * count
  24. data = fo.read(length)
  25. if len(data) != length:
  26. raise EOFError("Expected {} bytes but got {0} bytes".format(length, len(data)))
  27. res = struct.unpack(b"<{0}i".format(count), data)
  28. return res
  29. def read_plain_int64(fo, count):
  30. """Reads `count` 64-bit ints using the plain encoding"""
  31. return struct.unpack(b"<{0}q".format(count), fo.read(8 * count))
  32. def read_plain_int96(fo, count):
  33. """Reads `count` 96-bit ints using the plain encoding"""
  34. items = struct.unpack(b"<qi" * count, fo.read(12) * count)
  35. args = [iter(items)] * 2
  36. return [q << 32 | i for (q, i) in zip(*args)]
  37. def read_plain_float(fo, count):
  38. """Reads `count` 32-bit floats using the plain encoding"""
  39. return struct.unpack(b"<{0}f".format(count), fo.read(4 * count))
  40. def read_plain_double(fo, count):
  41. """Reads `count` 64-bit float (double) using the plain encoding"""
  42. return struct.unpack(b"<{0}d".format(count), fo.read(8 * count))
  43. def read_plain_byte_array(fo, count):
  44. """Read `count` byte arrays using the plain encoding"""
  45. return [fo.read(struct.unpack(b"<i", fo.read(4))[0]) for i in range(count)]
  46. def read_plain_byte_array_fixed(fo, fixed_length):
  47. """Reads a byte array of the given fixed_length"""
  48. return fo.read(fixed_length)
  49. DECODE_PLAIN = {
  50. parquet_thrift.Type.BOOLEAN: read_plain_boolean,
  51. parquet_thrift.Type.INT32: read_plain_int32,
  52. parquet_thrift.Type.INT64: read_plain_int64,
  53. parquet_thrift.Type.INT96: read_plain_int96,
  54. parquet_thrift.Type.FLOAT: read_plain_float,
  55. parquet_thrift.Type.DOUBLE: read_plain_double,
  56. parquet_thrift.Type.BYTE_ARRAY: read_plain_byte_array,
  57. parquet_thrift.Type.FIXED_LEN_BYTE_ARRAY: read_plain_byte_array_fixed
  58. }
  59. def read_plain(fo, type_, count):
  60. """Reads `count` items `type` from the fo using the plain encoding."""
  61. if count == 0:
  62. return []
  63. conv = DECODE_PLAIN[type_]
  64. return conv(fo, count)
  65. def read_unsigned_var_int(fo):
  66. result = 0
  67. shift = 0
  68. while True:
  69. byte = struct.unpack(b"<B", fo.read(1))[0]
  70. result |= ((byte & 0x7F) << shift)
  71. if (byte & 0x80) == 0:
  72. break
  73. shift += 7
  74. return result
  75. def read_rle(fo, header, bit_width, debug_logging):
  76. """Read a run-length encoded run from the given fo with the given header
  77. and bit_width.
  78. The count is determined from the header and the width is used to grab the
  79. value that's repeated. Yields the value repeated count times.
  80. """
  81. count = header >> 1
  82. zero_data = b"\x00\x00\x00\x00"
  83. width = (bit_width + 7) // 8
  84. data = fo.read(width)
  85. data = data + zero_data[len(data):]
  86. value = struct.unpack(b"<i", data)[0]
  87. if debug_logging:
  88. logger.debug("Read RLE group with value %s of byte-width %s and count %s",
  89. value, width, count)
  90. for i in range(count):
  91. yield value
  92. def width_from_max_int(value):
  93. """Converts the value specified to a bit_width."""
  94. return int(math.ceil(math.log(value + 1, 2)))
  95. def _mask_for_bits(i):
  96. """Helper function for read_bitpacked to generage a mask to grab i bits."""
  97. return (1 << i) - 1
  98. def read_bitpacked(fo, header, width, debug_logging):
  99. """Reads a bitpacked run of the rle/bitpack hybrid.
  100. Supports width >8 (crossing bytes).
  101. """
  102. num_groups = header >> 1
  103. count = num_groups * 8
  104. byte_count = (width * count) // 8
  105. if debug_logging:
  106. logger.debug("Reading a bit-packed run with: %s groups, count %s, bytes %s",
  107. num_groups, count, byte_count)
  108. raw_bytes = array.array(str('B'), fo.read(byte_count)).tolist()
  109. current_byte = 0
  110. b = raw_bytes[current_byte]
  111. mask = _mask_for_bits(width)
  112. bits_wnd_l = 8
  113. bits_wnd_r = 0
  114. res = []
  115. total = len(raw_bytes)*8;
  116. while (total >= width):
  117. # TODO zero-padding could produce extra zero-values
  118. if debug_logging:
  119. logger.debug(" read bitpacked: width=%s window=(%s %s) b=%s,"
  120. " current_byte=%s",
  121. width, bits_wnd_l, bits_wnd_r, bin(b), current_byte)
  122. if bits_wnd_r >= 8:
  123. bits_wnd_r -= 8
  124. bits_wnd_l -= 8
  125. b >>= 8
  126. elif bits_wnd_l - bits_wnd_r >= width:
  127. res.append((b >> bits_wnd_r) & mask)
  128. total -= width
  129. bits_wnd_r += width
  130. if debug_logging:
  131. logger.debug(" read bitpackage: added: %s", res[-1])
  132. elif current_byte + 1 < len(raw_bytes):
  133. current_byte += 1
  134. b |= (raw_bytes[current_byte] << bits_wnd_l)
  135. bits_wnd_l += 8
  136. return res
  137. def read_bitpacked_deprecated(fo, byte_count, count, width, debug_logging):
  138. raw_bytes = array.array(str('B'), fo.read(byte_count)).tolist()
  139. mask = _mask_for_bits(width)
  140. index = 0
  141. res = []
  142. word = 0
  143. bits_in_word = 0
  144. while len(res) < count and index <= len(raw_bytes):
  145. if debug_logging:
  146. logger.debug("index = %d", index)
  147. logger.debug("bits in word = %d", bits_in_word)
  148. logger.debug("word = %s", bin(word))
  149. if bits_in_word >= width:
  150. # how many bits over the value is stored
  151. offset = (bits_in_word - width)
  152. # figure out the value
  153. value = (word & (mask << offset)) >> offset
  154. if debug_logging:
  155. logger.debug("offset = %d", offset)
  156. logger.debug("value = %d (%s)", value, bin(value))
  157. res.append(value)
  158. bits_in_word -= width
  159. else:
  160. word = (word << 8) | raw_bytes[index]
  161. index += 1
  162. bits_in_word += 8
  163. return res
  164. def read_rle_bit_packed_hybrid(fo, width, length=None):
  165. """Implemenation of a decoder for the rel/bit-packed hybrid encoding.
  166. If length is not specified, then a 32-bit int is read first to grab the
  167. length of the encoded data.
  168. """
  169. debug_logging = logger.isEnabledFor(logging.DEBUG)
  170. io_obj = fo
  171. if length is None:
  172. length = read_plain_int32(fo, 1)[0]
  173. raw_bytes = fo.read(length)
  174. if raw_bytes == b'':
  175. return None
  176. io_obj = io.BytesIO(raw_bytes)
  177. res = []
  178. while io_obj.tell() < length:
  179. header = read_unsigned_var_int(io_obj)
  180. if header & 1 == 0:
  181. res += read_rle(io_obj, header, width, debug_logging)
  182. else:
  183. res += read_bitpacked(io_obj, header, width, debug_logging)
  184. return res