import time from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch from logger import logger from config import CACHE_DIR, MODEL_NAME tokenizer = None model = None generator = None model_loaded = False startup_time = time.time() def is_model_cached(model_name: str) -> bool: try: model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}" return model_path.exists() and any(model_path.iterdir()) except Exception as e: logger.error(f"Error checking cache: {e}") return False def load_model(): global tokenizer, model, generator, model_loaded try: logger.info(f"Loading model: {MODEL_NAME}") logger.info(f"CUDA available: {torch.cuda.is_available()}") start = time.time() tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, cache_dir=CACHE_DIR, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, low_cpu_mem_usage=True, ) device = 0 if torch.cuda.is_available() else -1 generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=device, ) model_loaded = True logger.info(f"✅ Model loaded in {time.time()-start:.2f}s on {model.device}") return True except Exception as e: logger.error(f"❌ Error loading model: {e}", exc_info=True) model_loaded = False return False def generate_response(message: str, max_length: int, temperature: float, top_p: float): if not generator: return "❌ Model not loaded yet", 0.0 start = time.time() try: result = generator( message, max_length=max_length, temperature=temperature, top_p=top_p, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, do_sample=True, repetition_penalty=1.1 ) text = result[0]['generated_text'] reply = text[len(message):].strip() if text.startswith(message) else text.strip() if not reply: reply = "I'm not sure how to respond to that." return reply, time.time()-start except Exception as e: logger.error(f"Generation error: {e}", exc_info=True) return f"❌ Error: {e}", 0.0