schema.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. """Utils for working with the parquet thrift models"""
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. from __future__ import unicode_literals
  6. import os
  7. import thriftpy
  8. THRIFT_FILE = os.path.join(os.path.dirname(__file__), "parquet.thrift")
  9. parquet_thrift = thriftpy.load(THRIFT_FILE, module_name=str("parquet_thrift"))
  10. class SchemaHelper(object):
  11. def __init__(self, schema_elements):
  12. self.schema_elements = schema_elements
  13. self.schema_elements_by_name = dict(
  14. [(se.name, se) for se in schema_elements])
  15. assert len(self.schema_elements) == len(self.schema_elements_by_name)
  16. def schema_element(self, name):
  17. """Get the schema element with the given name."""
  18. return self.schema_elements_by_name[name]
  19. def is_required(self, name):
  20. """Returns true iff the schema element with the given name is
  21. required"""
  22. return self.schema_element(name).repetition_type == parquet_thrift.FieldRepetitionType.REQUIRED
  23. def max_repetition_level(self, path):
  24. """get the max repetition level for the given schema path."""
  25. max_level = 0
  26. for part in path:
  27. se = self.schema_element(part)
  28. if se.repetition_type == parquet_thrift.FieldRepetitionType.REQUIRED:
  29. max_level += 1
  30. return max_level
  31. def max_definition_level(self, path):
  32. """get the max definition level for the given schema path."""
  33. max_level = 0
  34. for part in path:
  35. se = self.schema_element(part)
  36. if se.repetition_type != parquet_thrift.FieldRepetitionType.REQUIRED:
  37. max_level += 1
  38. return max_level