Browse Source

fixes for definition, repetition-levels and dicts

Joe Crobak 12 years ago
parent
commit
81f9c9df84
8 changed files with 425 additions and 20 deletions
  1. 150 8
      parquet/__init__.py
  2. 2 1
      parquet/__main__.py
  3. 22 0
      parquet/bitstring.py
  4. 151 0
      parquet/encoding.py
  5. 40 0
      parquet/schema.py
  6. 12 9
      setup.py
  7. 46 0
      test/test_encoding.py
  8. 2 2
      test/test_read_support.py

+ 150 - 8
parquet/__init__.py

@@ -1,12 +1,24 @@
+import gzip
+import json
+import logging
 import struct
 import struct
 import thrift
 import thrift
-import logging
-from ttypes import FileMetaData, CompressionCodec, Encoding, PageHeader, PageType, Type
+import StringIO
+from collections import defaultdict
+from ttypes import FileMetaData, CompressionCodec, Encoding, FieldRepetitionType, PageHeader, PageType, Type
 from thrift.protocol import TCompactProtocol
 from thrift.protocol import TCompactProtocol
 from thrift.transport import TTransport
 from thrift.transport import TTransport
+import encoding
+import schema
+
 
 
 logger = logging.getLogger("parquet")
 logger = logging.getLogger("parquet")
 
 
+try:
+    import snappy
+except ImportError:
+    logger.warn("Couldn't import snappy. Support for snappy compression disabled.")
+
 class ParquetFormatException(Exception):
 class ParquetFormatException(Exception):
     pass
     pass
 
 
@@ -42,9 +54,8 @@ def _read_footer(fo):
     return fmd
     return fmd
 
 
 
 
-def _read_page_header(fo, offset):
-    """Reads the page_header at the given offset"""
-    fo.seek(offset, 0)
+def _read_page_header(fo):
+    """Reads the page_header from the given fo"""
     tin = TTransport.TFileObjectTransport(fo)
     tin = TTransport.TFileObjectTransport(fo)
     pin = TCompactProtocol.TCompactProtocol(tin)
     pin = TCompactProtocol.TCompactProtocol(tin)
     ph = PageHeader()
     ph = PageHeader()
@@ -73,7 +84,8 @@ def dump_metadata(filename):
     print("  schema: ")
     print("  schema: ")
     for se in footer.schema:
     for se in footer.schema:
         print("    {name} ({type}): length={type_length}, repetition={repetition_type}, children={num_children}, converted_type={converted_type}".format(
         print("    {name} ({type}): length={type_length}, repetition={repetition_type}, children={num_children}, converted_type={converted_type}".format(
-            name=se.name, type=Type._VALUES_TO_NAMES[se.type] if se.type else None, type_length=se.type_length, repetition_type=se.repetition_type,
+            name=se.name, type=Type._VALUES_TO_NAMES[se.type] if se.type else None, type_length=se.type_length,
+            repetition_type=FieldRepetitionType._VALUES_TO_NAMES[se.repetition_type] if se.repetition_type else None,
             num_children=se.num_children, converted_type=se.converted_type))
             num_children=se.num_children, converted_type=se.converted_type))
     print("  row groups: ")
     print("  row groups: ")
     for rg in footer.row_groups:
     for rg in footer.row_groups:
@@ -96,10 +108,12 @@ def dump_metadata(filename):
                     ))
                     ))
             with open(filename, 'rb') as fo:
             with open(filename, 'rb') as fo:
                 offset = cmd.data_page_offset if (cmd.dictionary_page_offset is None or cmd.data_page_offset < cmd.dictionary_page_offset) else cmd.dictionary_page_offset
                 offset = cmd.data_page_offset if (cmd.dictionary_page_offset is None or cmd.data_page_offset < cmd.dictionary_page_offset) else cmd.dictionary_page_offset
+                fo.seek(offset, 0)
                 values_read = 0
                 values_read = 0
                 print("      pages: ")
                 print("      pages: ")
                 while values_read < num_rows:
                 while values_read < num_rows:
