File size: 13,338 Bytes
ef34958
 
bac60fc
ef34958
 
 
 
 
bac60fc
 
ef34958
 
d25d5e9
 
 
 
 
 
 
 
 
 
 
bac60fc
ef34958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bac60fc
 
ef34958
 
 
 
bac60fc
 
ef34958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d25d5e9
 
 
 
 
ef34958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d25d5e9
ef34958
 
 
 
 
d25d5e9
 
 
 
ef34958
 
d25d5e9
 
 
ef34958
 
 
 
 
 
 
 
 
8328841
 
ef34958
 
 
 
 
 
d25d5e9
ef34958
d25d5e9
 
ef34958
 
 
 
d25d5e9
ef34958
d25d5e9
8328841
d25d5e9
 
ef34958
 
8328841
 
 
 
 
 
 
 
 
 
d25d5e9
8328841
ef34958
8328841
d25d5e9
ef34958
 
 
 
8328841
d25d5e9
 
 
 
 
ef34958
 
 
 
 
d25d5e9
 
 
ef34958
 
 
 
 
 
 
d25d5e9
 
ef34958
 
 
 
 
 
d25d5e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef34958
 
 
 
d25d5e9
ef34958
 
 
 
 
d25d5e9
 
 
 
 
 
 
 
 
 
 
 
ef34958
 
 
d25d5e9
ef34958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7b09de
ef34958
 
 
e7b09de
 
ef34958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d25d5e9
ef34958
 
 
 
 
 
 
 
 
 
d25d5e9
 
 
 
 
 
ef34958
 
 
 
 
bac60fc
9ef9a4e
ef34958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d25d5e9
 
 
ef34958
 
d25d5e9
ef34958
d25d5e9
 
 
 
ef34958
d25d5e9
 
 
bac60fc
 
d25d5e9
 
 
ef34958
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import logging
import threading
import uvicorn
from pathlib import Path
import time
import multiprocessing

# CPU Performance Optimization
os.environ["OMP_NUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["MKL_NUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["OPENBLAS_NUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["VECLIB_MAXIMUM_THREADS"] = str(multiprocessing.cpu_count())
os.environ["NUMEXPR_NUM_THREADS"] = str(multiprocessing.cpu_count())

# Set PyTorch to use all CPU cores
torch.set_num_threads(multiprocessing.cpu_count())

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# FastAPI app
app = FastAPI(
    title="FastAPI Chatbot", 
    description="Chatbot with FastAPI backend",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Pydantic models with fixed namespace conflicts
class ChatRequest(BaseModel):
    message: str
    max_length: int = 100
    temperature: float = 0.7
    top_p: float = 0.9

    class Config:
        protected_namespaces = ()

class ChatResponse(BaseModel):
    response: str
    model_name: str
    response_time: float

    class Config:
        protected_namespaces = ()

class HealthResponse(BaseModel):
    status: str
    is_model_loaded: bool
    model_name: str
    cache_directory: str
    startup_time: float

    class Config:
        protected_namespaces = ()

# Global variables
tokenizer = None
model = None
generator = None
startup_time = time.time()
model_loaded = False

# Configuration
MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium")
CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/model_cache")
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "100"))
DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))

# CPU Optimization settings
CPU_CORES = multiprocessing.cpu_count()
INTRAOP_THREADS = CPU_CORES
INTEROP_THREADS = max(1, CPU_CORES // 2)  # Use half cores for inter-op parallelism

def ensure_cache_dir():
    """Ensure cache directory exists"""
    Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
    logger.info(f"Cache directory: {CACHE_DIR}")

def is_model_cached(model_name: str) -> bool:
    """Check if model is already cached"""
    try:
        model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}"
        is_cached = model_path.exists() and any(model_path.iterdir())
        logger.info(f"Model cached: {is_cached}")
        return is_cached
    except Exception as e:
        logger.error(f"Error checking cache: {e}")
        return False

