File size: 2,497 Bytes
73212ba 3590b90 73212ba 5297ff3 73212ba 3590b90 73212ba 9ee5d3a 5297ff3 73212ba 3590b90 73212ba 3590b90 73212ba 3590b90 73212ba 3590b90 73212ba 3590b90 73212ba 5297ff3 9ee5d3a 5297ff3 73212ba 5297ff3 73212ba 5297ff3 9ee5d3a 73212ba 9ee5d3a 73212ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import snapshot_download
# === Base & adapter config ===
BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
ADAPTER_PATH = "GilbertAkham/deepseek-R1-multitask-lora"
# === System message ===
SYSTEM_PROMPT = (
"You are Chat-Bot, a helpful and logical assistant trained for reasoning, "
"email, chatting, summarization, story continuation, and report writing.\n\n"
)
class EndpointHandler:
def __init__(self, path=""):
print("🚀 Loading base model...")
self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print(f"🔗 Downloading LoRA adapter from {ADAPTER_PATH}...")
adapter_local_path = snapshot_download(repo_id=ADAPTER_PATH, allow_patterns=["*adapter*"])
print(f"📁 Adapter files cached at {adapter_local_path}")
print("🧩 Attaching LoRA adapter...")
self.model = PeftModel.from_pretrained(base_model, adapter_local_path)
self.model.eval()
print("✅ Model + LoRA adapter loaded successfully.")
def __call__(self, data):
# === Combine system + user prompt ===
user_prompt = data.get("inputs", "")
full_prompt = SYSTEM_PROMPT + user_prompt
params = data.get("parameters", {})
max_new_tokens = params.get("max_new_tokens", 512)
temperature = params.get("temperature", 0.7)
top_p = params.get("top_p", 0.9)
# === Tokenize and run generation ===
inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# === Decode and strip system message ===
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if text.startswith(SYSTEM_PROMPT):
text = text[len(SYSTEM_PROMPT):].strip()
return {"generated_text": text}
|