Browse Source

test updates.

* new tests
* fix some bugs discovered in tests
* update setup.py and main based upon basic install test
Joe Crobak 12 years ago
parent
commit
ae430fa9ea
7 changed files with 192 additions and 52 deletions
  1. 3 0
      .gitignore
  2. 11 6
      parquet/__main__.py
  3. 5 5
      parquet/encoding.py
  4. 32 33
      parquet/schema.py
  5. 10 2
      setup.py
  6. 77 0
      test/test_encoding.py
  7. 54 6
      test/test_read_support.py

+ 3 - 0
.gitignore

@@ -1,3 +1,6 @@
 *.pyc
 .coverage
 cover
+build
+dist
+parquet.egg-info

+ 11 - 6
parquet/__main__.py

@@ -2,18 +2,21 @@ import argparse
 import logging
 import sys
 
-import parquet
 
-
-def setup_logging(options):
-    level = logging.DEBUG if options.debug else logging.WARNING
-    logging.basicConfig(level=level)
+def setup_logging(options=None):
+    level = logging.DEBUG if options is not None and options.debug \
+        else logging.WARNING
+    console = logging.StreamHandler()
+    console.setLevel(level)
+    formatter = logging.Formatter('%(name)s: %(levelname)-8s %(message)s')
+    console.setFormatter(formatter)
+    logging.getLogger('parquet').addHandler(console)
 
 
 def main(argv=None):
     argv = argv or sys.argv[1:]
 
