import os from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch import logging import threading import uvicorn from pathlib import Path import time import multiprocessing # CPU Performance Optimization os.environ["OMP_NUM_THREADS"] = str(multiprocessing.cpu_count()) os.environ["MKL_NUM_THREADS"] = str(multiprocessing.cpu_count()) os.environ["OPENBLAS_NUM_THREADS"] = str(multiprocessing.cpu_count()) os.environ["VECLIB_MAXIMUM_THREADS"] = str(multiprocessing.cpu_count()) os.environ["NUMEXPR_NUM_THREADS"] = str(multiprocessing.cpu_count()) # Set PyTorch to use all CPU cores torch.set_num_threads(multiprocessing.cpu_count()) # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # FastAPI app app = FastAPI( title="FastAPI Chatbot", description="Chatbot with FastAPI backend", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Pydantic models with fixed namespace conflicts class ChatRequest(BaseModel): message: str max_length: int = 100 temperature: float = 0.7 top_p: float = 0.9 class Config: protected_namespaces = () class ChatResponse(BaseModel): response: str model_name: str response_time: float class Config: protected_namespaces = () class HealthResponse(BaseModel): status: str is_model_loaded: bool model_name: str cache_directory: str startup_time: float class Config: protected_namespaces = () # Global variables tokenizer = None model = None generator = None startup_time = time.time() model_loaded = False # Configuration MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/model_cache") MAX_LENGTH = int(os.getenv("MAX_LENGTH", "100")) DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7")) # CPU Optimization settings CPU_CORES = multiprocessing.cpu_count() INTRAOP_THREADS = CPU_CORES INTEROP_THREADS = max(1, CPU_CORES // 2) # Use half cores for inter-op parallelism def ensure_cache_dir(): """Ensure cache directory exists""" Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) logger.info(f"Cache directory: {CACHE_DIR}") def is_model_cached(model_name: str) -> bool: """Check if model is already cached""" try: model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}" is_cached = model_path.exists() and any(model_path.iterdir()) logger.info(f"Model cached: {is_cached}") return is_cached except Exception as e: logger.error(f"Error checking cache: {e}") return False def load_model(): """Load the Hugging Face model with caching and CPU optimization""" global tokenizer, model, generator, model_loaded try: ensure_cache_dir() # Set PyTorch threading for optimal CPU performance torch.set_num_interop_threads(INTEROP_THREADS) torch.set_num_threads(INTRAOP_THREADS) logger.info(f"Loading model: {MODEL_NAME}") logger.info(f"Cache dir: {CACHE_DIR}") logger.info(f"CPU cores: {CPU_CORES}") logger.info(f"Intra-op threads: {INTRAOP_THREADS}") logger.info(f"Inter-op threads: {INTEROP_THREADS}") logger.info(f"CUDA available: {torch.cuda.is_available()}") start_time = time.time() # Load tokenizer first logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, cache_dir=CACHE_DIR, local_files_only=False, resume_download=False # Suppress deprecation warning ) # Add padding token if it doesn't exist if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model with CPU optimization logger.info("Loading model...") device_map = "auto" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, cache_dir=CACHE_DIR, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map=device_map, low_cpu_mem_usage=True, local_files_only=False, resume_download=False, # Suppress deprecation warning # CPU-specific optimizations use_cache=True, # Enable KV cache for faster generation ) # Enable CPU-specific optimizations (no manual .to('cpu') needed with device_map) model.eval() # Set to evaluation mode # Enable torch.jit optimization for CPU (optional, can improve performance) try: # This is experimental and might not work with all models # model = torch.jit.script(model) logger.info("Model loaded in CPU mode with optimizations") except Exception as e: logger.warning(f"JIT compilation not available: {e}") # Create text generation pipeline with optimized settings (no device arg to avoid accelerate conflict) logger.info("Creating pipeline...") dtype = torch.float16 if torch.cuda.is_available() else torch.float32 generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, torch_dtype=dtype, # CPU optimization: batch processing batch_size=1, # Optimal for single requests model_kwargs={ "use_cache": True, # Enable KV caching } ) load_time = time.time() - start_time model_loaded = True logger.info(f"โœ… Model loaded successfully in {load_time:.2f} seconds!") if hasattr(model, 'device'): logger.info(f"Model device: {model.device}") return True except Exception as e: logger.error(f"โŒ Error loading model: {str(e)}", exc_info=True) return False def generate_response(message: str, max_length: int = 100, temperature: float = 0.7, top_p: float = 0.9) -> tuple[str, float]: """Generate response using the loaded model with CPU optimizations""" if not generator: return "โŒ Model not loaded. Please wait for initialization...", 0.0 try: start_time = time.time() # Optimize input length to prevent excessive computation max_input_length = 512 # Reasonable limit for DialoGPT if len(message) > max_input_length: message = message[:max_input_length] logger.info(f"Input truncated to {max_input_length} characters") # Calculate total max length (input + generation) input_length = len(tokenizer.encode(message)) total_max_length = min(input_length + max_length, 1024) # DialoGPT max context # Generate response with optimized parameters for CPU with torch.no_grad(): # Disable gradient computation for inference response = generator( message, max_length=total_max_length, min_length=input_length + 10, # Ensure some generation temperature=temperature, top_p=top_p, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=True, repetition_penalty=1.1, length_penalty=1.0, early_stopping=True, # Stop when EOS is generated # Remove unsupported parameters # truncation=True # This was causing the error ) # Extract generated text generated_text = response[0]['generated_text'] # Clean up response - remove input prompt if generated_text.startswith(message): bot_response = generated_text[len(message):].strip() else: bot_response = generated_text.strip() # Post-process response if bot_response: # Remove any repetitive patterns sentences = bot_response.split('.') if len(sentences) > 1: # Take only the first complete sentence to avoid repetition bot_response = sentences[0].strip() + '.' # Ensure response isn't too short or just punctuation if len(bot_response.replace('.', '').replace('!', '').replace('?', '').strip()) < 3: bot_response = "I understand. Could you tell me more about that?" else: bot_response = "I'm not sure how to respond to that. Could you try rephrasing?" response_time = time.time() - start_time logger.info(f"Generated response in {response_time:.2f}s (length: {len(bot_response)} chars)") return bot_response, response_time except Exception as e: logger.error(f"Error generating response: {str(e)}", exc_info=True) return f"โŒ Error generating response: {str(e)}", 0.0 # FastAPI endpoints @app.get("/") async def root(): """Root endpoint""" return {"message": "FastAPI Chatbot API", "status": "running"} @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint with detailed information""" return HealthResponse( status="healthy" if model_loaded else "initializing", is_model_loaded=model_loaded, model_name=MODEL_NAME, cache_directory=CACHE_DIR, startup_time=time.time() - startup_time ) @app.post("/chat", response_model=ChatResponse) async def chat_endpoint(request: ChatRequest): """Chat endpoint for API access""" if not model_loaded: raise HTTPException( status_code=503, detail="Model not loaded yet. Please wait for initialization." ) # Validate input - stricter limits for free tier if not request.message.strip(): raise HTTPException(status_code=400, detail="Message cannot be empty") if len(request.message) > 500: # Reduced limit for HF Spaces raise HTTPException(status_code=400, detail="Message too long (max 500 characters)") # Generate response response_text, response_time = generate_response( request.message.strip(), request.max_length, request.temperature, request.top_p ) return ChatResponse( response=response_text, model_name=MODEL_NAME, response_time=response_time ) @app.get("/model-info") async def get_model_info(): """Get detailed model information including CPU optimization details""" device = "cuda" if torch.cuda.is_available() else "cpu" if model and hasattr(model, 'device'): device = str(model.device) return { "model_name": MODEL_NAME, "model_loaded": model_loaded, "device": device, "cache_directory": CACHE_DIR, "model_cached": is_model_cached(MODEL_NAME), "cpu_optimization": { "cpu_cores": CPU_CORES, "intra_op_threads": INTRAOP_THREADS, "inter_op_threads": INTEROP_THREADS, "torch_threads": torch.get_num_threads(), }, "parameters": { "max_length": MAX_LENGTH, "default_temperature": DEFAULT_TEMPERATURE } } @app.on_event("startup") async def startup_event(): """Load model on startup""" logger.info("๐Ÿš€ Starting FastAPI Chatbot...") logger.info("๐Ÿ“ฆ Loading model...") # Load model in background thread to not block startup def load_model_background(): global model_loaded model_loaded = load_model() if model_loaded: logger.info("โœ… Model loaded successfully!") else: logger.error("โŒ Failed to load model.") # Start model loading in background threading.Thread(target=load_model_background, daemon=True).start() def run_fastapi(): """Run FastAPI server with CPU optimization""" # Additional CPU optimization for uvicorn config = uvicorn.Config( app, host="0.0.0.0", port=7860, log_level="info", access_log=True, workers=1, # Single worker to avoid model loading multiple times loop="asyncio", # Use asyncio loop for better performance http="httptools", # Use httptools for faster HTTP parsing ) server = uvicorn.Server(config) server.run() if __name__ == "__main__": logger.info(f"๐Ÿš€ Starting FastAPI Chatbot with CPU optimization...") logger.info(f"๐Ÿ’ป CPU cores available: {CPU_CORES}") logger.info(f"๐Ÿงต Thread configuration - Intra-op: {INTRAOP_THREADS}, Inter-op: {INTEROP_THREADS}") run_fastapi()