from os.path import split from .base import MetadataProvider from typing import List, Dict, Any import time import threading import jaydebeapi from pathlib import Path class HiveMetadataProvider(MetadataProvider): def __init__(self, config: Dict): super().__init__(config) self.cache: Dict[str, Dict[str, Dict[str, Any]]] = {} # 创建一个线程锁 self._refresh_lock = threading.Lock() def _get_connection(self): """ 创建 Hive JDBC 连接 """ BASE_DIR = Path(__file__).resolve().parent.parent jdbc_jars = [f"{Path(BASE_DIR)}/{jar.strip()}" for jar in str(self.config["jars"]).split(",")] return jaydebeapi.connect( jclassname= self.config.get("driver", "org.apache.hive.jdbc.HiveDriver"), url= self.config["url"], driver_args=[self.config.get("username", "tjrd"), self.config.get("password", "")], jars=jdbc_jars, ) def get_databases(self) -> List[str]: conn = self._get_connection() try: cursor = conn.cursor() cursor.execute("SHOW DATABASES") return [row[0] for row in cursor.fetchall()] finally: conn.close() def get_tables(self, database: str) -> List[str]: conn = self._get_connection() try: cursor = conn.cursor() cursor.execute(f"SHOW TABLES IN `{database}`") return [row[0] for row in cursor.fetchall()] finally: conn.close() def get_columns(self, table: str, database: str) -> List[Dict[str, Any]]: conn = self._get_connection() try: cursor = conn.cursor() cursor.execute(f"DESCRIBE `{database}`.`{table}`") columns = [] for row in cursor.fetchall(): # Hive JDBC DESCRIBE: col_name, data_type, comment if not row or row[0].startswith("#") or str(row[0]).strip() == "" or any(c.get("name") == row[0] for c in columns): continue columns.append({ "name": row[0], "type": row[1] if len(row) > 1 else "", "comment": row[2] if len(row) > 2 else "", }) return columns except Exception as e: print(f"Error fetching columns for {database}.{table}: {e}") return [] finally: conn.close() def has_table(self, table_name: str, schema: str) -> bool: self.refresh_db_cache(schema) return schema in self.cache and table_name in self.cache[schema] def get_table_comment(self, table_name: str, schema: str) -> str: self.refresh_db_cache(schema) return self.cache.get(schema, {}).get(table_name, {}).get("comment", "") def fectch_distinct_values(self, column_name: str, table_name: str, schema: str, max_num: int = 5) -> list[str]: """检查表是否存在""" return [] def get_table_comment_db(self, table_name: str, schema: str) -> str: with self._get_connection() as conn: cursor = conn.cursor() cursor.execute(f"DESCRIBE FORMATTED `{schema}`.`{table_name}`") for row in cursor.fetchall(): if len(row) >= 3 and str(row[1]).strip() == "comment": return str(row[2]).strip() return "" def refresh_db_cache(self, schema: str = ""): db_cache = self.cache.get(schema, None) time_secs = int(time.time()) if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # 10 minutes, 缓存有有效 return with self._refresh_lock: if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # double check 缓存有有效 return # 重新构建缓存 table_cache: Dict[str, Dict[str, Any]] = {} tables = self.get_tables(schema) for table in tables: comment = self.get_table_comment_db(table, schema) columns = {} # self.get_columns(table, database) table_cache[table] = {"comment": comment, "columns": columns, "table": table, "schema": schema} print(f"{schema}.{table} comment: {comment}, columns: {len(columns)}") # 更新缓存 self.cache[schema] = {"cache_time": int(time.time()), "tables": table_cache} def refresh_cache(self): databases = self.get_databases() for database in databases: self.refresh_db_cache(database)