chatbot / main.py
Soumik555's picture
hello
8328841
import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import logging
import threading
import uvicorn
from pathlib import Path
import time
import multiprocessing
# CPU Performance Optimization
os.environ["OMP_NUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["MKL_NUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["OPENBLAS_NUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["VECLIB_MAXIMUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["NUMEXPR_NUM_THREADS"] = str(multiprocessing.cpu_count())
# Set PyTorch to use all CPU cores
torch.set_num_threads(multiprocessing.cpu_count())
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI(
title="FastAPI Chatbot",
description="Chatbot with FastAPI backend",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models with fixed namespace conflicts
class ChatRequest(BaseModel):
message: str
max_length: int = 100
temperature: float = 0.7
top_p: float = 0.9
class Config:
protected_namespaces = ()
class ChatResponse(BaseModel):
response: str
model_name: str
response_time: float
class Config:
protected_namespaces = ()
class HealthResponse(BaseModel):
status: str
is_model_loaded: bool
model_name: str
cache_directory: str
startup_time: float
class Config:
protected_namespaces = ()
# Global variables
tokenizer = None
model = None
generator = None
startup_time = time.time()
model_loaded = False
# Configuration
MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium")
CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/model_cache")
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "100"))
DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))
# CPU Optimization settings
CPU_CORES = multiprocessing.cpu_count()
INTRAOP_THREADS = CPU_CORES
INTEROP_THREADS = max(1, CPU_CORES // 2) # Use half cores for inter-op parallelism
def ensure_cache_dir():
"""Ensure cache directory exists"""
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
logger.info(f"Cache directory: {CACHE_DIR}")
def is_model_cached(model_name: str) -> bool:
"""Check if model is already cached"""
try:
model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}"
is_cached = model_path.exists() and any(model_path.iterdir())
logger.info(f"Model cached: {is_cached}")
return is_cached
except Exception as e:
logger.error(f"Error checking cache: {e}")
return False
def load_model():
"""Load the Hugging Face model with caching and CPU optimization"""
global tokenizer, model, generator, model_loaded
try:
ensure_cache_dir()
# Set PyTorch threading for optimal CPU performance
torch.set_num_interop_threads(INTEROP_THREADS)
torch.set_num_threads(INTRAOP_THREADS)
logger.info(f"Loading model: {MODEL_NAME}")
logger.info(f"Cache dir: {CACHE_DIR}")
logger.info(f"CPU cores: {CPU_CORES}")
logger.info(f"Intra-op threads: {INTRAOP_THREADS}")
logger.info(f"Inter-op threads: {INTEROP_THREADS}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
start_time = time.time()
# Load tokenizer first
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
local_files_only=False,
resume_download=False # Suppress deprecation warning
)
# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model with CPU optimization
logger.info("Loading model...")
device_map = "auto" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map=device_map,
low_cpu_mem_usage=True,
local_files_only=False,
resume_download=False, # Suppress deprecation warning
# CPU-specific optimizations
use_cache=True, # Enable KV cache for faster generation
)
# Enable CPU-specific optimizations (no manual .to('cpu') needed with device_map)
model.eval() # Set to evaluation mode
# Enable torch.jit optimization for CPU (optional, can improve performance)
try:
# This is experimental and might not work with all models
# model = torch.jit.script(model)
logger.info("Model loaded in CPU mode with optimizations")
except Exception as e:
logger.warning(f"JIT compilation not available: {e}")
# Create text generation pipeline with optimized settings (no device arg to avoid accelerate conflict)
logger.info("Creating pipeline...")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=dtype,
# CPU optimization: batch processing
batch_size=1, # Optimal for single requests
model_kwargs={
"use_cache": True, # Enable KV caching
}
)
load_time = time.time() - start_time
model_loaded = True
logger.info(f"βœ… Model loaded successfully in {load_time:.2f} seconds!")
if hasattr(model, 'device'):
logger.info(f"Model device: {model.device}")
return True
except Exception as e:
logger.error(f"❌ Error loading model: {str(e)}", exc_info=True)
return False
def generate_response(message: str, max_length: int = 100, temperature: float = 0.7, top_p: float = 0.9) -> tuple[str, float]:
"""Generate response using the loaded model with CPU optimizations"""
if not generator:
return "❌ Model not loaded. Please wait for initialization...", 0.0
try:
start_time = time.time()
# Optimize input length to prevent excessive computation
max_input_length = 512 # Reasonable limit for DialoGPT
if len(message) > max_input_length:
message = message[:max_input_length]
logger.info(f"Input truncated to {max_input_length} characters")
# Calculate total max length (input + generation)
input_length = len(tokenizer.encode(message))
total_max_length = min(input_length + max_length, 1024) # DialoGPT max context
# Generate response with optimized parameters for CPU
with torch.no_grad(): # Disable gradient computation for inference
response = generator(
message,
max_length=total_max_length,
min_length=input_length + 10, # Ensure some generation
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
repetition_penalty=1.1,
length_penalty=1.0,
early_stopping=True, # Stop when EOS is generated
# Remove unsupported parameters
# truncation=True # This was causing the error
)
# Extract generated text
generated_text = response[0]['generated_text']
# Clean up response - remove input prompt
if generated_text.startswith(message):
bot_response = generated_text[len(message):].strip()
else:
bot_response = generated_text.strip()
# Post-process response
if bot_response:
# Remove any repetitive patterns
sentences = bot_response.split('.')
if len(sentences) > 1:
# Take only the first complete sentence to avoid repetition
bot_response = sentences[0].strip() + '.'
# Ensure response isn't too short or just punctuation
if len(bot_response.replace('.', '').replace('!', '').replace('?', '').strip()) < 3:
bot_response = "I understand. Could you tell me more about that?"
else:
bot_response = "I'm not sure how to respond to that. Could you try rephrasing?"
response_time = time.time() - start_time
logger.info(f"Generated response in {response_time:.2f}s (length: {len(bot_response)} chars)")
return bot_response, response_time
except Exception as e:
logger.error(f"Error generating response: {str(e)}", exc_info=True)
return f"❌ Error generating response: {str(e)}", 0.0
# FastAPI endpoints
@app.get("/")
async def root():
"""Root endpoint"""
return {"message": "FastAPI Chatbot API", "status": "running"}
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint with detailed information"""
return HealthResponse(
status="healthy" if model_loaded else "initializing",
is_model_loaded=model_loaded,
model_name=MODEL_NAME,
cache_directory=CACHE_DIR,
startup_time=time.time() - startup_time
)
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""Chat endpoint for API access"""
if not model_loaded:
raise HTTPException(
status_code=503,
detail="Model not loaded yet. Please wait for initialization."
)
# Validate input - stricter limits for free tier
if not request.message.strip():
raise HTTPException(status_code=400, detail="Message cannot be empty")
if len(request.message) > 500: # Reduced limit for HF Spaces
raise HTTPException(status_code=400, detail="Message too long (max 500 characters)")
# Generate response
response_text, response_time = generate_response(
request.message.strip(),
request.max_length,
request.temperature,
request.top_p
)
return ChatResponse(
response=response_text,
model_name=MODEL_NAME,
response_time=response_time
)
@app.get("/model-info")
async def get_model_info():
"""Get detailed model information including CPU optimization details"""
device = "cuda" if torch.cuda.is_available() else "cpu"
if model and hasattr(model, 'device'):
device = str(model.device)
return {
"model_name": MODEL_NAME,
"model_loaded": model_loaded,
"device": device,
"cache_directory": CACHE_DIR,
"model_cached": is_model_cached(MODEL_NAME),
"cpu_optimization": {
"cpu_cores": CPU_CORES,
"intra_op_threads": INTRAOP_THREADS,
"inter_op_threads": INTEROP_THREADS,
"torch_threads": torch.get_num_threads(),
},
"parameters": {
"max_length": MAX_LENGTH,
"default_temperature": DEFAULT_TEMPERATURE
}
}
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("πŸš€ Starting FastAPI Chatbot...")
logger.info("πŸ“¦ Loading model...")
# Load model in background thread to not block startup
def load_model_background():
global model_loaded
model_loaded = load_model()
if model_loaded:
logger.info("βœ… Model loaded successfully!")
else:
logger.error("❌ Failed to load model.")
# Start model loading in background
threading.Thread(target=load_model_background, daemon=True).start()
def run_fastapi():
"""Run FastAPI server with CPU optimization"""
# Additional CPU optimization for uvicorn
config = uvicorn.Config(
app,
host="0.0.0.0",
port=7860,
log_level="info",
access_log=True,
workers=1, # Single worker to avoid model loading multiple times
loop="asyncio", # Use asyncio loop for better performance
http="httptools", # Use httptools for faster HTTP parsing
)
server = uvicorn.Server(config)
server.run()
if __name__ == "__main__":
logger.info(f"πŸš€ Starting FastAPI Chatbot with CPU optimization...")
logger.info(f"πŸ’» CPU cores available: {CPU_CORES}")
logger.info(f"🧡 Thread configuration - Intra-op: {INTRAOP_THREADS}, Inter-op: {INTEROP_THREADS}")
run_fastapi()