routes.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import traceback
  2. from fastapi import APIRouter, HTTPException, Query, Body
  3. from typing import Dict, Any
  4. from services.metadata_service import MetadataService
  5. from services.chat_service import ChatService
  6. from threading import Lock
  7. # 创建子路由
  8. router = APIRouter(prefix="", tags=["metadata"])
  9. # 服务实例
  10. _metadata_svc: MetadataService | None = None
  11. _chat_svc: ChatService | None = None
  12. _init_lock = Lock()
  13. def get_chat_svc() -> ChatService:
  14. global _chat_svc
  15. if _chat_svc is None:
  16. with _init_lock:
  17. if _chat_svc is None:
  18. _chat_svc = ChatService()
  19. return _chat_svc
  20. def get_metadata_svc() -> MetadataService:
  21. global _metadata_svc
  22. if _metadata_svc is None:
  23. with _init_lock:
  24. if _metadata_svc is None:
  25. _metadata_svc = MetadataService()
  26. return _metadata_svc
  27. @router.get("/sources")
  28. def list_sources():
  29. return {"sources": get_metadata_svc().list_sources()}
  30. @router.get("/databases")
  31. def list_databases(source: str = Query(...)):
  32. dbs = get_metadata_svc().get_databases(source)
  33. if not dbs:
  34. raise HTTPException(404, detail=f"Source '{source}' not found or no databases")
  35. return {"source": source, "databases": dbs}
  36. @router.get("/tables")
  37. def list_tables(source: str = Query(...), database: str = Query(...)):
  38. tables = get_metadata_svc().get_tables(source, database)
  39. if not tables:
  40. raise HTTPException(404, detail="Database or source not found")
  41. return {"source": source, "database": database, "tables": tables}
  42. @router.get("/columns")
  43. def list_columns(
  44. source: str = Query(...),
  45. database: str = Query(...),
  46. table: str = Query(...)
  47. ):
  48. cols = get_metadata_svc().get_columns(source, database, table)
  49. if not cols:
  50. raise HTTPException(404, detail="Table not found")
  51. return {
  52. "source": source,
  53. "database": database,
  54. "table": table,
  55. "columns": cols
  56. }
  57. @router.post("/generate_sql")
  58. def generate_sql(payload: Dict[str, Any] = Body(...)):
  59. """
  60. 请求体示例:
  61. {
  62. "source": "hive_prod",
  63. "database": "sales",
  64. "requirement": "最近7天销售额最高的10个用户"
  65. }
  66. """
  67. source = payload.get("source")
  68. database = payload.get("database")
  69. requirement = payload.get("requirement")
  70. if not all([source, database, requirement]):
  71. raise HTTPException(400, "Missing 'source', 'database' or 'requirement'")
  72. try:
  73. schema = get_metadata_svc().build_mschema(source, database)
  74. result = get_chat_svc().generate_sql(requirement, schema, source)
  75. response = result["structured_response"]
  76. return response
  77. except Exception as e:
  78. traceback.print_exc()
  79. raise HTTPException(500, detail=str(e))