-                    ph = _read_page_header(fo, offset)
+                    ph = _read_page_header(fo)
+                    fo.seek(ph.compressed_page_size, 1) # seek past current page.
                     daph = ph.data_page_header
                     daph = ph.data_page_header
                     diph = ph.dictionary_page_header
                     diph = ph.dictionary_page_header
                     type_ = PageType._VALUES_TO_NAMES[ph.type] if ph.type else None
                     type_ = PageType._VALUES_TO_NAMES[ph.type] if ph.type else None
@@ -109,7 +123,8 @@ def dump_metadata(filename):
                         num_values = daph.num_values
                         num_values = daph.num_values
                         values_read += num_values
                         values_read += num_values
                     if ph.type == PageType.DICTIONARY_PAGE:
                     if ph.type == PageType.DICTIONARY_PAGE:
-                        num_values = diph.num_values
+                        pass
+                        #num_values = diph.num_values
 
 
                     encoding = None
                     encoding = None
                     def_level_encoding = None
                     def_level_encoding = None
@@ -127,3 +142,130 @@ def dump_metadata(filename):
                             encoding=encoding, def_level_encoding=def_level_encoding,
                             encoding=encoding, def_level_encoding=def_level_encoding,
                             rep_level_encoding=rep_level_encoding))
                             rep_level_encoding=rep_level_encoding))
 
 
