File size: 2,497 Bytes
73212ba
 
 
3590b90
73212ba
5297ff3
73212ba
3590b90
73212ba
9ee5d3a
5297ff3
 
 
 
 
73212ba
 
3590b90
73212ba
3590b90
73212ba
3590b90
 
 
 
73212ba
3590b90
 
 
 
 
 
 
73212ba
3590b90
 
73212ba
 
5297ff3
9ee5d3a
 
5297ff3
 
 
 
 
 
 
 
73212ba
 
 
5297ff3
 
 
73212ba
 
 
 
5297ff3
9ee5d3a
73212ba
9ee5d3a
 
 
73212ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import snapshot_download

# === Base & adapter config ===
BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
ADAPTER_PATH = "GilbertAkham/deepseek-R1-multitask-lora"

# === System message ===
SYSTEM_PROMPT = (
    "You are Chat-Bot, a helpful and logical assistant trained for reasoning, "
    "email, chatting, summarization, story continuation, and report writing.\n\n"
)

class EndpointHandler:
    def __init__(self, path=""):
        print("🚀 Loading base model...")
        self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)

        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )

        print(f"🔗 Downloading LoRA adapter from {ADAPTER_PATH}...")
        adapter_local_path = snapshot_download(repo_id=ADAPTER_PATH, allow_patterns=["*adapter*"])
        print(f"📁 Adapter files cached at {adapter_local_path}")

        print("🧩 Attaching LoRA adapter...")
        self.model = PeftModel.from_pretrained(base_model, adapter_local_path)
        self.model.eval()

        print("✅ Model + LoRA adapter loaded successfully.")

    def __call__(self, data):
        # === Combine system + user prompt ===
        user_prompt = data.get("inputs", "")
        full_prompt = SYSTEM_PROMPT + user_prompt

        params = data.get("parameters", {})
        max_new_tokens = params.get("max_new_tokens", 512)
        temperature = params.get("temperature", 0.7)
        top_p = params.get("top_p", 0.9)

        # === Tokenize and run generation ===
        inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        # === Decode and strip system message ===
        text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        if text.startswith(SYSTEM_PROMPT):
            text = text[len(SYSTEM_PROMPT):].strip()

        return {"generated_text": text}