Browse Source

fix read_bitpacked byte count

Jaguar Xiong 11 years ago
parent
commit
03889c7539
2 changed files with 13 additions and 9 deletions
  1. 7 5
      parquet/__init__.py
  2. 6 4
      parquet/encoding.py

+ 7 - 5
parquet/__init__.py

@@ -2,7 +2,7 @@ import gzip
 import json
 import logging
 import struct
-import StringIO
+import cStringIO
 import sys
 from collections import defaultdict
 from ttypes import (FileMetaData, CompressionCodec, Encoding,
@@ -204,7 +204,7 @@ def _read_page(fo, page_header, column_metadata):
         if column_metadata.codec == CompressionCodec.SNAPPY:
             raw_bytes = snappy.decompress(bytes_from_file)
         elif column_metadata.codec == CompressionCodec.GZIP:
-            io_obj = StringIO.StringIO(bytes_from_file)
+            io_obj = cStringIO.StringIO(bytes_from_file)
             with gzip.GzipFile(fileobj=io_obj, mode='rb') as f:
                 raw_bytes = f.read()
         else:
@@ -252,7 +252,7 @@ def read_data_page(fo, schema_helper, page_header, column_metadata,
     """
     daph = page_header.data_page_header
     raw_bytes = _read_page(fo, page_header, column_metadata)
-    io_obj = StringIO.StringIO(raw_bytes)
+    io_obj = cStringIO.StringIO(raw_bytes)
     vals = []
 
     logger.debug("  definition_level_encoding: %s",
@@ -301,11 +301,13 @@ def read_data_page(fo, schema_helper, page_header, column_metadata,
         logger.debug("bit_width: %d", bit_width)
         total_seen = 0
         dict_values_bytes = io_obj.read()
-        dict_values_io_obj = StringIO.StringIO(dict_values_bytes)
+        dict_values_io_obj = cStringIO.StringIO(dict_values_bytes)
         # TODO jcrobak -- not sure that this loop is needed?
         while total_seen < daph.num_values:
             values = encoding.read_rle_bit_packed_hybrid(
                 dict_values_io_obj, bit_width, len(dict_values_bytes))
+            if len(values) + total_seen > daph.num_values:
+                values = values[0: daph.num_values - total_seen]
             vals += [dictionary[v] for v in values]
             total_seen += len(values)
     else:
@@ -316,7 +318,7 @@ def read_data_page(fo, schema_helper, page_header, column_metadata,
 
 def read_dictionary_page(fo, page_header, column_metadata):
     raw_bytes = _read_page(fo, page_header, column_metadata)
-    io_obj = StringIO.StringIO(raw_bytes)
+    io_obj = cStringIO.StringIO(raw_bytes)
     dict_items = []
     while io_obj.tell() < len(raw_bytes):
         # TODO - length for fixed byte array

+ 6 - 4
parquet/encoding.py

@@ -1,7 +1,7 @@
 import array
 import math
 import struct
-import StringIO
+import cStringIO
 import logging
 
 from ttypes import Type
@@ -133,9 +133,11 @@ def read_bitpacked(fo, header, width):
     Currently only supports width <=8 (doesn't support crossing bytes).
     """
     num_groups = header >> 1
-    logger.debug("Reading a bit-packed run with: %s groups", num_groups)
     count = num_groups * 8
-    raw_bytes = array.array('B', fo.read(count)).tolist()
+    byte_count = (width * count)/8
+    logger.debug("Reading a bit-packed run with: %s groups, count %s, bytes %s",
+        num_groups, count, byte_count)
+    raw_bytes = array.array('B', fo.read(byte_count)).tolist()
     current_byte = 0
     b = raw_bytes[current_byte]
     mask = _mask_for_bits(width)
@@ -206,7 +208,7 @@ def read_rle_bit_packed_hybrid(fo, width, length=None):
         raw_bytes = fo.read(length)
         if raw_bytes == '':
             return None
-        io_obj = StringIO.StringIO(raw_bytes)
+        io_obj = cStringIO.StringIO(raw_bytes)
     res = []
     while io_obj.tell() < length:
         header = read_unsigned_var_int(io_obj)