+def _read_page(fo, page_header, column_metadata):
+    """Reads the data page from the given file-object using the column metadata"""
+    bytes_from_file = fo.read(page_header.compressed_page_size)
+    if column_metadata.codec is not None and column_metadata.codec != CompressionCodec.UNCOMPRESSED:
+        if column_metadata.codec == CompressionCodec.SNAPPY:
+            raw_bytes = snappy.decompress(bytes_from_file)
+        elif column_metadata.codec == CompressionCodec.GZIP:
+            io_obj = StringIO.StringIO(bytes_from_file)
+            with gzip.GzipFile(fileobj=io_obj, mode='rb') as f:
+                raw_bytes = f.read()
+    else:
+        raw_bytes = bytes_from_file
+    return raw_bytes
+
+
+def _read_data(fo, fo_encoding, value_count, bit_width):
+    """Internal method to read data from the file-object using the given encoding. The data
+    could be definition levels, repetition levels, or actual values."""
+    vals = []
+    if fo_encoding == Encoding.RLE:
+        seen = 0
+        while seen < value_count:
+            values = encoding.read_rle_bit_packed_hybrid(fo, bit_width)
+            if values is None:
+                break  ## EOF was reached.
+            vals += values
+            seen += len(values)
+    elif fo_encoding == Encoding.BIT_PACKED:
+        raise NotImplementedError("Bit packing not yet supported")
+
+    return vals
+
+
+def read_data_page(fo, schema_helper, page_header, column_metadata, dictionary):
+    daph = page_header.data_page_header
+    raw_bytes = _read_page(fo, page_header, column_metadata)
+    io_obj = StringIO.StringIO(raw_bytes)
+    vals = []
+
+    print("  definition_level_encoding: {0}".format(Encoding._VALUES_TO_NAMES[daph.definition_level_encoding]))
+    print("  repetition_level_encoding: {0}".format(Encoding._VALUES_TO_NAMES[daph.repetition_level_encoding]))
+    print("  encoding: {0}".format(Encoding._VALUES_TO_NAMES[daph.encoding]) )
+
+    # definition levels are skipped if data is required.
+    if not schema_helper.is_required(column_metadata.path_in_schema[-1]):
+        max_definition_level = schema_helper.max_definition_level(column_metadata.path_in_schema)
+        bit_width = encoding.width_from_max_int(max_definition_level) # TODO Where does the -1 come from?
+        print "  max def level: {1}   bit_width: {0}".format(bit_width, max_definition_level)
+        if bit_width == 0:
+            definition_levels = [0] * daph.num_values
+        else:
+            definition_levels = _read_data(io_obj, daph.definition_level_encoding, daph.num_values, bit_width)
+
+        print ("  Definition levels: {0}".format(",".join([str(dl) for dl in definition_levels])))
+    
+    # repetition levels are skipped if data is at the first level.
+    if len(column_metadata.path_in_schema) > 1:
+        max_repetition_level = schema_helper.max_repetition_level(column_metadata.path_in_schema)
+        bit_width = encoding.width_from_max_int(max_repetition_level)
+        repetition_levels = _read_data(io_obj, daph.repetition_level_encoding, daph.num_values)
+
+    # TODO Actually use the definition and repetition levels.
+
+    if daph.encoding == Encoding.PLAIN:
+        for i in range(daph.num_values):
+            vals.append(encoding.read_plain(io_obj, column_metadata.type, None))
+        print "  Values: " + ",".join([str(x) for x in vals])
+    elif daph.encoding == Encoding.PLAIN_DICTIONARY:
+        bit_width = struct.unpack("<B", io_obj.read(1))[0]  # bitwidth is stored as single byte.
+        print "bit_width: {0}".format(bit_width)
+        total_seen = 0
+        dict_values_bytes = io_obj.read()
+        dict_values_io_obj = StringIO.StringIO(dict_values_bytes)
+        while total_seen < daph.num_values:  # TODO jcrobak -- not sure that this loop i sneeded?
+            values = encoding.read_rle_bit_packed_hybrid(dict_values_io_obj, bit_width, len(dict_values_bytes))
+            vals += [dictionary[v] for v in values]
+            total_seen += len(values)
+    else:
+        raise ParquetFormatException("Unsupported encoding: " + Encoding._VALUES_TO_NAMES[daph.encoding])
+    return vals
+
+
+def read_dictionary_page(fo, page_header, column_metadata):
+    raw_bytes = _read_page(fo, page_header, column_metadata)
+    io_obj = StringIO.StringIO(raw_bytes)
+    dict_items = []
+    while io_obj.tell() < len(raw_bytes):
+        dict_items.append(encoding.read_plain(io_obj, column_metadata.type, None))  # TODO - length for fixed byte array
+    return dict_items
+
+
+def dump(filename, max_records=10):
+    footer = read_footer(filename)
+    schema_helper = schema.SchemaHelper(footer.schema)
+    for rg in footer.row_groups:
+        res = defaultdict(list)
+        row_group_rows = rg.num_rows
+        dict_items = []
+        for idx, cg in enumerate(rg.columns):
+            cmd = cg.meta_data
+            with open(filename, 'rb') as fo:
+                offset = cmd.data_page_offset if (cmd.dictionary_page_offset is None or cmd.data_page_offset < cmd.dictionary_page_offset) else cmd.dictionary_page_offset
+                fo.seek(offset, 0)
+                values_seen = 0
+                print("reading column chunk of type: {0}".format(Type._VALUES_TO_NAMES[cmd.type]))
+                while values_seen < row_group_rows:
+                    ph = _read_page_header(fo)
+                    print("Reading page (type={2}, uncompressed={0} bytes, compressed={1} bytes)".format(
+                        ph.uncompressed_page_size, ph.compressed_page_size, PageType._VALUES_TO_NAMES[ph.type]))
+                    daph = ph.data_page_header
+                    diph = ph.dictionary_page_header
+                    if ph.type == PageType.DATA_PAGE:
+                        values = read_data_page(fo, schema_helper, ph, cmd, dict_items)
+                        res[".".join(cmd.path_in_schema)] += values
+                        values_seen += cmd.num_values
+                    elif ph.type == PageType.DICTIONARY_PAGE:
+                        print ph
+                        assert dict_items == []
+                        dict_items = read_dictionary_page(fo, ph, cmd)
+                        print("Dictionary: " + str(dict_items))
+                    else:
+                        logger.info("Skipping unknown page type={0}".format(ph.type))
+        print "Data for row group: "
+        keys = res.keys()
+        print "\t".join(keys)
+        for i in range(rg.num_rows):
+            print "\t".join(str(res[k][i]) for k in keys)

