Soumik555 commited on
Commit
ef34958
Β·
1 Parent(s): 5344861
Files changed (1) hide show
  1. main.py +278 -15
main.py CHANGED
@@ -1,27 +1,290 @@
1
- from fastapi import FastAPI
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from chat_routes import router as chat_router
4
- from model_service import load_model
 
 
 
5
  import threading
6
  import uvicorn
7
- from logger import logger
 
8
 
9
- app = FastAPI(title="FastAPI Chatbot", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  app.add_middleware(
11
  CORSMiddleware,
12
- allow_origins=["*"], allow_credentials=True,
13
- allow_methods=["*"], allow_headers=["*"],
 
 
14
  )
15
 
16
- app.include_router(chat_router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @app.on_event("startup")
19
- def startup():
20
- def load_in_bg():
21
- success = load_model()
22
- if success: logger.info("Model loaded on startup.")
23
- else: logger.error("Model failed to load.")
24
- threading.Thread(target=load_in_bg, daemon=True).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  if __name__ == "__main__":
27
- uvicorn.run("app.main:app", host="0.0.0.0", port=7860, reload=False)
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ import torch
8
+ import logging
9
  import threading
10
  import uvicorn
11
+ from pathlib import Path
12
+ import time
13
 
14
+ # Configure logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # FastAPI app
22
+ app = FastAPI(
23
+ title="FastAPI Chatbot",
24
+ description="Chatbot with FastAPI backend",
25
+ version="1.0.0"
26
+ )
27
+
28
+ # Add CORS middleware
29
  app.add_middleware(
30
  CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
  )
36
 
37
+ # Pydantic models with fixed namespace conflicts
38
+ class ChatRequest(BaseModel):
39
+ message: str
40
+ max_length: int = 100
41
+ temperature: float = 0.7
42
+ top_p: float = 0.9
43
+
44
+ class Config:
45
+ protected_namespaces = ()
46
+
47
+ class ChatResponse(BaseModel):
48
+ response: str
49
+ model_name: str
50
+ response_time: float
51
+
52
+ class Config:
53
+ protected_namespaces = ()
54
+
55
+ class HealthResponse(BaseModel):
56
+ status: str
57
+ is_model_loaded: bool
58
+ model_name: str
59
+ cache_directory: str
60
+ startup_time: float
61
+
62
+ class Config:
63
+ protected_namespaces = ()
64
+
65
+ # Global variables
66
+ tokenizer = None
67
+ model = None
68
+ generator = None
69
+ startup_time = time.time()
70
+ model_loaded = False
71
+
72
+ # Configuration
73
+ MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium")
74
+ CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/model_cache")
75
+ MAX_LENGTH = int(os.getenv("MAX_LENGTH", "100"))
76
+ DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))
77
+
78
+ def ensure_cache_dir():
79
+ """Ensure cache directory exists"""
80
+ Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
81
+ logger.info(f"Cache directory: {CACHE_DIR}")
82
+
83
+ def is_model_cached(model_name: str) -> bool:
84
+ """Check if model is already cached"""
85
+ try:
86
+ model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}"
87
+ is_cached = model_path.exists() and any(model_path.iterdir())
88
+ logger.info(f"Model cached: {is_cached}")
89
+ return is_cached
90
+ except Exception as e:
91
+ logger.error(f"Error checking cache: {e}")
92
+ return False
93
+
94
+ def load_model():
95
+ """Load the Hugging Face model with caching"""
96
+ global tokenizer, model, generator, model_loaded
97
+
98
+ try:
99
+ ensure_cache_dir()
100
+
101
+ logger.info(f"Loading model: {MODEL_NAME}")
102
+ logger.info(f"Cache dir: {CACHE_DIR}")
103
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
104
+
105
+ start_time = time.time()
106
+
107
+ # Load tokenizer first
108
+ logger.info("Loading tokenizer...")
109
+ tokenizer = AutoTokenizer.from_pretrained(
110
+ MODEL_NAME,
111
+ cache_dir=CACHE_DIR,
112
+ local_files_only=False
113
+ )
114
+
115
+ # Add padding token if it doesn't exist
116
+ if tokenizer.pad_token is None:
117
+ tokenizer.pad_token = tokenizer.eos_token
118
+
119
+ # Load model
120
+ logger.info("Loading model...")
121
+ model = AutoModelForCausalLM.from_pretrained(
122
+ MODEL_NAME,
123
+ cache_dir=CACHE_DIR,
124
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
125
+ device_map="auto" if torch.cuda.is_available() else None,
126
+ low_cpu_mem_usage=True,
127
+ local_files_only=False
128
+ )
129
+
130
+ # Create text generation pipeline
131
+ logger.info("Creating pipeline...")
132
+ device = 0 if torch.cuda.is_available() else -1
133
+ generator = pipeline(
134
+ "text-generation",
135
+ model=model,
136
+ tokenizer=tokenizer,
137
+ device=device,
138
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
139
+ )
140
+
141
+ load_time = time.time() - start_time
142
+ model_loaded = True
143
+ logger.info(f"βœ… Model loaded successfully in {load_time:.2f} seconds!")
144
+ logger.info(f"Model device: {model.device}")
145
+
146
+ return True
147
+
148
+ except Exception as e:
149
+ logger.error(f"❌ Error loading model: {str(e)}", exc_info=True)
150
+ return False
151
+
152
+ def generate_response(message: str, max_length: int = 100, temperature: float = 0.7, top_p: float = 0.9) -> str:
153
+ """Generate response using the loaded model"""
154
+ if not generator:
155
+ return "❌ Model not loaded. Please wait for initialization...", 0.0
156
+
157
+ try:
158
+ start_time = time.time()
159
+
160
+ # Generate response with parameters
161
+ response = generator(
162
+ message,
163
+ max_length=max_length,
164
+ temperature=temperature,
165
+ top_p=top_p,
166
+ num_return_sequences=1,
167
+ pad_token_id=tokenizer.eos_token_id,
168
+ do_sample=True,
169
+ truncation=True,
170
+ repetition_penalty=1.1
171
+ )
172
+
173
+ # Extract generated text
174
+ generated_text = response[0]['generated_text']
175
+
176
+ # Clean up response
177
+ if generated_text.startswith(message):
178
+ bot_response = generated_text[len(message):].strip()
179
+ else:
180
+ bot_response = generated_text.strip()
181
+
182
+ # Fallback if empty response
183
+ if not bot_response:
184
+ bot_response = "I'm not sure how to respond to that. Could you try rephrasing?"
185
+
186
+ response_time = time.time() - start_time
187
+ logger.info(f"Generated response in {response_time:.2f}s")
188
+
189
+ return bot_response, response_time
190
+
191
+ except Exception as e:
192
+ logger.error(f"Error generating response: {str(e)}", exc_info=True)
193
+ return f"❌ Error generating response: {str(e)}", 0.0
194
+
195
+ # FastAPI endpoints
196
+ @app.get("/")
197
+ async def root():
198
+ """Root endpoint"""
199
+ return {"message": "FastAPI Chatbot API", "status": "running"}
200
+
201
+ @app.get("/health", response_model=HealthResponse)
202
+ async def health_check():
203
+ """Health check endpoint with detailed information"""
204
+ return HealthResponse(
205
+ status="healthy" if model_loaded else "initializing",
206
+ is_model_loaded=model_loaded,
207
+ model_name=MODEL_NAME,
208
+ cache_directory=CACHE_DIR,
209
+ startup_time=time.time() - startup_time
210
+ )
211
+
212
+ @app.post("/chat", response_model=ChatResponse)
213
+ async def chat_endpoint(request: ChatRequest):
214
+ """Chat endpoint for API access"""
215
+ if not model_loaded:
216
+ raise HTTPException(
217
+ status_code=503,
218
+ detail="Model not loaded yet. Please wait for initialization."
219
+ )
220
+
221
+ # Validate input
222
+ if not request.message.strip():
223
+ raise HTTPException(status_code=400, detail="Message cannot be empty")
224
+
225
+ if len(request.message) > 1000:
226
+ raise HTTPException(status_code=400, detail="Message too long (max 1000 characters)")
227
+
228
+ # Generate response
229
+ response_text, response_time = generate_response(
230
+ request.message.strip(),
231
+ request.max_length,
232
+ request.temperature,
233
+ request.top_p
234
+ )
235
+
236
+ return ChatResponse(
237
+ response=response_text,
238
+ model_name=MODEL_NAME,
239
+ response_time=response_time
240
+ )
241
+
242
+ @app.get("/model-info")
243
+ async def get_model_info():
244
+ """Get detailed model information"""
245
+ device = "cuda" if torch.cuda.is_available() else "cpu"
246
+ if model and hasattr(model, 'device'):
247
+ device = str(model.device)
248
+
249
+ return {
250
+ "model_name": MODEL_NAME,
251
+ "model_loaded": model_loaded,
252
+ "device": device,
253
+ "cache_directory": CACHE_DIR,
254
+ "model_cached": is_model_cached(MODEL_NAME),
255
+ "parameters": {
256
+ "max_length": MAX_LENGTH,
257
+ "default_temperature": DEFAULT_TEMPERATURE
258
+ }
259
+ }
260
 
261
  @app.on_event("startup")
262
+ async def startup_event():
263
+ """Load model on startup"""
264
+ logger.info("πŸš€ Starting FastAPI Chatbot...")
265
+ logger.info("πŸ“¦ Loading model...")
266
+
267
+ # Load model in background thread to not block startup
268
+ def load_model_background():
269
+ global model_loaded
270
+ model_loaded = load_model()
271
+ if model_loaded:
272
+ logger.info("βœ… Model loaded successfully!")
273
+ else:
274
+ logger.error("❌ Failed to load model.")
275
+
276
+ # Start model loading in background
277
+ threading.Thread(target=load_model_background, daemon=True).start()
278
+
279
+ def run_fastapi():
280
+ """Run FastAPI server"""
281
+ uvicorn.run(
282
+ app,
283
+ host="0.0.0.0",
284
+ port=7860, # Changed to 7860 for HuggingFace
285
+ log_level="info",
286
+ access_log=True
287
+ )
288
 
289
  if __name__ == "__main__":
290
+ run_fastapi()