Spaces:
Sleeping
Sleeping
| import time | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| from logger import logger | |
| from config import CACHE_DIR, MODEL_NAME | |
| tokenizer = None | |
| model = None | |
| generator = None | |
| model_loaded = False | |
| startup_time = time.time() | |
| def is_model_cached(model_name: str) -> bool: | |
| try: | |
| model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}" | |
| return model_path.exists() and any(model_path.iterdir()) | |
| except Exception as e: | |
| logger.error(f"Error checking cache: {e}") | |
| return False | |
| def load_model(): | |
| global tokenizer, model, generator, model_loaded | |
| try: | |
| logger.info(f"Loading model: {MODEL_NAME}") | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| start = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| cache_dir=CACHE_DIR, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| low_cpu_mem_usage=True, | |
| ) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| ) | |
| model_loaded = True | |
| logger.info(f"β Model loaded in {time.time()-start:.2f}s on {model.device}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Error loading model: {e}", exc_info=True) | |
| model_loaded = False | |
| return False | |
| def generate_response(message: str, max_length: int, temperature: float, top_p: float): | |
| if not generator: | |
| return "β Model not loaded yet", 0.0 | |
| start = time.time() | |
| try: | |
| result = generator( | |
| message, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| repetition_penalty=1.1 | |
| ) | |
| text = result[0]['generated_text'] | |
| reply = text[len(message):].strip() if text.startswith(message) else text.strip() | |
| if not reply: | |
| reply = "I'm not sure how to respond to that." | |
| return reply, time.time()-start | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}", exc_info=True) | |
| return f"β Error: {e}", 0.0 | |