Browse Source

Merge pull request #3 from jaguarx/master

support bitwidth >= 8, and bitpack run with less than 8 values
Joe Crobak 11 years ago
parent
commit
65f7fc9f5c
2 changed files with 35 additions and 36 deletions
  1. 10 9
      parquet/__init__.py
  2. 25 27
      parquet/encoding.py

+ 10 - 9
parquet/__init__.py

@@ -2,7 +2,7 @@ import gzip
 import json
 import logging
 import struct
-import StringIO
+import cStringIO
 import sys
 from collections import defaultdict
 from ttypes import (FileMetaData, CompressionCodec, Encoding,
@@ -204,7 +204,7 @@ def _read_page(fo, page_header, column_metadata):
         if column_metadata.codec == CompressionCodec.SNAPPY:
             raw_bytes = snappy.decompress(bytes_from_file)
         elif column_metadata.codec == CompressionCodec.GZIP:
-            io_obj = StringIO.StringIO(bytes_from_file)
+            io_obj = cStringIO.StringIO(bytes_from_file)
             with gzip.GzipFile(fileobj=io_obj, mode='rb') as f:
                 raw_bytes = f.read()
         else:
@@ -252,7 +252,7 @@ def read_data_page(fo, schema_helper, page_header, column_metadata,
     """
     daph = page_header.data_page_header
     raw_bytes = _read_page(fo, page_header, column_metadata)
-    io_obj = StringIO.StringIO(raw_bytes)
+    io_obj = cStringIO.StringIO(raw_bytes)
     vals = []
 
     logger.debug("  definition_level_encoding: %s",
@@ -276,8 +276,7 @@ def read_data_page(fo, schema_helper, page_header, column_metadata,
                                            daph.num_values,
                                            bit_width)
 
-        logger.debug("  Definition levels: %s",
-                     ",".join([str(dl) for dl in definition_levels]))
+        logger.debug("  Definition levels: %s", len(definition_levels))
 
     # repetition levels are skipped if data is at the first level.
     if len(column_metadata.path_in_schema) > 1:
@@ -294,18 +293,20 @@ def read_data_page(fo, schema_helper, page_header, column_metadata,
         for i in range(daph.num_values):
             vals.append(
                 encoding.read_plain(io_obj, column_metadata.type, None))
-        logger.debug("  Values: %s", ",".join([str(x) for x in vals]))
+        logger.debug("  Values: %s", len(vals))
     elif daph.encoding == Encoding.PLAIN_DICTIONARY:
         # bit_width is stored as single byte.
         bit_width = struct.unpack("<B", io_obj.read(1))[0]
         logger.debug("bit_width: %d", bit_width)
         total_seen = 0
         dict_values_bytes = io_obj.read()
-        dict_values_io_obj = StringIO.StringIO(dict_values_bytes)
+        dict_values_io_obj = cStringIO.StringIO(dict_values_bytes)
         # TODO jcrobak -- not sure that this loop is needed?
         while total_seen < daph.num_values:
             values = encoding.read_rle_bit_packed_hybrid(
                 dict_values_io_obj, bit_width, len(dict_values_bytes))
+            if len(values) + total_seen > daph.num_values:
+                values = values[0: daph.num_values - total_seen]
             vals += [dictionary[v] for v in values]
             total_seen += len(values)
     else:
@@ -316,7 +317,7 @@ def read_data_page(fo, schema_helper, page_header, column_metadata,
 
 def read_dictionary_page(fo, page_header, column_metadata):
     raw_bytes = _read_page(fo, page_header, column_metadata)
-    io_obj = StringIO.StringIO(raw_bytes)
+    io_obj = cStringIO.StringIO(raw_bytes)
     dict_items = []
     while io_obj.tell() < len(raw_bytes):
         # TODO - length for fixed byte array
@@ -360,7 +361,7 @@ def _dump(fo, options, out=sys.stdout):
                     values = read_data_page(fo, schema_helper, ph, cmd,
                                             dict_items)
                     res[".".join(cmd.path_in_schema)] += values
-                    values_seen += cmd.num_values
+                    values_seen += ph.data_page_header.num_values
                 elif ph.type == PageType.DICTIONARY_PAGE:
                     logger.debug(ph)
                     assert dict_items == []

+ 25 - 27
parquet/encoding.py

@@ -1,7 +1,7 @@
 import array
 import math
 import struct
-import StringIO
+import cStringIO
 import logging
 
 from ttypes import Type
@@ -130,41 +130,39 @@ def _mask_for_bits(i):
 def read_bitpacked(fo, header, width):
     """Reads a bitpacked run of the rle/bitpack hybrid.
 
-    Currently only supports width <=8 (doesn't support crossing bytes).
+    Supports width >8 (crossing bytes).
     """
-    assert width <= 8
     num_groups = header >> 1
-    logger.debug("Reading a bit-packed run with: %s groups", num_groups)
     count = num_groups * 8
-    raw_bytes = array.array('B', fo.read(count)).tolist()
+    byte_count = (width * count)/8
+    logger.debug("Reading a bit-packed run with: %s groups, count %s, bytes %s",
+        num_groups, count, byte_count)
+    raw_bytes = array.array('B', fo.read(byte_count)).tolist()
     current_byte = 0
     b = raw_bytes[current_byte]
     mask = _mask_for_bits(width)
-    bits_in_byte = 8
+    bits_wnd_l = 8
+    bits_wnd_r = 0
     res = []
-    while current_byte < len(raw_bytes):
+    total = len(raw_bytes)*8;
+    while (total >= width):
         # TODO zero-padding could produce extra zero-values
-        logger.debug("  read bitpacked: width=%s bits_in_byte=%s b=%s,"
+        logger.debug("  read bitpacked: width=%s window=(%s %s) b=%s,"
                      " current_byte=%s",
-                     width, bits_in_byte, bin(b), current_byte)
-        if bits_in_byte >= width:
-            res.append(b & mask)
-            b >>= width
-            bits_in_byte -= width
-        else:
-            if current_byte + 1 == len(raw_bytes):
-                break  # partial results / padding at the end.
-            next_b = raw_bytes[current_byte + 1]
-            borrowed_bits = next_b & _mask_for_bits(width - bits_in_byte)
-            logger.debug("    borrowing %d bits", width - bits_in_byte)
-            logger.debug("    next_b=%s, borrowed_bits=%s",
-                         bin(next_b), bin(borrowed_bits))
-            res.append((borrowed_bits << bits_in_byte) | b)
-            b = next_b >> (width - bits_in_byte)
-            logger.debug("  shifting away: %d", width - bits_in_byte)
-            bits_in_byte = 8 - (width - bits_in_byte)
+                     width, bits_wnd_l, bits_wnd_r, bin(b), current_byte)
+        if bits_wnd_r >= 8:
+            bits_wnd_r -= 8
+            bits_wnd_l -= 8
+            b >>= 8
+        elif bits_wnd_l - bits_wnd_r >= width:
+            res.append((b >> bits_wnd_r) & mask)
+            total -= width
+            bits_wnd_r += width
+            logger.debug("  read bitpackage: added: %s", res[-1])
+        elif current_byte + 1 < len(raw_bytes):
             current_byte += 1
-        logger.debug("  read bitpackage: added: %s", res[-1])
+            b |= (raw_bytes[current_byte] << bits_wnd_l)
+            bits_wnd_l += 8
     return res
 
 
@@ -210,7 +208,7 @@ def read_rle_bit_packed_hybrid(fo, width, length=None):
         raw_bytes = fo.read(length)
         if raw_bytes == '':
             return None
-        io_obj = StringIO.StringIO(raw_bytes)
+        io_obj = cStringIO.StringIO(raw_bytes)
     res = []
     while io_obj.tell() < length:
         header = read_unsigned_var_int(io_obj)