impala.py 4.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from .base import MetadataProvider
  2. from typing import List, Dict, Any
  3. import time
  4. import threading
  5. from impala.dbapi import connect
  6. import getpass
  7. class ImpalaMetadataProvider(MetadataProvider):
  8. def __init__(self, config: Dict):
  9. super().__init__(config)
  10. self.cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
  11. # 创建一个线程锁
  12. self._refresh_lock = threading.Lock()
  13. def _get_connection(self):
  14. return connect(
  15. host=self.config["host"],
  16. port=self.config["port"],
  17. auth_mechanism='NOSASL' # NOSASL 方式,传递的user参数不生效,而实际是使用服务启动用户
  18. )
  19. def get_databases(self) -> List[str]:
  20. with self._get_connection() as conn:
  21. cursor = conn.cursor()
  22. if "tjrd" == getpass.getuser():
  23. cursor.execute("INVALIDATE METADATA")
  24. cursor.execute("SHOW SCHEMAS")
  25. return [row[0] for row in cursor.fetchall()]
  26. def get_tables(self, database: str) -> List[str]:
  27. with self._get_connection() as conn:
  28. cursor = conn.cursor()
  29. cursor.execute(f"SHOW TABLES IN `{database}`")
  30. return [row[0] for row in cursor.fetchall()]
  31. def get_columns(self, table: str, database: str) -> List[Dict[str, Any]]:
  32. with self._get_connection() as conn:
  33. cursor = conn.cursor()
  34. cursor.execute(f"DESCRIBE `{database}`.`{table}`")
  35. columns = []
  36. for row in cursor.fetchall():
  37. # Impala: (name, type, comment)
  38. if not row or row[0].startswith("#") or str(row[0]).strip() == "" or any(c.get("name") == row[0] for c in columns) :
  39. continue
  40. columns.append({
  41. "name": row[0],
  42. "type": row[1] if len(row) > 1 else "",
  43. "comment": row[2] if len(row) > 2 else "",
  44. })
  45. return columns
  46. def has_table(self, table_name: str, schema: str) -> bool:
  47. self.refresh_db_cache(schema)
  48. return schema in self.cache and table_name in self.cache[schema]
  49. def get_table_comment(self, table_name: str, schema: str) -> str:
  50. self.refresh_db_cache(schema)
  51. return self.cache.get(schema, {}).get(table_name, {}).get("comment", "")
  52. def fectch_distinct_values(self, column_name: str, table_name: str, schema: str, max_num: int = 5) -> list[str]:
  53. """检查表是否存在"""
  54. return []
  55. def get_table_comment_db(self, table_name: str, schema: str) -> str:
  56. with self._get_connection() as conn:
  57. cursor = conn.cursor()
  58. cursor.execute(f"DESCRIBE FORMATTED `{schema}`.`{table_name}`")
  59. for row in cursor.fetchall():
  60. if len(row) >= 3 and str(row[1]).strip() == "comment":
  61. return str(row[2]).strip()
  62. return ""
  63. def refresh_db_cache(self, schema: str = ""):
  64. db_cache = self.cache.get(schema, None)
  65. time_secs = int(time.time())
  66. if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # 10 minutes, 缓存有有效
  67. return
  68. with self._refresh_lock:
  69. if db_cache is not None and time_secs - db_cache["cache_time"] < 600: # double check 缓存有有效
  70. return
  71. # 重新构建缓存
  72. table_cache: Dict[str, Dict[str, Any]] = {}
  73. tables = self.get_tables(schema)
  74. for table in tables:
  75. comment = self.get_table_comment_db(table, schema)
  76. columns = {} # self.get_columns(table, database)
  77. table_cache[table] = {"comment": comment, "columns": columns, "table": table, "schema": schema}
  78. print(f"{schema}.{table} comment: {comment}, columns: {len(columns)}")
  79. # 更新缓存
  80. self.cache[schema] = {"cache_time": int(time.time()), "tables": table_cache}
  81. def refresh_cache(self):
  82. databases = self.get_databases()
  83. for database in databases:
  84. self.refresh_db_cache(database)