+ 2 - 1
parquet/__main__.py

@@ -1,3 +1,4 @@
 import parquet
 import parquet
 import sys
 import sys
-parquet.dump_metadata(sys.argv[1])
+parquet.dump_metadata(sys.argv[1])
+parquet.dump(sys.argv[1])

+ 22 - 0
parquet/bitstring.py

@@ -0,0 +1,22 @@
+
+SINGLE_BIT_MASK =  [1 << x for x in range(7, -1, -1)]
+
+]
+
+class BitString(object):
+
+	def __init__(self, bytes, length=None, offset=None):
+		self.bytes = bytes
+		self.offset = offset if offset is not None else 0
+		self.length = length if length is not None else 8 * len(data) - self.offset 
+
+
+	def __getitem__(self, key):
+		try:
+			start = key.start
+			stop = key.stop
+		except AttributeError:
+			if key < 0 or key >= length:
+				raise IndexError()
+			byte_index, bit_offset = divmod(self.offset + key), 8)
+			return self.bytes[byte_index] & SINGLE_BIT_MASK[bit_offset]

+ 151 - 0
parquet/encoding.py

@@ -0,0 +1,151 @@
+import array
+import math
+import struct
+import StringIO
+
+from ttypes import Type
+
+def read_plain_boolean(fo):
+    raise NotImplemented
+
+
+def read_plain_int32(fo):
+    tup = struct.unpack("<i", fo.read(4))
+    return tup[0]
+
+def read_plain_int64(fo):
+    tup = struct.unpack("<q", fo.read(8))
+    return tup[0]
+
+def read_plain_int96(fo):
+    tup = struct.unpack("<q<i", fo.read(12))
+    return tup[0] << 32 | tup[1]
+
+def read_plain_float(fo):
+    tup = struct.unpack("<f", fo.read(4))
+
+def read_plain_double(fo):
+    tup = struct.unpack("<d", fo.read(8))
+
+def read_plain_byte_array(fo):
+    length = read_plain_int32(fo)
+    return fo.read(length)
+
+def read_plain_byte_array_fixed(fo, fixed_length):
+    return fo.read(fixed_length)
+
+DECODE_PLAIN = {
+    Type.BOOLEAN: read_plain_boolean,
+    Type.INT32: read_plain_int32,
+    Type.INT64: read_plain_int64,
+    Type.INT96: read_plain_int96,
+    Type.FLOAT: read_plain_float,
+    Type.DOUBLE: read_plain_double,
+    Type.BYTE_ARRAY: read_plain_byte_array,
+    Type.FIXED_LEN_BYTE_ARRAY: read_plain_byte_array_fixed
+}
+
+def read_plain(fo, type_, type_length):
+    conv = DECODE_PLAIN[type_]
+    if type_ == Type.FIXED_LEN_BYTE_ARRAY:
+        return conv(fo, type_length)
+    return conv(fo)
+
+
+def read_unsigned_var_int(fo):
+    result = 0
+    shift = 0
+    while True:
+        byte = struct.unpack("<B", fo.read(1))[0]
+        result |= ((byte & 0x7F) << shift)
+        if (byte & 0x80) == 0:
+            break
+        shift += 7
+    return result
+
+def byte_width(bit_width):
+    "Returns the byte width for the given bit_width"
+    return (bit_width + 7) / 8;
+
+def read_rle(fo, header, bit_width):
+    """Grabs count from the header and uses width to grab the value that's
+    repeated. Returns an array with the value repeated count times."""
+    count = header >> 1
+    zero_data = "\x00\x00\x00\x00"
+    data = ""
+    width = byte_width(bit_width)
+    if width >= 1:
+        data += fo.read(1)
+    elif width >= 2:
+        data += fo.read(1)
+    elif width >= 3:
+        data +=  fo.read(1)
+    elif width == 4:
+        data = fo.read(1)
+    data = data + zero_data[len(data):]
+    value = struct.unpack("<i", data)[0]
+
+    return [value]*count
+
+
+def width_from_max_int(value):
+    return int(math.ceil(math.log(value + 1, 2)))
+
+def mask_for_bits(i):
+    return (1 << i) - 1
+
+def read_bitpacked(fo, header, width):
+    num_groups = header >> 1;
+    count = num_groups * 8
+    raw_bytes = array.array('B', fo.read(count)).tolist()
+    current_byte = 0
+    b = raw_bytes[current_byte]
+    mask = mask_for_bits(width)
+    bits_in_byte = 8
+    res = []
+    while current_byte < width and len(res) < (count / width):
+        print "width={0} bits_in_byte={1} b={2}".format(width, bits_in_byte, bin(b))
+        if bits_in_byte >= width:
+            res.append(b & mask)
+            b >>= width
+            bits_in_byte -= width
+        else:
+            next_b = raw_bytes[current_byte + 1]
+            borrowed_bits = next_b & mask_for_bits(width - bits_in_byte)
+            #print "  borrowing {0} bites".format(width - bits_in_byte)
+            #print "  next_b={0}, borrowed_bits={1}".format(bin(next_b), bin(borrowed_bits))
+            res.append((borrowed_bits << bits_in_byte) | b)
+            b = next_b >> (width - bits_in_byte)
+            #print "  shifting away: {0}".format(width - bits_in_byte)
+            bits_in_byte = 8 - (width - bits_in_byte)
+            current_byte += 1
+        print "  added: {0}".format(res[-1])
+    return res
+
+
+def read_bitpacked_deprecated(fo, count, width):
+    res = []
+    raw_bytes = array.array('B', fo.read(count)).tolist()
+    current_byte = 0
+    b = raw_bytes[current_byte]
+    mask = mask_for_bits(width)
+
+
+
+def read_rle_bit_packed_hybrid(fo, width, length=None):
+#    import pdb; pdb.set_trace()
+    io_obj = fo
+    if length is None:
+        length = read_plain_int32(fo)
+        raw_bytes = fo.read(length)
+        if raw_bytes == '':
+            return None
+        io_obj = StringIO.StringIO(raw_bytes)
+    res = []
+    while io_obj.tell() < length:
+        header = read_unsigned_var_int(io_obj)
+        if header & 1 == 0:
+            res += read_rle(io_obj, header, width)
+        else:
+            res += read_bitpacked(io_obj, header, width)
+    return res

