| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import traceback
- from fastapi import APIRouter, HTTPException, Query, Body
- from typing import Dict, Any
- from services.metadata_service import MetadataService
- from services.chat_service import ChatService
- from threading import Lock
- # 创建子路由
- router = APIRouter(prefix="", tags=["metadata"])
- # 服务实例
- _metadata_svc: MetadataService | None = None
- _chat_svc: ChatService | None = None
- _init_lock = Lock()
- def get_chat_svc() -> ChatService:
- global _chat_svc
- if _chat_svc is None:
- with _init_lock:
- if _chat_svc is None:
- _chat_svc = ChatService()
- return _chat_svc
- def get_metadata_svc() -> MetadataService:
- global _metadata_svc
- if _metadata_svc is None:
- with _init_lock:
- if _metadata_svc is None:
- _metadata_svc = MetadataService()
- return _metadata_svc
- @router.get("/sources")
- def list_sources():
- return {"sources": get_metadata_svc().list_sources()}
- @router.get("/databases")
- def list_databases(source: str = Query(...)):
- dbs = get_metadata_svc().get_databases(source)
- if not dbs:
- raise HTTPException(404, detail=f"Source '{source}' not found or no databases")
- return {"source": source, "databases": dbs}
- @router.get("/tables")
- def list_tables(source: str = Query(...), database: str = Query(...)):
- tables = get_metadata_svc().get_tables(source, database)
- if not tables:
- raise HTTPException(404, detail="Database or source not found")
- return {"source": source, "database": database, "tables": tables}
- @router.get("/columns")
- def list_columns(
- source: str = Query(...),
- database: str = Query(...),
- table: str = Query(...)
- ):
- cols = get_metadata_svc().get_columns(source, database, table)
- if not cols:
- raise HTTPException(404, detail="Table not found")
- return {
- "source": source,
- "database": database,
- "table": table,
- "columns": cols
- }
- @router.post("/generate_sql")
- def generate_sql(payload: Dict[str, Any] = Body(...)):
- """
- 请求体示例:
- {
- "source": "hive_prod",
- "database": "sales",
- "requirement": "最近7天销售额最高的10个用户"
- }
- """
- source = payload.get("source")
- database = payload.get("database")
- requirement = payload.get("requirement")
- if not all([source, database, requirement]):
- raise HTTPException(400, "Missing 'source', 'database' or 'requirement'")
- try:
- schema = get_metadata_svc().build_mschema(source, database)
- result = get_chat_svc().generate_sql(requirement, schema, source)
- response = result["structured_response"]
- return response
- except Exception as e:
- traceback.print_exc()
- raise HTTPException(500, detail=str(e))
|