| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- from .base import MetadataProvider
- from typing import List, Dict, Any
- import time
- import threading
- from impala.dbapi import connect
- import getpass
- class ImpalaMetadataProvider(MetadataProvider):
- def __init__(self, config: Dict):
- super().__init__(config)
- self.cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
- # 创建一个线程锁
- self._refresh_lock = threading.Lock()
- def _get_connection(self):
- return connect(
- host=self.config["host"],
- port=self.config["port"],
- auth_mechanism='NOSASL' # NOSASL 方式,传递的user参数不生效,而实际是使用服务启动用户
- )
- def get_databases(self) -> List[str]:
- with self._get_connection() as conn:
- cursor = conn.cursor()
- if "tjrd" == getpass.getuser():
- cursor.execute("INVALIDATE METADATA")
- cursor.execute("SHOW SCHEMAS")
- return [row[0] for row in cursor.fetchall()]
- def get_tables(self, database: str) -> List[str]:
- with self._get_connection() as conn:
- cursor = conn.cursor()
- cursor.execute(f"SHOW TABLES IN `{database}`")
- return [row[0] for row in cursor.fetchall()]
- def get_columns(self, table: str, database: str) -> List[Dict[str, Any]]:
- with self._get_connection() as conn:
- cursor = conn.cursor()
- cursor.execute(f"DESCRIBE `{database}`.`{table}`")
- columns = []
- for row in cursor.fetchall():
- # Impala: (name, type, comment)
- if not row or row[0].startswith("#") or str(row[0]).strip() == "" or any(c.get("name") == row[0] for c in columns) :
- continue
- columns.append({
- "name": row[0],
- "type": row[1] if len(row) > 1 else "",
- "comment": row[2] if len(row) > 2 else "",
- })
- return columns
- def has_table(self, table_name: str, schema: str) -> bool:
- self.refresh_db_cache(schema)
- return schema in self.cache and table_name in self.cache[schema]
- def get_table_comment(self, table_name: str, schema: str) -> str:
- self.refresh_db_cache(schema)
- return self.cache.get(schema, {}).get(table_name, {}).get("comment", "")
- def fectch_distinct_values(self, column_name: str, table_name: str, schema: str, max_num: int = 5) -> list[str]:
- """检查表是否存在"""
- return []
- def get_table_comment_db(self, table_name: str, schema: str) -> str:
- with self._get_connection() as conn:
- cursor = conn.cursor()
- cursor.execute(f"DESCRIBE FORMATTED `{schema}`.`{table_name}`")
- for row in cursor.fetchall():
- if len(row) >= 3 and str(row[1]).strip() == "comment":
- return str(row[2]).strip()
- return ""
- def refresh_db_cache(self, schema: str = ""):
- db_cache = self.cache.get(schema, None)
- time_secs = int(time.time())
- if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # 10 minutes, 缓存有有效
- return
- with self._refresh_lock:
- if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # double check 缓存有有效
- return
- # 重新构建缓存
- table_cache: Dict[str, Dict[str, Any]] = {}
- tables = self.get_tables(schema)
- for table in tables:
- comment = self.get_table_comment_db(table, schema)
- columns = {} # self.get_columns(table, database)
- table_cache[table] = {"comment": comment, "columns": columns, "table": table, "schema": schema}
- print(f"{schema}.{table} comment: {comment}, columns: {len(columns)}")
- # 更新缓存
- self.cache[schema] = {"cache_time": int(time.time()), "tables": table_cache}
- def refresh_cache(self):
- databases = self.get_databases()
- for database in databases:
- self.refresh_db_cache(database)
|