+ 40 - 0
parquet/schema.py

@@ -0,0 +1,40 @@
+"""Utils for working with the parquet thrift models"""
+
+from ttypes import FieldRepetitionType
+
+
+class SchemaHelper(object):
+
+	def __init__(self, schema_elements):
+		self.schema_elements = schema_elements
+		self.schema_elements_by_name = dict([(se.name, se) for se in schema_elements])
+		assert len(self.schema_elements) == len(self.schema_elements_by_name)
+
+
+	def schema_element(self, name):
+		"""Get the schema element with the given name."""
+		return self.schema_elements_by_name[name]
+
+	def is_required(self, name):
+		"""Returns true iff the schema element with the given name is required"""
+		return self.schema_element(name).repetition_type == FieldRepetitionType.REQUIRED
+
+
+	def max_repetition_level(self, path):
+		"""get the max repetition level for the given schema path."""
+		max_level = 0
+		for part in path:
+			se = self.schema_element(part)
+			if se.repetition_type == FieldRepetitionType.REQUIRED:
+				max_level += 1
+		return max_level
+
+
+	def max_definition_level(self, path):
+		"""get the max definition level for the given schema path."""
+		max_level = 0
+		for part in path:
+			se = self.schema_element(part)
+			if se.repetition_type != FieldRepetitionType.REQUIRED:
+				max_level += 1
+		return max_level

+ 12 - 9
setup.py

