GilbertAkham commited on
Commit
3590b90
·
verified ·
1 Parent(s): 73212ba

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -7
handler.py CHANGED
@@ -2,22 +2,32 @@
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
 
5
 
6
- # Model path in the repo
7
  BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
8
- ADAPTER_PATH = "."
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
- print("Loading tokenizer and model...")
13
  self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
 
14
  base_model = AutoModelForCausalLM.from_pretrained(
15
- BASE_MODEL, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
 
 
 
16
  )
17
- self.model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
18
- self.model = self.model.merge_and_unload()
 
 
 
 
 
19
  self.model.eval()
20
- print("Model loaded successfully.")
 
21
 
22
  def __call__(self, data):
23
  prompt = data.get("inputs", "")
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from peft import PeftModel
5
+ from huggingface_hub import snapshot_download
6
 
 
7
  BASE_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
8
+ ADAPTER_PATH = "GilbertAkham/deepseek-R1-multitask-lora"
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
+ print("🚀 Loading base model...")
13
  self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
14
+
15
  base_model = AutoModelForCausalLM.from_pretrained(
16
+ BASE_MODEL,
17
+ torch_dtype=torch.float16,
18
+ device_map="auto",
19
+ trust_remote_code=True
20
  )
21
+
22
+ print(f"🔗 Downloading LoRA adapter from {ADAPTER_PATH}...")
23
+ adapter_local_path = snapshot_download(repo_id=ADAPTER_PATH, allow_patterns=["*adapter*"])
24
+ print(f"📁 Adapter files cached at {adapter_local_path}")
25
+
26
+ print("🧩 Attaching LoRA adapter...")
27
+ self.model = PeftModel.from_pretrained(base_model, adapter_local_path)
28
  self.model.eval()
29
+
30
+ print("✅ Model + LoRA adapter loaded successfully.")
31
 
32
  def __call__(self, data):
33
  prompt = data.get("inputs", "")