Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from typing import List, Optional | |
| import numpy as np | |
| import io | |
| import os | |
| import gc | |
| from dotenv import load_dotenv | |
| from pydub import AudioSegment | |
| from utils import ( | |
| authenticate, | |
| split_documents, | |
| build_vectorstore, | |
| retrieve_context, | |
| retrieve_context_approx, | |
| build_prompt, | |
| ask_gemini, | |
| load_documents_gradio, | |
| transcribe | |
| ) | |
| load_dotenv() | |
| # Configure logging | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Define the specific origins that are allowed to make requests to your API | |
| origins = [ | |
| "http://localhost:3000", # For local development | |
| "https://chat-docx-ai-vercel.vercel.app", | |
| "https://huggingface.co", # Hugging Face Spaces domain | |
| "https://codegeass321-chatdocxai.hf.space", # Old HF space | |
| "https://codegeass321-backendserver.hf.space", # New HF space main UI | |
| "https://codegeass321-backendserver-8000.hf.space", # New HF space API endpoint | |
| "*", # Allow requests from the proxy (same origin) | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| client = authenticate() | |
| store = {"value": None} | |
| async def root(): | |
| """Root endpoint that redirects to status.""" | |
| logger.info("Root endpoint called") | |
| return { | |
| "message": "API is running. Use /status, /upload, or /ask endpoints." | |
| } | |
| async def options_upload(): | |
| return JSONResponse( | |
| content={"status": "ok"}, | |
| headers={ | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "POST, OPTIONS", | |
| "Access-Control-Allow-Headers": "Content-Type, Authorization", | |
| }, | |
| ) | |
| async def upload(files: List[UploadFile] = File(...)): | |
| headers = { | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "POST, OPTIONS", | |
| "Access-Control-Allow-Headers": "Content-Type, Authorization", | |
| } | |
| try: | |
| logger.info(f"Upload request received with {len(files)} files") | |
| for i, file in enumerate(files): | |
| logger.info(f"File {i+1}: {file.filename}, content_type: {file.content_type}") | |
| if not files: | |
| return JSONResponse( | |
| content={"status": "error", "message": "No files uploaded."}, | |
| status_code=400, | |
| headers=headers | |
| ) | |
| # Explicitly clear memory before processing new files | |
| logger.info("Clearing previous vector store from memory...") | |
| old_store_had_value = store.get("value") is not None | |
| store["value"] = None | |
| # Force garbage collection | |
| gc.collect() | |
| # More aggressive memory cleanup if needed | |
| if old_store_had_value: | |
| try: | |
| if hasattr(gc, 'collect'): | |
| for i in range(3): # Run multiple collection cycles | |
| gc.collect(i) | |
| except Exception as e: | |
| logger.warning(f"Error during aggressive garbage collection: {e}") | |
| logger.info("Memory cleared.") | |
| logger.info("Starting document processing...") | |
| try: | |
| raw_docs = load_documents_gradio(files) | |
| logger.info(f"Documents loaded: {len(raw_docs)} documents") | |
| except Exception as doc_error: | |
| logger.error(f"Error loading documents: {doc_error}") | |
| return JSONResponse( | |
| content={"status": "error", "message": f"Error loading documents: {str(doc_error)}"}, | |
| status_code=500, | |
| headers=headers | |
| ) | |
| if not raw_docs: | |
| return JSONResponse( | |
| content={"status": "error", "message": "No content could be extracted from the uploaded files."}, | |
| status_code=400, | |
| headers=headers | |
| ) | |
| logger.info("Documents loaded. Splitting documents...") | |
| try: | |
| chunks = split_documents(raw_docs) | |
| logger.info(f"Documents split into {len(chunks)} chunks") | |
| except Exception as split_error: | |
| logger.error(f"Error splitting documents: {split_error}") | |
| return JSONResponse( | |
| content={"status": "error", "message": f"Error splitting documents: {str(split_error)}"}, | |
| status_code=500, | |
| headers=headers | |
| ) | |
| logger.info("Documents split. Building vector store...") | |
| try: | |
| store["value"] = build_vectorstore(chunks) | |
| logger.info("Vector store built successfully.") | |
| except Exception as vector_error: | |
| logger.error(f"Error building vector store: {vector_error}") | |
| return JSONResponse( | |
| content={"status": "error", "message": f"Error building vector store: {str(vector_error)}"}, | |
| status_code=500, | |
| headers=headers | |
| ) | |
| return JSONResponse( | |
| content={"status": "success", "message": "Document processed successfully! You can now ask questions."}, | |
| headers=headers | |
| ) | |
| except Exception as e: | |
| import traceback | |
| error_trace = traceback.format_exc() | |
| logger.error(f"An error occurred during upload: {e}") | |
| logger.error(f"Traceback: {error_trace}") | |
| return JSONResponse( | |
| content={"status": "error", "message": f"An internal server error occurred: {str(e)}"}, | |
| status_code=500, | |
| headers=headers | |
| ) | |
| async def ask( | |
| text: Optional[str] = Form(None), | |
| audio: Optional[UploadFile] = File(None) | |
| ): | |
| logger.info(f"Ask endpoint called: text={bool(text)}, audio={bool(audio)}") | |
| transcribed = None | |
| if store["value"] is None: | |
| logger.warning("Ask called but no document is loaded") | |
| return JSONResponse({"status": "error", "message": "Please upload and process a document first."}, status_code=400) | |
| if text and text.strip(): | |
| query = text.strip() | |
| elif audio is not None: | |
| audio_bytes = await audio.read() | |
| try: | |
| audio_io = io.BytesIO(audio_bytes) | |
| audio_seg = AudioSegment.from_file(audio_io) | |
| y = np.array(audio_seg.get_array_of_samples()).astype(np.float32) | |
| if audio_seg.channels == 2: | |
| y = y.reshape((-1, 2)).mean(axis=1) # Convert to mono | |
| y /= np.max(np.abs(y)) # Normalize to [-1, 1] | |
| sr = audio_seg.frame_rate | |
| transcribed = transcribe((sr, y)) | |
| query = transcribed | |
| except FileNotFoundError as e: | |
| return JSONResponse({"status": "error", "message": "Audio decode failed: ffmpeg is not installed or not in PATH. Please install ffmpeg."}, status_code=400) | |
| except Exception as e: | |
| return JSONResponse({"status": "error", "message": f"Audio decode failed: {str(e)}"}, status_code=400) | |
| else: | |
| logger.warning("Ask called with no text or audio") | |
| return JSONResponse({"status": "error", "message": "Please provide a question by typing or speaking."}, status_code=400) | |
| logger.info(f"Processing query: {query[:100]}...") | |
| if store["value"]["chunks"] <= 50: | |
| top_chunks = retrieve_context(query, store["value"]) | |
| else: | |
| top_chunks = retrieve_context_approx(query, store["value"]) | |
| prompt = build_prompt(top_chunks, query) | |
| answer = ask_gemini(prompt, client) | |
| logger.info(f"Generated answer: {answer[:100]}...") | |
| return {"status": "success", "answer": answer.strip(), "transcribed": transcribed} | |
| async def status(): | |
| """Simple endpoint to check if the server is running.""" | |
| import platform | |
| import sys | |
| import psutil | |
| logger.info("Status endpoint called") | |
| # Get memory info | |
| process = psutil.Process(os.getpid()) | |
| memory_info = process.memory_info() | |
| status_info = { | |
| "status": "ok", | |
| "message": "Server is running", | |
| "google_api_key_set": bool(os.environ.get("GOOGLE_API_KEY")), | |
| "vectorstore_loaded": store.get("value") is not None, | |
| "system_info": { | |
| "platform": platform.platform(), | |
| "python_version": sys.version, | |
| "memory_usage_mb": memory_info.rss / (1024 * 1024), | |
| "cpu_percent": process.cpu_percent(), | |
| "available_memory_mb": psutil.virtual_memory().available / (1024 * 1024) | |
| }, | |
| "env_vars": { | |
| "PORT": os.environ.get("PORT"), | |
| "SPACE_ID": os.environ.get("SPACE_ID"), | |
| "SYSTEM": os.environ.get("SYSTEM") | |
| } | |
| } | |
| logger.info(f"Status response: {status_info}") | |
| return status_info | |
| async def clear_context(): | |
| """Clear the current document context and free up memory.""" | |
| global store | |
| logger.info("Clearing document context...") | |
| # Clear the store | |
| if store.get("value") is not None: | |
| store["value"] = None | |
| # Force garbage collection | |
| gc.collect() | |
| # Run a more aggressive memory cleanup | |
| try: | |
| if hasattr(gc, 'collect'): | |
| for i in range(3): # Run multiple collection cycles | |
| gc.collect(i) | |
| except Exception as e: | |
| logger.warning(f"Error during aggressive garbage collection: {e}") | |
| logger.info("Document context cleared successfully.") | |
| return {"status": "success", "message": "Document context cleared successfully."} | |
| else: | |
| logger.info("No document context to clear.") | |
| return {"status": "info", "message": "No document context was loaded."} | |