Spaces:
Sleeping
Sleeping
File size: 2,684 Bytes
5344861 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | import time
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from logger import logger
from config import CACHE_DIR, MODEL_NAME
tokenizer = None
model = None
generator = None
model_loaded = False
startup_time = time.time()
def is_model_cached(model_name: str) -> bool:
try:
model_path = Path(CACHE_DIR) / f"models--{model_name.replace('/', '--')}"
return model_path.exists() and any(model_path.iterdir())
except Exception as e:
logger.error(f"Error checking cache: {e}")
return False
def load_model():
global tokenizer, model, generator, model_loaded
try:
logger.info(f"Loading model: {MODEL_NAME}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
start = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
low_cpu_mem_usage=True,
)
device = 0 if torch.cuda.is_available() else -1
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
model_loaded = True
logger.info(f"✅ Model loaded in {time.time()-start:.2f}s on {model.device}")
return True
except Exception as e:
logger.error(f"❌ Error loading model: {e}", exc_info=True)
model_loaded = False
return False
def generate_response(message: str, max_length: int, temperature: float, top_p: float):
if not generator:
return "❌ Model not loaded yet", 0.0
start = time.time()
try:
result = generator(
message,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
repetition_penalty=1.1
)
text = result[0]['generated_text']
reply = text[len(message):].strip() if text.startswith(message) else text.strip()
if not reply:
reply = "I'm not sure how to respond to that."
return reply, time.time()-start
except Exception as e:
logger.error(f"Generation error: {e}", exc_info=True)
return f"❌ Error: {e}", 0.0
|