chatbot / model_service.py
Soumik555's picture
hello
5344861
import time
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from logger import logger
from config import CACHE_DIR, MODEL_NAME
tokenizer = None
model = None
generator = None
model_loaded = False
startup_time = time.time()
def is_model_cached(model_name: str) -> bool:
try:
model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}"
return model_path.exists() and any(model_path.iterdir())
except Exception as e:
logger.error(f"Error checking cache: {e}")
return False
def load_model():
global tokenizer, model, generator, model_loaded
try:
logger.info(f"Loading model: {MODEL_NAME}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
start = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
low_cpu_mem_usage=True,
)
device = 0 if torch.cuda.is_available() else -1
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
model_loaded = True
logger.info(f"βœ… Model loaded in {time.time()-start:.2f}s on {model.device}")
return True
except Exception as e:
logger.error(f"❌ Error loading model: {e}", exc_info=True)
model_loaded = False
return False
def generate_response(message: str, max_length: int, temperature: float, top_p: float):
if not generator:
return "❌ Model not loaded yet", 0.0
start = time.time()
try:
result = generator(
message,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
repetition_penalty=1.1
)
text = result[0]['generated_text']
reply = text[len(message):].strip() if text.startswith(message) else text.strip()
if not reply:
reply = "I'm not sure how to respond to that."
return reply, time.time()-start
except Exception as e:
logger.error(f"Generation error: {e}", exc_info=True)
return f"❌ Error: {e}", 0.0