from .base import MetadataProvider from typing import List, Dict, Any import time import threading from impala.dbapi import connect import getpass class ImpalaMetadataProvider(MetadataProvider): def __init__(self, config: Dict): super().__init__(config) self.cache: Dict[str, Dict[str, Dict[str, Any]]] = {} # 创建一个线程锁 self._refresh_lock = threading.Lock() def _get_connection(self): return connect( host=self.config["host"], port=self.config["port"], auth_mechanism='NOSASL' # NOSASL 方式,传递的user参数不生效,而实际是使用服务启动用户 ) def get_databases(self) -> List[str]: with self._get_connection() as conn: cursor = conn.cursor() if "tjrd" == getpass.getuser(): cursor.execute("INVALIDATE METADATA") cursor.execute("SHOW SCHEMAS") return [row[0] for row in cursor.fetchall()] def get_tables(self, database: str) -> List[str]: with self._get_connection() as conn: cursor = conn.cursor() cursor.execute(f"SHOW TABLES IN `{database}`") return [row[0] for row in cursor.fetchall()] def get_columns(self, table: str, database: str) -> List[Dict[str, Any]]: with self._get_connection() as conn: cursor = conn.cursor() cursor.execute(f"DESCRIBE `{database}`.`{table}`") columns = [] for row in cursor.fetchall(): # Impala: (name, type, comment) if not row or row[0].startswith("#") or str(row[0]).strip() == "" or any(c.get("name") == row[0] for c in columns) : continue columns.append({ "name": row[0], "type": row[1] if len(row) > 1 else "", "comment": row[2] if len(row) > 2 else "", }) return columns def has_table(self, table_name: str, schema: str) -> bool: self.refresh_db_cache(schema) return schema in self.cache and table_name in self.cache[schema] def get_table_comment(self, table_name: str, schema: str) -> str: self.refresh_db_cache(schema) return self.cache.get(schema, {}).get(table_name, {}).get("comment", "") def fectch_distinct_values(self, column_name: str, table_name: str, schema: str, max_num: int = 5) -> list[str]: """检查表是否存在""" return [] def get_table_comment_db(self, table_name: str, schema: str) -> str: with self._get_connection() as conn: cursor = conn.cursor() cursor.execute(f"DESCRIBE FORMATTED `{schema}`.`{table_name}`") for row in cursor.fetchall(): if len(row) >= 3 and str(row[1]).strip() == "comment": return str(row[2]).strip() return "" def refresh_db_cache(self, schema: str = ""): db_cache = self.cache.get(schema, None) time_secs = int(time.time()) if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # 10 minutes, 缓存有有效 return with self._refresh_lock: if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # double check 缓存有有效 return # 重新构建缓存 table_cache: Dict[str, Dict[str, Any]] = {} tables = self.get_tables(schema) for table in tables: comment = self.get_table_comment_db(table, schema) columns = {} # self.get_columns(table, database) table_cache[table] = {"comment": comment, "columns": columns, "table": table, "schema": schema} print(f"{schema}.{table} comment: {comment}, columns: {len(columns)}") # 更新缓存 self.cache[schema] = {"cache_time": int(time.time()), "tables": table_cache} def refresh_cache(self): databases = self.get_databases() for database in databases: self.refresh_db_cache(database)