schema.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. """Utils for working with the parquet thrift models."""
  2. from __future__ import (absolute_import, division, print_function,
  3. unicode_literals)
  4. import os
  5. import thriftpy2 as thriftpy
  6. THRIFT_FILE = os.path.join(os.path.dirname(__file__), "parquet.thrift")
  7. parquet_thrift = thriftpy.load(THRIFT_FILE, module_name=str("parquet_thrift")) # pylint: disable=invalid-name
  8. class SchemaHelper:
  9. """Utility providing convenience methods for schema_elements."""
  10. def __init__(self, schema_elements):
  11. """Initialize with the specified schema_elements."""
  12. self.schema_elements = schema_elements
  13. self.schema_elements_by_name = {se.name: se for se in schema_elements}
  14. assert len(self.schema_elements) == len(self.schema_elements_by_name)
  15. def schema_element(self, name):
  16. """Get the schema element with the given name."""
  17. return self.schema_elements_by_name[name]
  18. def is_required(self, name):
  19. """Return true iff the schema element with the given name is required."""
  20. return self.schema_element(name).repetition_type == parquet_thrift.FieldRepetitionType.REQUIRED
  21. def max_repetition_level(self, path):
  22. """Get the max repetition level for the given schema path."""
  23. max_level = 0
  24. for part in path:
  25. element = self.schema_element(part)
  26. if element.repetition_type == parquet_thrift.FieldRepetitionType.REQUIRED:
  27. max_level += 1
  28. return max_level
  29. def max_definition_level(self, path):
  30. """Get the max definition level for the given schema path."""
  31. max_level = 0
  32. for part in path:
  33. element = self.schema_element(part)
  34. if element.repetition_type != parquet_thrift.FieldRepetitionType.REQUIRED:
  35. max_level += 1
  36. return max_level