""" FastAPI 后端服务 - 用于 Hugging Face Spaces """ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, List import json # 导入数据库和 RAG 引擎 # 注意:在 HF Spaces 中,这些文件应该在同一个目录下 from database_setup_lite import setup_databases from rag_engine import RAGEngine # 初始化 FastAPI 应用 app = FastAPI(title="GraphRAG Backend API") # 配置 CORS - 允许所有来源(生产环境可以限制为特定域名) app.add_middleware( CORSMiddleware, allow_origins=["*"], # 生产环境可以设置为 ["https://your-frontend.vercel.app"] allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 初始化数据库和引擎(全局变量,避免重复初始化) print("正在初始化数据库...") graph_db, vector_db = setup_databases() rag_engine = RAGEngine(graph_db, vector_db) # 加载数据用于前端展示 with open("mock_data.json", "r", encoding="utf-8") as f: mock_data = json.load(f) # Pydantic 模型 class SearchRequest(BaseModel): query: str product_name: Optional[str] = "" style_name: Optional[str] = "" class GenerateRequest(BaseModel): query: str product_name: str style_name: str use_graph: bool = True class FeatureSearchRequest(BaseModel): query: str @app.get("/") def root(): """根路径""" return { "message": "GraphRAG Backend API", "version": "1.0.0", "endpoints": [ "GET /api/products", "GET /api/styles", "GET /api/graph", "GET /api/vector-db", "POST /api/search", "POST /api/generate", "POST /api/features/search" ] } @app.get("/api/products") def get_products(): """获取产品列表""" demo_product = { "id": "P_DEMO", "name": "真丝睡眠眼罩" } return [demo_product] @app.get("/api/styles") def get_styles(): """获取风格列表""" styles = [{"id": s["id"], "name": s["name"]} for s in mock_data["styles"]] return styles @app.get("/api/graph") def get_graph(): """获取图结构数据""" nodes = [] edges = [] # 添加节点 for node_id, node_data in graph_db.nodes.items(): nodes.append({ "id": node_id, "type": node_data["type"], "label": node_data["properties"].get("name") or node_data["properties"].get("content", "")[:20] or node_id, "properties": node_data["properties"] }) # 添加边 for edge in graph_db.edges: edges.append({ "source": edge["source"], "target": edge["target"], "relationship": edge["relationship"] }) return { "nodes": nodes, "edges": edges } @app.post("/api/search") def search(request: SearchRequest): """搜索接口""" if not request.query: raise HTTPException(status_code=400, detail="查询不能为空") comparison = rag_engine.compare_retrieval( request.query, request.product_name or "", request.style_name or "" ) return comparison @app.post("/api/generate") def generate(request: GenerateRequest): """生成文案接口""" if not all([request.query, request.product_name, request.style_name]): raise HTTPException(status_code=400, detail="缺少必要参数") result = rag_engine.generate_copywriting( request.query, request.product_name, request.style_name, request.use_graph ) return result @app.get("/api/vector-db") def get_vector_db(): """获取传统RAG的向量数据库内容""" try: collection = vector_db.collection all_docs = collection.get() documents = [] for i, doc_id in enumerate(all_docs["ids"]): documents.append({ "id": doc_id, "content": all_docs["documents"][i] if "documents" in all_docs and i < len(all_docs["documents"]) else "", "metadata": all_docs["metadatas"][i] if "metadatas" in all_docs and i < len(all_docs["metadatas"]) else {} }) return { "total": len(documents), "documents": documents } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/features/search") def search_features(request: FeatureSearchRequest): """根据查询搜索相关特征""" query = request.query.lower() if not query: return {"features": []} # 获取所有特征节点 feature_nodes = graph_db.find_nodes_by_type("Feature") matched_features = [] for node in feature_nodes: feature_name = node["properties"].get("name", node["id"]).lower() # 简单的关键词匹配 if query in feature_name or any(keyword in feature_name for keyword in query.split()): matched_features.append({ "id": node["id"], "name": node["properties"].get("name", node["id"]), "related_products": [] }) # 查找使用该特征的产品 for edge in graph_db.edges: if edge["target"] == node["id"] and edge["relationship"] == "HAS_FEATURE": product_node = graph_db.nodes.get(edge["source"], {}) if product_node.get("type") == "Product": matched_features[-1]["related_products"].append( product_node["properties"].get("name", edge["source"]) ) return {"features": matched_features[:10]} # 最多返回10个 # HF Spaces 会自动使用 Dockerfile 中的 CMD 启动 # 如果需要本地测试,可以取消下面的注释 # if __name__ == "__main__": # import uvicorn # port = int(os.getenv("PORT", 7860)) # uvicorn.run(app, host="0.0.0.0", port=port)