| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- from os.path import split
- from .base import MetadataProvider
- from typing import List, Dict, Any
- import time
- import threading
- import jaydebeapi
- from pathlib import Path
- class HiveMetadataProvider(MetadataProvider):
- def __init__(self, config: Dict):
- super().__init__(config)
- self.cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
- # 创建一个线程锁
- self._refresh_lock = threading.Lock()
- def _get_connection(self):
- """
- 创建 Hive JDBC 连接
- """
- BASE_DIR = Path(__file__).resolve().parent.parent
- jdbc_jars = [f"{Path(BASE_DIR)}/{jar.strip()}" for jar in str(self.config["jars"]).split(",")]
- return jaydebeapi.connect(
- jclassname= self.config.get("driver", "org.apache.hive.jdbc.HiveDriver"),
- url= self.config["url"],
- driver_args=[self.config.get("username", "tjrd"), self.config.get("password", "")],
- jars=jdbc_jars,
- )
- def get_databases(self) -> List[str]:
- conn = self._get_connection()
- try:
- cursor = conn.cursor()
- cursor.execute("SHOW DATABASES")
- return [row[0] for row in cursor.fetchall()]
- finally:
- conn.close()
- def get_tables(self, database: str) -> List[str]:
- conn = self._get_connection()
- try:
- cursor = conn.cursor()
- cursor.execute(f"SHOW TABLES IN `{database}`")
- return [row[0] for row in cursor.fetchall()]
- finally:
- conn.close()
- def get_columns(self, table: str, database: str) -> List[Dict[str, Any]]:
- conn = self._get_connection()
- try:
- cursor = conn.cursor()
- cursor.execute(f"DESCRIBE `{database}`.`{table}`")
- columns = []
- for row in cursor.fetchall():
- # Hive JDBC DESCRIBE: col_name, data_type, comment
- if not row or row[0].startswith("#") or str(row[0]).strip() == "" or any(c.get("name") == row[0] for c in columns):
- continue
- columns.append({
- "name": row[0],
- "type": row[1] if len(row) > 1 else "",
- "comment": row[2] if len(row) > 2 else "",
- })
- return columns
- except Exception as e:
- print(f"Error fetching columns for {database}.{table}: {e}")
- return []
- finally:
- conn.close()
- def has_table(self, table_name: str, schema: str) -> bool:
- self.refresh_db_cache(schema)
- return schema in self.cache and table_name in self.cache[schema]
- def get_table_comment(self, table_name: str, schema: str) -> str:
- self.refresh_db_cache(schema)
- return self.cache.get(schema, {}).get(table_name, {}).get("comment", "")
- def fectch_distinct_values(self, column_name: str, table_name: str, schema: str, max_num: int = 5) -> list[str]:
- """检查表是否存在"""
- return []
- def get_table_comment_db(self, table_name: str, schema: str) -> str:
- with self._get_connection() as conn:
- cursor = conn.cursor()
- cursor.execute(f"DESCRIBE FORMATTED `{schema}`.`{table_name}`")
- for row in cursor.fetchall():
- if len(row) >= 3 and str(row[1]).strip() == "comment":
- return str(row[2]).strip()
- return ""
- def refresh_db_cache(self, schema: str = ""):
- db_cache = self.cache.get(schema, None)
- time_secs = int(time.time())
- if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # 10 minutes, 缓存有有效
- return
- with self._refresh_lock:
- if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # double check 缓存有有效
- return
- # 重新构建缓存
- table_cache: Dict[str, Dict[str, Any]] = {}
- tables = self.get_tables(schema)
- for table in tables:
- comment = self.get_table_comment_db(table, schema)
- columns = {} # self.get_columns(table, database)
- table_cache[table] = {"comment": comment, "columns": columns, "table": table, "schema": schema}
- print(f"{schema}.{table} comment: {comment}, columns: {len(columns)}")
- # 更新缓存
- self.cache[schema] = {"cache_time": int(time.time()), "tables": table_cache}
- def refresh_cache(self):
- databases = self.get_databases()
- for database in databases:
- self.refresh_db_cache(database)
|