@@ -1,12 +1,15 @@
 from distutils.core import setup
 from distutils.core import setup
 
 
 setup(name='parquet',
 setup(name='parquet',
-      version='1.0',
-      description='Python support for Parquet file format',
-      author='Joe Crobak',
-      author_email='joecrow@gmail.com',
-      packages=['parquet'],
-      requires=[
-      	'thrift',
-      ]
-     )
+    version='1.0',
+    description='Python support for Parquet file format',
+    author='Joe Crobak',
+    author_email='joecrow@gmail.com',
+    packages=[ 'parquet' ],
+    requires=[
+        'thrift',
+    ],
+    extras_require = {
+        'snappy support': ['python-snappy']
+    },
+)

+ 46 - 0
test/test_encoding.py

@@ -0,0 +1,46 @@
+import array
+import StringIO
+import unittest
+
+import parquet.encoding
+
+class TestBitPacked(unittest.TestCase):
+
+    def testFromExample(self):
+        encoded_bitstring = array.array('B', [0b10001000, 0b11000110, 0b11111010]).tostring()
+        fo = StringIO.StringIO(encoded_bitstring)
+        count = 3 << 1
+        res = parquet.encoding.read_bitpacked(fo, count, 3)
+        self.assertEquals(range(8), res)
+
+
+class TestBitPackedDeprecated(unittest.TestCase):
+
+    def testFromExample(self):
+        encoded_bitstring = array.array('B', [0b00000101, 0b00111001, 0b01110111])
+        fo = StringIO.StringIO(encoded_bitstring)
+        res = parquet.encoding.read_bitpacked_deprecated(fo, 3, 3)
+        self.assertEquals(range(8), res)
+
+
+class TestWidthFromMaxInt(unittest.TestCase):
+
+    def testWidths(self):
+        self.assertEquals(0, parquet.encoding.width_from_max_int(0));
+        self.assertEquals(1, parquet.encoding.width_from_max_int(1));
+        self.assertEquals(2, parquet.encoding.width_from_max_int(2));
+        self.assertEquals(2, parquet.encoding.width_from_max_int(3));
+        self.assertEquals(3, parquet.encoding.width_from_max_int(4));
+        self.assertEquals(3, parquet.encoding.width_from_max_int(5));
+        self.assertEquals(3, parquet.encoding.width_from_max_int(6));
+        self.assertEquals(3, parquet.encoding.width_from_max_int(7));
+        self.assertEquals(4, parquet.encoding.width_from_max_int(8));
+        self.assertEquals(4, parquet.encoding.width_from_max_int(15));
+        self.assertEquals(5, parquet.encoding.width_from_max_int(16));
+        self.assertEquals(5, parquet.encoding.width_from_max_int(31));
+        self.assertEquals(6, parquet.encoding.width_from_max_int(32));
+        self.assertEquals(6, parquet.encoding.width_from_max_int(63));
+        self.assertEquals(7, parquet.encoding.width_from_max_int(64));
+        self.assertEquals(7, parquet.encoding.width_from_max_int(127));
+        self.assertEquals(8, parquet.encoding.width_from_max_int(128));
+        self.assertEquals(8, parquet.encoding.width_from_max_int(255));

+ 2 - 2
test/test_read_support.py

@@ -25,11 +25,11 @@ class TestFileFormat(unittest.TestCase):
 
 
 class TestMetadata(unittest.TestCase):
 class TestMetadata(unittest.TestCase):
 
 
-	f = "/Users/joecrow/Code/parquet-compatibility/parquet-testdata/impala/1.0.4-SNAPPY/nation.impala.parquet"
+	f = "/Users/joecrow/Code/parquet-compatibility/parquet-testdata/impala/1.1.1-SNAPPY/nation.impala.parquet"
 	
 	
 	def testFooterBytes(self):
 	def testFooterBytes(self):
 		with open(self.f) as fo:
 		with open(self.f) as fo:
-			self.assertEquals(229, parquet._get_footer_size(fo))
+			self.assertEquals(327, parquet._get_footer_size(fo))
 
 
 	def testReadFOoter(self):
 	def testReadFOoter(self):
 		parquet.read_footer(self.f)
 		parquet.read_footer(self.f)