def load_model():
    """Load the Hugging Face model with caching and CPU optimization"""
    global tokenizer, model, generator, model_loaded
    
    try:
        ensure_cache_dir()
        
        # Set PyTorch threading for optimal CPU performance
        torch.set_num_interop_threads(INTEROP_THREADS)
        torch.set_num_threads(INTRAOP_THREADS)
        
        logger.info(f"Loading model: {MODEL_NAME}")
        logger.info(f"Cache dir: {CACHE_DIR}")
        logger.info(f"CPU cores: {CPU_CORES}")
        logger.info(f"Intra-op threads: {INTRAOP_THREADS}")
        logger.info(f"Inter-op threads: {INTEROP_THREADS}")
        logger.info(f"CUDA available: {torch.cuda.is_available()}")
        
        start_time = time.time()
        
        # Load tokenizer first
        logger.info("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_NAME,
            cache_dir=CACHE_DIR,
            local_files_only=False,
            resume_download=False  # Suppress deprecation warning
        )
        
        # Add padding token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model with CPU optimization
        logger.info("Loading model...")
        device_map = "auto" if torch.cuda.is_available() else "cpu"
        
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            cache_dir=CACHE_DIR,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map=device_map,
            low_cpu_mem_usage=True,
            local_files_only=False,
            resume_download=False,  # Suppress deprecation warning
            # CPU-specific optimizations
            use_cache=True,  # Enable KV cache for faster generation
        )
        
        # Enable CPU-specific optimizations (no manual .to('cpu') needed with device_map)
        model.eval()  # Set to evaluation mode
        
        # Enable torch.jit optimization for CPU (optional, can improve performance)
        try:
            # This is experimental and might not work with all models
            # model = torch.jit.script(model)
            logger.info("Model loaded in CPU mode with optimizations")
        except Exception as e:
            logger.warning(f"JIT compilation not available: {e}")
        
        # Create text generation pipeline with optimized settings (no device arg to avoid accelerate conflict)
        logger.info("Creating pipeline...")
        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        
        generator = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            torch_dtype=dtype,
            # CPU optimization: batch processing
            batch_size=1,  # Optimal for single requests
            model_kwargs={
                "use_cache": True,  # Enable KV caching
            }
        )
        
        load_time = time.time() - start_time
        model_loaded = True
        logger.info(f"βœ… Model loaded successfully in {load_time:.2f} seconds!")
        
        if hasattr(model, 'device'):
            logger.info(f"Model device: {model.device}")
        
        return True
        
    except Exception as e:
        logger.error(f"❌ Error loading model: {str(e)}", exc_info=True)
        return False

def generate_response(message: str, max_length: int = 100, temperature: float = 0.7, top_p: float = 0.9) -> tuple[str, float]:
    """Generate response using the loaded model with CPU optimizations"""
    if not generator:
        return "❌ Model not loaded. Please wait for initialization...", 0.0
    
    try:
        start_time = time.time()
        
        # Optimize input length to prevent excessive computation
        max_input_length = 512  # Reasonable limit for DialoGPT
        if len(message) > max_input_length:
            message = message[:max_input_length]
            logger.info(f"Input truncated to {max_input_length} characters")
        
        # Calculate total max length (input + generation)
        input_length = len(tokenizer.encode(message))
        total_max_length = min(input_length + max_length, 1024)  # DialoGPT max context
        
        # Generate response with optimized parameters for CPU
        with torch.no_grad():  # Disable gradient computation for inference
            response = generator(
                message,
                max_length=total_max_length,
                min_length=input_length + 10,  # Ensure some generation
                temperature=temperature,
                top_p=top_p,
                num_return_sequences=1,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                do_sample=True,
                repetition_penalty=1.1,
                length_penalty=1.0,
                early_stopping=True,  # Stop when EOS is generated
                # Remove unsupported parameters
                # truncation=True  # This was causing the error
            )
        
        # Extract generated text
        generated_text = response[0]['generated_text']
        
        # Clean up response - remove input prompt
        if generated_text.startswith(message):
            bot_response = generated_text[len(message):].strip()
        else:
            bot_response = generated_text.strip()
        
        # Post-process response
        if bot_response:
            # Remove any repetitive patterns
            sentences = bot_response.split('.')
            if len(sentences) > 1:
                # Take only the first complete sentence to avoid repetition
                bot_response = sentences[0].strip() + '.'
            
            # Ensure response isn't too short or just punctuation
            if len(bot_response.replace('.', '').replace('!', '').replace('?', '').strip()) < 3:
                bot_response = "I understand. Could you tell me more about that?"
        else:
            bot_response = "I'm not sure how to respond to that. Could you try rephrasing?"
        
        response_time = time.time() - start_time
        logger.info(f"Generated response in {response_time:.2f}s (length: {len(bot_response)} chars)")
        
        return bot_response, response_time
        
    except Exception as e:
        logger.error(f"Error generating response: {str(e)}", exc_info=True)
        return f"❌ Error generating response: {str(e)}", 0.0

