schema.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. """Utils for working with the parquet thrift models."""
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. from __future__ import unicode_literals
  6. import os
  7. import thriftpy
  8. THRIFT_FILE = os.path.join(os.path.dirname(__file__), "parquet.thrift")
  9. parquet_thrift = thriftpy.load(THRIFT_FILE, module_name=str("parquet_thrift")) # pylint: disable=invalid-name
  10. class SchemaHelper(object):
  11. """Utility providing convenience methods for schema_elements."""
  12. def __init__(self, schema_elements):
  13. """Initialize with the specified schema_elements."""
  14. self.schema_elements = schema_elements
  15. self.schema_elements_by_name = dict(
  16. [(se.name, se) for se in schema_elements])
  17. assert len(self.schema_elements) == len(self.schema_elements_by_name)
  18. def schema_element(self, name):
  19. """Get the schema element with the given name."""
  20. return self.schema_elements_by_name[name]
  21. def is_required(self, name):
  22. """Return true iff the schema element with the given name is required."""
  23. return self.schema_element(name).repetition_type == parquet_thrift.FieldRepetitionType.REQUIRED
  24. def max_repetition_level(self, path):
  25. """Get the max repetition level for the given schema path."""
  26. max_level = 0
  27. for part in path:
  28. element = self.schema_element(part)
  29. if element.repetition_type == parquet_thrift.FieldRepetitionType.REQUIRED:
  30. max_level += 1
  31. return max_level
  32. def max_definition_level(self, path):
  33. """Get the max definition level for the given schema path."""
  34. max_level = 0
  35. for part in path:
  36. element = self.schema_element(part)
  37. if element.repetition_type != parquet_thrift.FieldRepetitionType.REQUIRED:
  38. max_level += 1
  39. return max_level