Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| import logging | |
| import threading | |
| import uvicorn | |
| from pathlib import Path | |
| import time | |
| import multiprocessing | |
| # CPU Performance Optimization | |
| os.environ["OMP_NUM_THREADS"] = str(multiprocessing.cpu_count()) | |
| os.environ["MKL_NUM_THREADS"] = str(multiprocessing.cpu_count()) | |
| os.environ["OPENBLAS_NUM_THREADS"] = str(multiprocessing.cpu_count()) | |
| os.environ["VECLIB_MAXIMUM_THREADS"] = str(multiprocessing.cpu_count()) | |
| os.environ["NUMEXPR_NUM_THREADS"] = str(multiprocessing.cpu_count()) | |
| # Set PyTorch to use all CPU cores | |
| torch.set_num_threads(multiprocessing.cpu_count()) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI app | |
| app = FastAPI( | |
| title="FastAPI Chatbot", | |
| description="Chatbot with FastAPI backend", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models with fixed namespace conflicts | |
| class ChatRequest(BaseModel): | |
| message: str | |
| max_length: int = 100 | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| class Config: | |
| protected_namespaces = () | |
| class ChatResponse(BaseModel): | |
| response: str | |
| model_name: str | |
| response_time: float | |
| class Config: | |
| protected_namespaces = () | |
| class HealthResponse(BaseModel): | |
| status: str | |
| is_model_loaded: bool | |
| model_name: str | |
| cache_directory: str | |
| startup_time: float | |
| class Config: | |
| protected_namespaces = () | |
| # Global variables | |
| tokenizer = None | |
| model = None | |
| generator = None | |
| startup_time = time.time() | |
| model_loaded = False | |
| # Configuration | |
| MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") | |
| CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/model_cache") | |
| MAX_LENGTH = int(os.getenv("MAX_LENGTH", "100")) | |
| DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7")) | |
| # CPU Optimization settings | |
| CPU_CORES = multiprocessing.cpu_count() | |
| INTRAOP_THREADS = CPU_CORES | |
| INTEROP_THREADS = max(1, CPU_CORES // 2) # Use half cores for inter-op parallelism | |
| def ensure_cache_dir(): | |
| """Ensure cache directory exists""" | |
| Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Cache directory: {CACHE_DIR}") | |
| def is_model_cached(model_name: str) -> bool: | |
| """Check if model is already cached""" | |
| try: | |
| model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}" | |
| is_cached = model_path.exists() and any(model_path.iterdir()) | |
| logger.info(f"Model cached: {is_cached}") | |
| return is_cached | |
| except Exception as e: | |
| logger.error(f"Error checking cache: {e}") | |
| return False | |
| def load_model(): | |
| """Load the Hugging Face model with caching and CPU optimization""" | |
| global tokenizer, model, generator, model_loaded | |
| try: | |
| ensure_cache_dir() | |
| # Set PyTorch threading for optimal CPU performance | |
| torch.set_num_interop_threads(INTEROP_THREADS) | |
| torch.set_num_threads(INTRAOP_THREADS) | |
| logger.info(f"Loading model: {MODEL_NAME}") | |
| logger.info(f"Cache dir: {CACHE_DIR}") | |
| logger.info(f"CPU cores: {CPU_CORES}") | |
| logger.info(f"Intra-op threads: {INTRAOP_THREADS}") | |
| logger.info(f"Inter-op threads: {INTEROP_THREADS}") | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| start_time = time.time() | |
| # Load tokenizer first | |
| logger.info("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| cache_dir=CACHE_DIR, | |
| local_files_only=False, | |
| resume_download=False # Suppress deprecation warning | |
| ) | |
| # Add padding token if it doesn't exist | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model with CPU optimization | |
| logger.info("Loading model...") | |
| device_map = "auto" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| cache_dir=CACHE_DIR, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map=device_map, | |
| low_cpu_mem_usage=True, | |
| local_files_only=False, | |
| resume_download=False, # Suppress deprecation warning | |
| # CPU-specific optimizations | |
| use_cache=True, # Enable KV cache for faster generation | |
| ) | |
| # Enable CPU-specific optimizations (no manual .to('cpu') needed with device_map) | |
| model.eval() # Set to evaluation mode | |
| # Enable torch.jit optimization for CPU (optional, can improve performance) | |
| try: | |
| # This is experimental and might not work with all models | |
| # model = torch.jit.script(model) | |
| logger.info("Model loaded in CPU mode with optimizations") | |
| except Exception as e: | |
| logger.warning(f"JIT compilation not available: {e}") | |
| # Create text generation pipeline with optimized settings (no device arg to avoid accelerate conflict) | |
| logger.info("Creating pipeline...") | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| torch_dtype=dtype, | |
| # CPU optimization: batch processing | |
| batch_size=1, # Optimal for single requests | |
| model_kwargs={ | |
| "use_cache": True, # Enable KV caching | |
| } | |
| ) | |
| load_time = time.time() - start_time | |
| model_loaded = True | |
| logger.info(f"β Model loaded successfully in {load_time:.2f} seconds!") | |
| if hasattr(model, 'device'): | |
| logger.info(f"Model device: {model.device}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Error loading model: {str(e)}", exc_info=True) | |
| return False | |
| def generate_response(message: str, max_length: int = 100, temperature: float = 0.7, top_p: float = 0.9) -> tuple[str, float]: | |
| """Generate response using the loaded model with CPU optimizations""" | |
| if not generator: | |
| return "β Model not loaded. Please wait for initialization...", 0.0 | |
| try: | |
| start_time = time.time() | |
| # Optimize input length to prevent excessive computation | |
| max_input_length = 512 # Reasonable limit for DialoGPT | |
| if len(message) > max_input_length: | |
| message = message[:max_input_length] | |
| logger.info(f"Input truncated to {max_input_length} characters") | |
| # Calculate total max length (input + generation) | |
| input_length = len(tokenizer.encode(message)) | |
| total_max_length = min(input_length + max_length, 1024) # DialoGPT max context | |
| # Generate response with optimized parameters for CPU | |
| with torch.no_grad(): # Disable gradient computation for inference | |
| response = generator( | |
| message, | |
| max_length=total_max_length, | |
| min_length=input_length + 10, # Ensure some generation | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| repetition_penalty=1.1, | |
| length_penalty=1.0, | |
| early_stopping=True, # Stop when EOS is generated | |
| # Remove unsupported parameters | |
| # truncation=True # This was causing the error | |
| ) | |
| # Extract generated text | |
| generated_text = response[0]['generated_text'] | |
| # Clean up response - remove input prompt | |
| if generated_text.startswith(message): | |
| bot_response = generated_text[len(message):].strip() | |
| else: | |
| bot_response = generated_text.strip() | |
| # Post-process response | |
| if bot_response: | |
| # Remove any repetitive patterns | |
| sentences = bot_response.split('.') | |
| if len(sentences) > 1: | |
| # Take only the first complete sentence to avoid repetition | |
| bot_response = sentences[0].strip() + '.' | |
| # Ensure response isn't too short or just punctuation | |
| if len(bot_response.replace('.', '').replace('!', '').replace('?', '').strip()) < 3: | |
| bot_response = "I understand. Could you tell me more about that?" | |
| else: | |
| bot_response = "I'm not sure how to respond to that. Could you try rephrasing?" | |
| response_time = time.time() - start_time | |
| logger.info(f"Generated response in {response_time:.2f}s (length: {len(bot_response)} chars)") | |
| return bot_response, response_time | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}", exc_info=True) | |
| return f"β Error generating response: {str(e)}", 0.0 | |
| # FastAPI endpoints | |
| async def root(): | |
| """Root endpoint""" | |
| return {"message": "FastAPI Chatbot API", "status": "running"} | |
| async def health_check(): | |
| """Health check endpoint with detailed information""" | |
| return HealthResponse( | |
| status="healthy" if model_loaded else "initializing", | |
| is_model_loaded=model_loaded, | |
| model_name=MODEL_NAME, | |
| cache_directory=CACHE_DIR, | |
| startup_time=time.time() - startup_time | |
| ) | |
| async def chat_endpoint(request: ChatRequest): | |
| """Chat endpoint for API access""" | |
| if not model_loaded: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded yet. Please wait for initialization." | |
| ) | |
| # Validate input - stricter limits for free tier | |
| if not request.message.strip(): | |
| raise HTTPException(status_code=400, detail="Message cannot be empty") | |
| if len(request.message) > 500: # Reduced limit for HF Spaces | |
| raise HTTPException(status_code=400, detail="Message too long (max 500 characters)") | |
| # Generate response | |
| response_text, response_time = generate_response( | |
| request.message.strip(), | |
| request.max_length, | |
| request.temperature, | |
| request.top_p | |
| ) | |
| return ChatResponse( | |
| response=response_text, | |
| model_name=MODEL_NAME, | |
| response_time=response_time | |
| ) | |
| async def get_model_info(): | |
| """Get detailed model information including CPU optimization details""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if model and hasattr(model, 'device'): | |
| device = str(model.device) | |
| return { | |
| "model_name": MODEL_NAME, | |
| "model_loaded": model_loaded, | |
| "device": device, | |
| "cache_directory": CACHE_DIR, | |
| "model_cached": is_model_cached(MODEL_NAME), | |
| "cpu_optimization": { | |
| "cpu_cores": CPU_CORES, | |
| "intra_op_threads": INTRAOP_THREADS, | |
| "inter_op_threads": INTEROP_THREADS, | |
| "torch_threads": torch.get_num_threads(), | |
| }, | |
| "parameters": { | |
| "max_length": MAX_LENGTH, | |
| "default_temperature": DEFAULT_TEMPERATURE | |
| } | |
| } | |
| async def startup_event(): | |
| """Load model on startup""" | |
| logger.info("π Starting FastAPI Chatbot...") | |
| logger.info("π¦ Loading model...") | |
| # Load model in background thread to not block startup | |
| def load_model_background(): | |
| global model_loaded | |
| model_loaded = load_model() | |
| if model_loaded: | |
| logger.info("β Model loaded successfully!") | |
| else: | |
| logger.error("β Failed to load model.") | |
| # Start model loading in background | |
| threading.Thread(target=load_model_background, daemon=True).start() | |
| def run_fastapi(): | |
| """Run FastAPI server with CPU optimization""" | |
| # Additional CPU optimization for uvicorn | |
| config = uvicorn.Config( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info", | |
| access_log=True, | |
| workers=1, # Single worker to avoid model loading multiple times | |
| loop="asyncio", # Use asyncio loop for better performance | |
| http="httptools", # Use httptools for faster HTTP parsing | |
| ) | |
| server = uvicorn.Server(config) | |
| server.run() | |
| if __name__ == "__main__": | |
| logger.info(f"π Starting FastAPI Chatbot with CPU optimization...") | |
| logger.info(f"π» CPU cores available: {CPU_CORES}") | |
| logger.info(f"π§΅ Thread configuration - Intra-op: {INTRAOP_THREADS}, Inter-op: {INTEROP_THREADS}") | |
| run_fastapi() |