# FastAPI endpoints
@app.get("/")
async def root():
    """Root endpoint"""
    return {"message": "FastAPI Chatbot API", "status": "running"}

@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint with detailed information"""
    return HealthResponse(
        status="healthy" if model_loaded else "initializing",
        is_model_loaded=model_loaded,
        model_name=MODEL_NAME,
        cache_directory=CACHE_DIR,
        startup_time=time.time() - startup_time
    )

@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
    """Chat endpoint for API access"""
    if not model_loaded:
        raise HTTPException(
            status_code=503, 
            detail="Model not loaded yet. Please wait for initialization."
        )
    
    # Validate input - stricter limits for free tier
    if not request.message.strip():
        raise HTTPException(status_code=400, detail="Message cannot be empty")
    
    if len(request.message) > 500:  # Reduced limit for HF Spaces
        raise HTTPException(status_code=400, detail="Message too long (max 500 characters)")
    
    # Generate response
    response_text, response_time = generate_response(
        request.message.strip(), 
        request.max_length, 
        request.temperature,
        request.top_p
    )
    
    return ChatResponse(
        response=response_text,
        model_name=MODEL_NAME,
        response_time=response_time
    )

@app.get("/model-info")
async def get_model_info():
    """Get detailed model information including CPU optimization details"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if model and hasattr(model, 'device'):
        device = str(model.device)
    
    return {
        "model_name": MODEL_NAME,
        "model_loaded": model_loaded,
        "device": device,
        "cache_directory": CACHE_DIR,
        "model_cached": is_model_cached(MODEL_NAME),
        "cpu_optimization": {
            "cpu_cores": CPU_CORES,
            "intra_op_threads": INTRAOP_THREADS,
            "inter_op_threads": INTEROP_THREADS,
            "torch_threads": torch.get_num_threads(),
        },
        "parameters": {
            "max_length": MAX_LENGTH,
            "default_temperature": DEFAULT_TEMPERATURE
        }
    }

@app.on_event("startup")
async def startup_event():
    """Load model on startup"""
    logger.info("πŸš€ Starting FastAPI Chatbot...")
    logger.info("πŸ“¦ Loading model...")
    
    # Load model in background thread to not block startup
    def load_model_background():
        global model_loaded
        model_loaded = load_model()
        if model_loaded:
            logger.info("βœ… Model loaded successfully!")
        else:
            logger.error("❌ Failed to load model.")
    
    # Start model loading in background
    threading.Thread(target=load_model_background, daemon=True).start()

def run_fastapi():
    """Run FastAPI server with CPU optimization"""
    # Additional CPU optimization for uvicorn
    config = uvicorn.Config(
        app, 
        host="0.0.0.0", 
        port=7860,
        log_level="info",
        access_log=True,
        workers=1,  # Single worker to avoid model loading multiple times
        loop="asyncio",  # Use asyncio loop for better performance
        http="httptools",  # Use httptools for faster HTTP parsing
    )
    
    server = uvicorn.Server(config)
    server.run()

if __name__ == "__main__":
    logger.info(f"πŸš€ Starting FastAPI Chatbot with CPU optimization...")
    logger.info(f"πŸ’» CPU cores available: {CPU_CORES}")
    logger.info(f"🧡 Thread configuration - Intra-op: {INTRAOP_THREADS}, Inter-op: {INTEROP_THREADS}")
    run_fastapi()