hive.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from os.path import split
  2. from .base import MetadataProvider
  3. from typing import List, Dict, Any
  4. import time
  5. import threading
  6. import jaydebeapi
  7. from pathlib import Path
  8. class HiveMetadataProvider(MetadataProvider):
  9. def __init__(self, config: Dict):
  10. super().__init__(config)
  11. self.cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
  12. # 创建一个线程锁
  13. self._refresh_lock = threading.Lock()
  14. def _get_connection(self):
  15. """
  16. 创建 Hive JDBC 连接
  17. """
  18. BASE_DIR = Path(__file__).resolve().parent.parent
  19. jdbc_jars = [f"{Path(BASE_DIR)}/{jar.strip()}" for jar in str(self.config["jars"]).split(",")]
  20. return jaydebeapi.connect(
  21. jclassname= self.config.get("driver", "org.apache.hive.jdbc.HiveDriver"),
  22. url= self.config["url"],
  23. driver_args=[self.config.get("username", "tjrd"), self.config.get("password", "")],
  24. jars=jdbc_jars,
  25. )
  26. def get_databases(self) -> List[str]:
  27. conn = self._get_connection()
  28. try:
  29. cursor = conn.cursor()
  30. cursor.execute("SHOW DATABASES")
  31. return [row[0] for row in cursor.fetchall()]
  32. finally:
  33. conn.close()
  34. def get_tables(self, database: str) -> List[str]:
  35. conn = self._get_connection()
  36. try:
  37. cursor = conn.cursor()
  38. cursor.execute(f"SHOW TABLES IN `{database}`")
  39. return [row[0] for row in cursor.fetchall()]
  40. finally:
  41. conn.close()
  42. def get_columns(self, table: str, database: str) -> List[Dict[str, Any]]:
  43. conn = self._get_connection()
  44. try:
  45. cursor = conn.cursor()
  46. cursor.execute(f"DESCRIBE `{database}`.`{table}`")
  47. columns = []
  48. for row in cursor.fetchall():
  49. # Hive JDBC DESCRIBE: col_name, data_type, comment
  50. if not row or row[0].startswith("#") or str(row[0]).strip() == "" or any(c.get("name") == row[0] for c in columns):
  51. continue
  52. columns.append({
  53. "name": row[0],
  54. "type": row[1] if len(row) > 1 else "",
  55. "comment": row[2] if len(row) > 2 else "",
  56. })
  57. return columns
  58. except Exception as e:
  59. print(f"Error fetching columns for {database}.{table}: {e}")
  60. return []
  61. finally:
  62. conn.close()
  63. def has_table(self, table_name: str, schema: str) -> bool:
  64. self.refresh_db_cache(schema)
  65. return schema in self.cache and table_name in self.cache[schema]
  66. def get_table_comment(self, table_name: str, schema: str) -> str:
  67. self.refresh_db_cache(schema)
  68. return self.cache.get(schema, {}).get(table_name, {}).get("comment", "")
  69. def fectch_distinct_values(self, column_name: str, table_name: str, schema: str, max_num: int = 5) -> list[str]:
  70. """检查表是否存在"""
  71. return []
  72. def get_table_comment_db(self, table_name: str, schema: str) -> str:
  73. with self._get_connection() as conn:
  74. cursor = conn.cursor()
  75. cursor.execute(f"DESCRIBE FORMATTED `{schema}`.`{table_name}`")
  76. for row in cursor.fetchall():
  77. if len(row) >= 3 and str(row[1]).strip() == "comment":
  78. return str(row[2]).strip()
  79. return ""
  80. def refresh_db_cache(self, schema: str = ""):
  81. db_cache = self.cache.get(schema, None)
  82. time_secs = int(time.time())
  83. if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # 10 minutes, 缓存有有效
  84. return
  85. with self._refresh_lock:
  86. if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # double check 缓存有有效
  87. return
  88. # 重新构建缓存
  89. table_cache: Dict[str, Dict[str, Any]] = {}
  90. tables = self.get_tables(schema)
  91. for table in tables:
  92. comment = self.get_table_comment_db(table, schema)
  93. columns = {} # self.get_columns(table, database)
  94. table_cache[table] = {"comment": comment, "columns": columns, "table": table, "schema": schema}
  95. print(f"{schema}.{table} comment: {comment}, columns: {len(columns)}")
  96. # 更新缓存
  97. self.cache[schema] = {"cache_time": int(time.time()), "tables": table_cache}
  98. def refresh_cache(self):
  99. databases = self.get_databases()
  100. for database in databases:
  101. self.refresh_db_cache(database)