-    parser = argparse.ArgumentParser(description='Process some integers.')
+    parser = argparse.ArgumentParser(description='Read parquet files')
     parser.add_argument('--metadata', action='store_true',
                         help='show metadata on file')
     parser.add_argument('--row-group-metadata', action='store_true',
@@ -39,6 +42,8 @@ def main(argv=None):
 
     setup_logging(args)
 
+    import parquet
+
     if args.metadata:
         parquet.dump_metadata(args.file, args.row_group_metadata)
     if not args.no_data:

+ 5 - 5
parquet/encoding.py

@@ -28,7 +28,7 @@ def read_plain_int64(fo):
 
 def read_plain_int96(fo):
     """Reads a 96-bit int using the plain encoding"""
-    tup = struct.unpack("<q<i", fo.read(12))
+    tup = struct.unpack("<qi", fo.read(12))
     return tup[0] << 32 | tup[1]
 
 
@@ -103,12 +103,12 @@ def read_rle(fo, header, bit_width):
     width = byte_width(bit_width)
     if width >= 1:
         data += fo.read(1)
-    elif width >= 2:
+    if width >= 2:
         data += fo.read(1)
-    elif width >= 3:
+    if width >= 3:
+        data += fo.read(1)
+    if width == 4:
         data += fo.read(1)
-    elif width == 4:
-        data = fo.read(1)
     data = data + zero_data[len(data):]
     value = struct.unpack("<i", data)[0]
     logger.debug("Read RLE group with value %s of byte-width %s and count %s",

+ 32 - 33
parquet/schema.py

@@ -5,36 +5,35 @@ from ttypes import FieldRepetitionType
 
 class SchemaHelper(object):
 
-	def __init__(self, schema_elements):
-		self.schema_elements = schema_elements
-		self.schema_elements_by_name = dict([(se.name, se) for se in schema_elements])
-		assert len(self.schema_elements) == len(self.schema_elements_by_name)
-
-
-	def schema_element(self, name):
-		"""Get the schema element with the given name."""
-		return self.schema_elements_by_name[name]
-
-	def is_required(self, name):
-		"""Returns true iff the schema element with the given name is required"""
-		return self.schema_element(name).repetition_type == FieldRepetitionType.REQUIRED
-
-
-	def max_repetition_level(self, path):
-		"""get the max repetition level for the given schema path."""
-		max_level = 0
-		for part in path:
-			se = self.schema_element(part)
-			if se.repetition_type == FieldRepetitionType.REQUIRED:
-				max_level += 1
-		return max_level
-
-
-	def max_definition_level(self, path):
-		"""get the max definition level for the given schema path."""
-		max_level = 0
-		for part in path:
-			se = self.schema_element(part)
-			if se.repetition_type != FieldRepetitionType.REQUIRED:
-				max_level += 1
-		return max_level
+    def __init__(self, schema_elements):
+        self.schema_elements = schema_elements
+        self.schema_elements_by_name = dict(
+            [(se.name, se) for se in schema_elements])
+        assert len(self.schema_elements) == len(self.schema_elements_by_name)
+
+    def schema_element(self, name):
+        """Get the schema element with the given name."""
+        return self.schema_elements_by_name[name]
+
+    def is_required(self, name):
+        """Returns true iff the schema element with the given name is
+        required"""
+        return self.schema_element(name).repetition_type == FieldRepetitionType.REQUIRED
+
+    def max_repetition_level(self, path):
+        """get the max repetition level for the given schema path."""
+        max_level = 0
+        for part in path:
+            se = self.schema_element(part)
+            if se.repetition_type == FieldRepetitionType.REQUIRED:
+                max_level += 1
+        return max_level
+
+    def max_definition_level(self, path):
+        """get the max definition level for the given schema path."""
+        max_level = 0
+        for part in path:
+            se = self.schema_element(part)
+            if se.repetition_type != FieldRepetitionType.REQUIRED:
+                max_level += 1
+        return max_level

+ 10 - 2
setup.py

@@ -1,4 +1,7 @@
-from distutils.core import setup
+try:
+    from setuptools import setup
+except ImportError:
+    from distutils.core import setup
 
 setup(name='parquet',
     version='1.0',
@@ -6,10 +9,15 @@ setup(name='parquet',
     author='Joe Crobak',
     author_email='joecrow@gmail.com',
     packages=[ 'parquet' ],
-    requires=[
+    install_requires=[
         'thrift',
     ],
     extras_require = {
         'snappy support': ['python-snappy']
     },
+    entry_points={
+        'console_scripts': [
+            'parquet = parquet.__main__:main',
+        ]
+    },
 )

+ 77 - 0
test/test_encoding.py

@@ -1,8 +1,85 @@
 import array
+import struct
 import StringIO
 import unittest
 
 import parquet.encoding
+from parquet.ttypes import Type
+
+
+class TestPlain(unittest.TestCase):
+
+    def test_int32(self):
+        self.assertEquals(
+            999,
+            parquet.encoding.read_plain_int32(
+                StringIO.StringIO(struct.pack("<i", 999))))
+
+    def test_int64(self):
+        self.assertEquals(
+            999,
+            parquet.encoding.read_plain_int64(
+                StringIO.StringIO(struct.pack("<q", 999))))
+
+    def test_int96(self):
+        self.assertEquals(
+            999,
+            parquet.encoding.read_plain_int96(
+                StringIO.StringIO(struct.pack("<qi", 0, 999))))
+
+    def test_float(self):
+        self.assertAlmostEquals(
+            9.99,
+            parquet.encoding.read_plain_float(
+                StringIO.StringIO(struct.pack("<f", 9.99))),
+            2)
+
+    def test_double(self):
+        self.assertEquals(
+            9.99,
+            parquet.encoding.read_plain_double(
+                StringIO.StringIO(struct.pack("<d", 9.99))))
+
+    def test_fixed(self):
+        data = "foobar"
+        fo = StringIO.StringIO(data)
+        self.assertEquals(
+            data[:3],
+            parquet.encoding.read_plain_byte_array_fixed(
+                fo, 3))
+        self.assertEquals(
+            data[3:],
+            parquet.encoding.read_plain_byte_array_fixed(
+                fo, 3))
+
+    def test_fixed_read_plain(self):
+        data = "foobar"
+        fo = StringIO.StringIO(data)
+        self.assertEquals(
+            data[:3],
+            parquet.encoding.read_plain(
+                fo, Type.FIXED_LEN_BYTE_ARRAY, 3))
+
+
+class TestRle(unittest.TestCase):
+
+    def testFourByteValue(self):
+        fo = StringIO.StringIO(struct.pack("<i", 1 << 30))
+        out = parquet.encoding.read_rle(fo, 2 << 1, 30)
+        self.assertEquals([1 << 30] * 2, list(out))
+
+
+class TestVarInt(unittest.TestCase):
+
+    def testSingleByte(self):
+        fo = StringIO.StringIO(struct.pack("<B", 0x7F))
+        out = parquet.encoding.read_unsigned_var_int(fo)
+        self.assertEquals(0x7F, out)
+
+    def testFourByte(self):
+        fo = StringIO.StringIO(struct.pack("<BBBB", 0xFF, 0xFF, 0xFF, 0x7F))
+        out = parquet.encoding.read_unsigned_var_int(fo)
+        self.assertEquals(0x0FFFFFFF, out)
 
 
 class TestBitPacked(unittest.TestCase):

+ 54 - 6
test/test_read_support.py

@@ -1,4 +1,5 @@
 import csv
+import json
 import os
 import StringIO
 import tempfile
@@ -48,12 +49,22 @@ class TestMetadata(unittest.TestCase):
         parquet.dump_metadata(self.f, data)
 
 
-class Options():
-    col = None
-    format = 'csv'
-    no_headers = True
-    limit = -1
+class Options(object):
 
+    def __init__(self, col=None, format='csv', no_headers=True, limit=-1):
+        self.col = col
+        self.format = format
+        self.no_headers = no_headers
+        self.limit = limit
+
+
+class TestReadApi(unittest.TestCase):
+
+    def test_projection(self):
+        pass
+
+    def test_limit(self):
+        pass
 
 class TestCompatibility(object):
 
@@ -66,7 +77,7 @@ class TestCompatibility(object):
     def _test_file_csv(self, parquet_file, csv_file):
         """ Given the parquet_file and csv_file representation, converts the
             parquet_file to a csv using the dump utility and then compares the
-            result to the csv_file using column agnostic ordering.
+            result to the csv_file.
         """
         expected_data = []
         with open(csv_file, 'rb') as f:
@@ -80,6 +91,43 @@ class TestCompatibility(object):
         assert expected_data == actual_data, "{0} != {1}".format(
             str(expected_data), str(actual_data))
 
+        actual_raw_data = StringIO.StringIO()
+        parquet.dump(parquet_file, Options(no_headers=False),
+                     out=actual_raw_data)
+        actual_raw_data.seek(0, 0)
+        actual_data = list(csv.reader(actual_raw_data, delimiter='\t'))[1:]
+
+        assert expected_data == actual_data, "{0} != {1}".format(
+            str(expected_data), str(actual_data))
+
+    def _test_file_json(self, parquet_file, csv_file):
+        """ Given the parquet_file and csv_file representation, converts the
+            parquet_file to json using the dump utility and then compares the
+            result to the csv_file using column agnostic ordering.
+        """
+        expected_data = []
+        with open(csv_file, 'rb') as f:
+            expected_data = list(csv.reader(f, delimiter='|'))
+
+        actual_raw_data = StringIO.StringIO()
+        parquet.dump(parquet_file, Options(format='json'),
+                     out=actual_raw_data)
+        actual_raw_data.seek(0, 0)
+        actual_data = [json.loads(x.rstrip()) for x in
+                       actual_raw_data.read().split("\n") if len(x) > 0]
+
+        assert len(expected_data) == len(actual_data)
+        footer = parquet.read_footer(parquet_file)
+        cols = [s.name for s in footer.schema]
+        for expected, actual in zip(expected_data, actual_raw_data):
+            assert len(expected) == len(actual)
+            for i, c in enumerate(cols):
+                if c in actual:
+                    assert expected[i] == actual[c]
+
+
+
     def test_all_files(self):
         for parquet_file, csv_file in self.files:
             yield self._test_file_csv, parquet_file, csv_file
+            yield self._test_file_json, parquet_file, csv_file