Commit
Β·
6a50f6f
0
Parent(s):
init
Browse files- .gradio/flagged/dataset1.csv +7 -0
- README.md +3 -0
- __pycache__/main.cpython-313.pyc +0 -0
- main.ipynb +458 -0
- main.py +202 -0
- poetry.lock +0 -0
- pyproject.toml +19 -0
.gradio/flagged/dataset1.csv
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Prompt,Max Tokens,Gamma (draft lookahead),Confidence Threshold,Speculative Decoding Visualization,timestamp
|
| 2 |
+
def fibonacci(n):,10,5,0.5,"<div style='font-family: monospace;'><div style='margin-bottom: 20px; padding: 10px; background: #f0f0f0; border-radius: 5px;'><b>Final Output:</b><br/><|im_start|>system
|
| 3 |
+
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
|
| 4 |
+
<|im_start|>user
|
| 5 |
+
def fibonacci(n):<|im_end|>
|
| 6 |
+
<|im_start|>assistant
|
| 7 |
+
I understand that you're looking for a Python function</div><div style='margin-bottom: 20px; padding: 10px; background: #e0e0e0; border-radius: 5px;'><b>Acceptance Rate:</b> 8/15 = 53.3%</div><div style='margin-bottom: 10px;'><b>Decoding Steps:</b></div><div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'><b>Step 1:</b> <span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>Certainly</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>!</span> β <span style='background: #87CEEB; padding: 2px 4px; border-radius: 3px;'> I</span></div><div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'><b>Step 2:</b> <span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>'m</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'> sorry</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>,</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'> but</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'> I</span> β <span style='background: #87CEEB; padding: 2px 4px; border-radius: 3px;'> understand</span></div><div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'><b>Step 3:</b> <span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> that</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> you</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'>'re</span></div><div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'><b>Step 4:</b> <span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> looking</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> for</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> a</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> Python</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> function</span></div></div>",2025-12-14 13:54:04.330173
|
README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Speculative Decoding
|
| 2 |
+
|
| 3 |
+
A project implementing speculative decoding techniques.
|
__pycache__/main.cpython-313.pyc
ADDED
|
Binary file (9.49 kB). View file
|
|
|
main.ipynb
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "e0ef8d28",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"/Users/yiyunzhu/Library/Caches/pypoetry/virtualenvs/speculativedecoding-B0TTdUOs-py3.13/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 14 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 15 |
+
"`torch_dtype` is deprecated! Use `dtype` instead!\n",
|
| 16 |
+
"Fetching 2 files: 100%|ββββββββββ| 2/2 [00:58<00:00, 29.32s/it]\n",
|
| 17 |
+
"Loading checkpoint shards: 100%|ββββββββββ| 2/2 [00:00<00:00, 27.91it/s]\n"
|
| 18 |
+
]
|
| 19 |
+
}
|
| 20 |
+
],
|
| 21 |
+
"source": [
|
| 22 |
+
"import torch\n",
|
| 23 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"set_seed(67)\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"device = \"mps\"\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-Coder-0.5B-Instruct\") #HuggingFaceTB/SmolLM2-135M-Instruct\n",
|
| 30 |
+
"draft_model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-Coder-0.5B-Instruct\", torch_dtype=torch.bfloat16).to(device)\n",
|
| 31 |
+
"verify_model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-Coder-3B-Instruct\", torch_dtype=torch.bfloat16).to(device) #HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": 8,
|
| 37 |
+
"id": "30d81505",
|
| 38 |
+
"metadata": {},
|
| 39 |
+
"outputs": [
|
| 40 |
+
{
|
| 41 |
+
"data": {
|
| 42 |
+
"text/plain": [
|
| 43 |
+
"'The quick brown fox jumps over the lazy dog'"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
"execution_count": 8,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"output_type": "execute_result"
|
| 49 |
+
}
|
| 50 |
+
],
|
| 51 |
+
"source": [
|
| 52 |
+
"prompt = \"The quick brown fox\"\n",
|
| 53 |
+
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
|
| 54 |
+
"input_ids = inputs[\"input_ids\"]\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"generated = input_ids.clone() # [1, seq_len]\n",
|
| 57 |
+
"draft_probs = []\n",
|
| 58 |
+
"for _ in range(5): # gamma = 5\n",
|
| 59 |
+
" with torch.no_grad():\n",
|
| 60 |
+
" outputs = draft_model(generated)\n",
|
| 61 |
+
" logits = outputs.logits[:, -1, :] # batch, seq_len, vocab_size\n",
|
| 62 |
+
" \n",
|
| 63 |
+
" probs = torch.softmax(logits, dim=-1) # [1, 50257]\n",
|
| 64 |
+
" next_token = torch.multinomial(probs, num_samples=1)\n",
|
| 65 |
+
"\n",
|
| 66 |
+
" draft_probs.append(probs)\n",
|
| 67 |
+
" generated = torch.cat([generated, next_token], dim=-1)\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"tokenizer.decode(generated[0])"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": 9,
|
| 75 |
+
"id": "2492ee58",
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"outputs": [
|
| 78 |
+
{
|
| 79 |
+
"name": "stdout",
|
| 80 |
+
"output_type": "stream",
|
| 81 |
+
"text": [
|
| 82 |
+
"Token: ' jumps', q(x)=0.9727, p(x)=0.9023\n",
|
| 83 |
+
" -> Accepted\n",
|
| 84 |
+
"Token: ' over', q(x)=1.0000, p(x)=0.9961\n",
|
| 85 |
+
" -> Accepted\n",
|
| 86 |
+
"Token: ' the', q(x)=0.9023, p(x)=0.9102\n",
|
| 87 |
+
" -> Accepted\n",
|
| 88 |
+
"Token: ' lazy', q(x)=1.0000, p(x)=0.9922\n",
|
| 89 |
+
" -> Accepted\n",
|
| 90 |
+
"Token: ' dog', q(x)=0.9922, p(x)=0.9844\n",
|
| 91 |
+
" -> Accepted\n",
|
| 92 |
+
" jumps over the lazy dog\n",
|
| 93 |
+
"[34208, 916, 279, 15678, 5562]\n"
|
| 94 |
+
]
|
| 95 |
+
}
|
| 96 |
+
],
|
| 97 |
+
"source": [
|
| 98 |
+
"with torch.no_grad():\n",
|
| 99 |
+
" target_outputs = verify_model(generated)\n",
|
| 100 |
+
" target_logits = target_outputs.logits[:, -6:-1, :]\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"target_probs = torch.softmax(target_logits, dim=-1) # [1, 5, 50257]\n",
|
| 103 |
+
"accepted_tokens = []\n",
|
| 104 |
+
"for i in range(5):\n",
|
| 105 |
+
" # if q(x) <= p(x), keep\n",
|
| 106 |
+
" # if q(x) > p(x), reject with 1-p(x)/q(x) chance\n",
|
| 107 |
+
" # if rejected, we sample from norm(max(0, p(x) - q(x)))\n",
|
| 108 |
+
" q = draft_probs[i] # [1, 50257]\n",
|
| 109 |
+
" p = target_probs[:, i, :] # [1, 50257]\n",
|
| 110 |
+
" token = generated[:, i - 5] # [1]\n",
|
| 111 |
+
" # assume unbatched for now\n",
|
| 112 |
+
" x = token[0].item()\n",
|
| 113 |
+
" q_x = q[0, x].item()\n",
|
| 114 |
+
" p_x = p[0, x].item()\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" print(f\"Token: '{tokenizer.decode(x)}', q(x)={q_x:.4f}, p(x)={p_x:.4f}\")\n",
|
| 117 |
+
" if q_x <= p_x:\n",
|
| 118 |
+
" print(\" -> Accepted\")\n",
|
| 119 |
+
" accepted_tokens.append(x)\n",
|
| 120 |
+
" else:\n",
|
| 121 |
+
" r = torch.rand(1).item()\n",
|
| 122 |
+
" acceptance_rate = p_x / q_x\n",
|
| 123 |
+
"\n",
|
| 124 |
+
" if r < acceptance_rate:\n",
|
| 125 |
+
" print(\" -> Accepted\")\n",
|
| 126 |
+
" accepted_tokens.append(x)\n",
|
| 127 |
+
" else:\n",
|
| 128 |
+
" print(\" -> Rejected\")\n",
|
| 129 |
+
" adjusted = torch.clamp(p-q, min=0)\n",
|
| 130 |
+
" adjusted = adjusted / adjusted.sum()\n",
|
| 131 |
+
" new_token = torch.multinomial(adjusted, num_samples=1)[0].item()\n",
|
| 132 |
+
" accepted_tokens.append(new_token)\n",
|
| 133 |
+
" break\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"print(tokenizer.decode(accepted_tokens))\n",
|
| 136 |
+
"print(accepted_tokens)"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"execution_count": 20,
|
| 142 |
+
"id": "2df0e2a9",
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"outputs": [],
|
| 145 |
+
"source": [
|
| 146 |
+
"def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv):\n",
|
| 147 |
+
" generated = input_ids.clone() # [1, seq_len]\n",
|
| 148 |
+
" draft_probs = []\n",
|
| 149 |
+
" for _ in range(gamma): \n",
|
| 150 |
+
" with torch.no_grad():\n",
|
| 151 |
+
" outputs = draft_model(\n",
|
| 152 |
+
" generated if past_kv is None else generated[:, -1:],\n",
|
| 153 |
+
" past_key_values=past_kv,\n",
|
| 154 |
+
" use_cache=True\n",
|
| 155 |
+
" )\n",
|
| 156 |
+
" logits = outputs.logits[:, -1, :] # batch, seq_len, vocab_size\n",
|
| 157 |
+
" past_kv = outputs.past_key_values\n",
|
| 158 |
+
" \n",
|
| 159 |
+
" probs = torch.softmax(logits, dim=-1) # [1, 50257]\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" confidence = probs.max().item() # dynamic speculative decoding\n",
|
| 162 |
+
" if confidence < confidence_threshold and len(draft_probs) > 0:\n",
|
| 163 |
+
" break \n",
|
| 164 |
+
"\n",
|
| 165 |
+
" next_token = torch.argmax(probs, dim=-1, keepdim=True)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
" draft_probs.append(probs)\n",
|
| 168 |
+
" generated = torch.cat([generated, next_token], dim=-1)\n",
|
| 169 |
+
"\n",
|
| 170 |
+
" if next_token.item() == eos_token:\n",
|
| 171 |
+
" break;\n",
|
| 172 |
+
"\n",
|
| 173 |
+
" return generated, draft_probs, past_kv\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"def verify(drafted, drafted_probs, eos_token, past_kv):\n",
|
| 176 |
+
" draft_len = len(drafted_probs) # number of new drafted tokens\n",
|
| 177 |
+
" with torch.no_grad():\n",
|
| 178 |
+
" if past_kv is None:\n",
|
| 179 |
+
" target_outputs = verify_model(drafted, use_cache=True)\n",
|
| 180 |
+
" target_logits = target_outputs.logits[:, -draft_len - 1:-1, :]\n",
|
| 181 |
+
" else:\n",
|
| 182 |
+
" target_outputs = verify_model(\n",
|
| 183 |
+
" drafted[:, -(draft_len + 1):], # extra token \n",
|
| 184 |
+
" past_key_values=past_kv,\n",
|
| 185 |
+
" use_cache=True\n",
|
| 186 |
+
" )\n",
|
| 187 |
+
" target_logits = target_outputs.logits[:, :-1, :] # Drop last (predicts bonus token)\n",
|
| 188 |
+
"\n",
|
| 189 |
+
" past_kv = target_outputs.past_key_values\n",
|
| 190 |
+
"\n",
|
| 191 |
+
" target_probs = torch.softmax(target_logits, dim=-1) # [1, 5, 50257]\n",
|
| 192 |
+
" accepted_tokens = []\n",
|
| 193 |
+
" num_accepted = 0 # number of tokens from drafted that is accepted\n",
|
| 194 |
+
" for i in range(draft_len):\n",
|
| 195 |
+
" # if q(x) <= p(x), keep\n",
|
| 196 |
+
" # if q(x) > p(x), reject with 1-p(x)/q(x) chance\n",
|
| 197 |
+
" # if rejected, we sample from norm(max(0, p(x) - q(x)))\n",
|
| 198 |
+
" q = drafted_probs[i] # [1, 50257]\n",
|
| 199 |
+
" p = target_probs[:, i, :] # [1, 50257]\n",
|
| 200 |
+
" token = drafted[:, i - draft_len] # [1]\n",
|
| 201 |
+
" # assume unbatched for now\n",
|
| 202 |
+
" x = token[0].item()\n",
|
| 203 |
+
" q_x = q[0, x].item()\n",
|
| 204 |
+
" p_x = p[0, x].item()\n",
|
| 205 |
+
"\n",
|
| 206 |
+
" print(f\"Token: '{tokenizer.decode(x)}'\", end = \"\")\n",
|
| 207 |
+
" if q_x <= p_x:\n",
|
| 208 |
+
" print(\" -> Accepted\")\n",
|
| 209 |
+
" accepted_tokens.append(x)\n",
|
| 210 |
+
" num_accepted+=1\n",
|
| 211 |
+
" else:\n",
|
| 212 |
+
" r = torch.rand(1).item()\n",
|
| 213 |
+
" acceptance_rate = p_x / q_x\n",
|
| 214 |
+
"\n",
|
| 215 |
+
" if r < acceptance_rate:\n",
|
| 216 |
+
" print(\" -> Accepted\")\n",
|
| 217 |
+
" accepted_tokens.append(x)\n",
|
| 218 |
+
" num_accepted+=1\n",
|
| 219 |
+
" else:\n",
|
| 220 |
+
" print(\" -> Rejected\", end = \"\")\n",
|
| 221 |
+
" adjusted = torch.clamp(p-q, min=0)\n",
|
| 222 |
+
" adjusted = adjusted / adjusted.sum()\n",
|
| 223 |
+
" new_token = torch.multinomial(adjusted, num_samples=1)[0].item()\n",
|
| 224 |
+
" accepted_tokens.append(new_token)\n",
|
| 225 |
+
" print(tokenizer.decode(new_token))\n",
|
| 226 |
+
" break\n",
|
| 227 |
+
" if accepted_tokens[-1] == eos_token:\n",
|
| 228 |
+
" break\n",
|
| 229 |
+
"\n",
|
| 230 |
+
" return accepted_tokens, num_accepted, past_kv"
|
| 231 |
+
]
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"cell_type": "code",
|
| 235 |
+
"execution_count": 24,
|
| 236 |
+
"id": "5378f5d5",
|
| 237 |
+
"metadata": {},
|
| 238 |
+
"outputs": [
|
| 239 |
+
{
|
| 240 |
+
"name": "stdout",
|
| 241 |
+
"output_type": "stream",
|
| 242 |
+
"text": [
|
| 243 |
+
"Token: 'A' -> Accepted\n",
|
| 244 |
+
"Token: ' deal' -> Accepted\n",
|
| 245 |
+
"Token: ' flow' -> Accepted\n",
|
| 246 |
+
"Token: ' in' -> Accepted\n",
|
| 247 |
+
"Token: ' a' -> Accepted\n",
|
| 248 |
+
"Token: ' VC' -> Rejected venture\n",
|
| 249 |
+
"Token: ' capital' -> Accepted\n",
|
| 250 |
+
"Token: ' fund' -> Rejected (\n",
|
| 251 |
+
"Token: 'VC' -> Accepted\n",
|
| 252 |
+
"Token: ')' -> Accepted\n",
|
| 253 |
+
"Token: ' fund' -> Accepted\n",
|
| 254 |
+
"Token: ' is' -> Rejected refers\n",
|
| 255 |
+
"Token: ' to' -> Accepted\n",
|
| 256 |
+
"Token: ' the' -> Accepted\n",
|
| 257 |
+
"Token: ' process' -> Accepted\n",
|
| 258 |
+
"Token: ' of' -> Accepted\n",
|
| 259 |
+
"Token: ' setting' -> Rejected screening\n",
|
| 260 |
+
"Token: ',' -> Accepted\n",
|
| 261 |
+
"Token: ' evaluating' -> Accepted\n",
|
| 262 |
+
"Token: ',' -> Accepted\n",
|
| 263 |
+
"Token: ' and' -> Accepted\n",
|
| 264 |
+
"Token: ' selecting' -> Accepted\n",
|
| 265 |
+
"Token: ' investors' -> Rejected potential\n",
|
| 266 |
+
"Token: ' investment' -> Accepted\n",
|
| 267 |
+
"Token: ' in' -> Rejected opportunities\n",
|
| 268 |
+
"Token: ' for' -> Rejected.\n",
|
| 269 |
+
"Token: ' It' -> Accepted\n",
|
| 270 |
+
"Token: ' involves' -> Accepted\n",
|
| 271 |
+
"Token: ' several' -> Rejected identifying\n",
|
| 272 |
+
"Token: ' potential' -> Accepted\n",
|
| 273 |
+
"Token: ' investors' -> Rejected companies\n",
|
| 274 |
+
"Token: ' that' -> Accepted\n",
|
| 275 |
+
"Token: ' could' -> Rejected the\n",
|
| 276 |
+
"Token: ' VC' -> Accepted\n",
|
| 277 |
+
"Token: ' fund' -> Accepted\n",
|
| 278 |
+
"Token: 'ers' -> Rejected is\n",
|
| 279 |
+
"Token: ' interested' -> Accepted\n",
|
| 280 |
+
"Token: ' in' -> Accepted\n",
|
| 281 |
+
"Token: ',' -> Accepted\n",
|
| 282 |
+
"Token: ' assessing' -> Accepted\n",
|
| 283 |
+
"Token: ' their' -> Accepted\n",
|
| 284 |
+
"Token: ' financial' -> Rejected growth\n",
|
| 285 |
+
"Token: ' potential' -> Accepted\n",
|
| 286 |
+
"Token: ',' -> Accepted\n",
|
| 287 |
+
"Token: ' and' -> Accepted\n",
|
| 288 |
+
"Token: ' evaluating' -> Rejected business\n",
|
| 289 |
+
"Token: ' model' -> Accepted\n",
|
| 290 |
+
"Token: ',' -> Accepted\n",
|
| 291 |
+
"Token: ' and' -> Accepted\n",
|
| 292 |
+
"Token: ',' -> Rejected market\n",
|
| 293 |
+
"Token: ' demand' -> Accepted\n",
|
| 294 |
+
"Token: ',' -> Accepted\n",
|
| 295 |
+
"Token: ' and' -> Accepted\n",
|
| 296 |
+
"Token: ' then' -> Accepted\n",
|
| 297 |
+
"Token: ' making' -> Accepted\n",
|
| 298 |
+
"Token: ' a' -> Accepted\n",
|
| 299 |
+
"Token: ' decision' -> Accepted\n",
|
| 300 |
+
"Token: ' on' -> Accepted\n",
|
| 301 |
+
"Token: ' whether' -> Accepted\n",
|
| 302 |
+
"Token: ' to' -> Accepted\n",
|
| 303 |
+
"Token: ' invest' -> Accepted\n",
|
| 304 |
+
"Token: ' in' -> Accepted\n",
|
| 305 |
+
"Token: ' those' -> Rejected the\n",
|
| 306 |
+
"Token: ' VC' -> Rejected fund\n",
|
| 307 |
+
"Token: ' in' -> Rejectedβs\n",
|
| 308 |
+
"Token: ' capital' -> Accepted\n",
|
| 309 |
+
"Token: '.' -> Accepted\n",
|
| 310 |
+
"Token: '<|im_end|>' -> Accepted\n"
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"data": {
|
| 315 |
+
"text/plain": [
|
| 316 |
+
"'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is a deal flow in a VC fund?<|im_end|>\\n<|im_start|>assistant\\nA deal flow in a venture capital (VC) fund refers to the process of screening, evaluating, and selecting potential investment opportunities. It involves identifying potential companies that the VC fund is interested in, assessing their growth potential, and business model, and market demand, and then making a decision on whether to invest in the fundβs capital.<|im_end|>'"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
"execution_count": 24,
|
| 320 |
+
"metadata": {},
|
| 321 |
+
"output_type": "execute_result"
|
| 322 |
+
}
|
| 323 |
+
],
|
| 324 |
+
"source": [
|
| 325 |
+
"messages = [{\"role\": \"user\", \"content\": \"What is a deal flow in a VC fund?\"}]\n",
|
| 326 |
+
"prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"max_tokens = 80\n",
|
| 329 |
+
"eos_token = tokenizer.eos_token_id\n",
|
| 330 |
+
"im_end_token = tokenizer.convert_tokens_to_ids(\"<|im_end|>\")\n",
|
| 331 |
+
"gamma = 15\n",
|
| 332 |
+
"confidence_threshold = 0.5\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
|
| 335 |
+
"result = inputs[\"input_ids\"].clone()\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"draft_kv = None\n",
|
| 338 |
+
"verify_kv = None\n",
|
| 339 |
+
"\n",
|
| 340 |
+
"total_drafted = 0\n",
|
| 341 |
+
"total_accepted = 0\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"while result.shape[-1] - inputs[\"input_ids\"].shape[-1] < max_tokens:\n",
|
| 344 |
+
" drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)\n",
|
| 345 |
+
" accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)\n",
|
| 346 |
+
"\n",
|
| 347 |
+
" total_drafted += len(drafted_probs)\n",
|
| 348 |
+
" total_accepted += num_accepted\n",
|
| 349 |
+
" \n",
|
| 350 |
+
" valid_len = result.shape[-1] + num_accepted\n",
|
| 351 |
+
" result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)\n",
|
| 352 |
+
"\n",
|
| 353 |
+
" if draft_kv is not None:\n",
|
| 354 |
+
" draft_kv.crop(max_length=valid_len)\n",
|
| 355 |
+
" if verify_kv is not None:\n",
|
| 356 |
+
" verify_kv.crop(max_length=valid_len)\n",
|
| 357 |
+
"\n",
|
| 358 |
+
" if eos_token in accepted_tokens or im_end_token in accepted_tokens:\n",
|
| 359 |
+
" break\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"tokenizer.decode(result[0])\n"
|
| 362 |
+
]
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"cell_type": "code",
|
| 366 |
+
"execution_count": 25,
|
| 367 |
+
"id": "40c92741",
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"outputs": [
|
| 370 |
+
{
|
| 371 |
+
"data": {
|
| 372 |
+
"text/plain": [
|
| 373 |
+
"0.6071428571428571"
|
| 374 |
+
]
|
| 375 |
+
},
|
| 376 |
+
"execution_count": 25,
|
| 377 |
+
"metadata": {},
|
| 378 |
+
"output_type": "execute_result"
|
| 379 |
+
}
|
| 380 |
+
],
|
| 381 |
+
"source": [
|
| 382 |
+
"total_accepted / total_drafted"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"cell_type": "code",
|
| 387 |
+
"execution_count": 26,
|
| 388 |
+
"id": "0c661940",
|
| 389 |
+
"metadata": {},
|
| 390 |
+
"outputs": [
|
| 391 |
+
{
|
| 392 |
+
"data": {
|
| 393 |
+
"text/plain": [
|
| 394 |
+
"'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is a deal flow in a VC fund?<|im_end|>\\n<|im_start|>assistant\\nA deal flow in a VC fund refers to the collection and processing of new investment opportunities presented to the fund for screening, evaluation, and ultimately investment.\\n\\nHereβs a basic overview:\\n\\n* **Deal Flow Collection**: Fund managers typically collect deals through email, cold calls, calendars, meeting notes, investor referrals, and referrals by other investors. They utilize various channels including online searches, investor networks, and industry'"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
"execution_count": 26,
|
| 398 |
+
"metadata": {},
|
| 399 |
+
"output_type": "execute_result"
|
| 400 |
+
}
|
| 401 |
+
],
|
| 402 |
+
"source": [
|
| 403 |
+
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
|
| 404 |
+
"result = inputs[\"input_ids\"].clone()\n",
|
| 405 |
+
"\n",
|
| 406 |
+
"past_kv = None\n",
|
| 407 |
+
"\n",
|
| 408 |
+
"while result.shape[-1] - inputs[\"input_ids\"].shape[-1] < max_tokens:\n",
|
| 409 |
+
" with torch.no_grad():\n",
|
| 410 |
+
" output = verify_model(\n",
|
| 411 |
+
" result if past_kv is None else result[:, -1:],\n",
|
| 412 |
+
" past_key_values=past_kv,\n",
|
| 413 |
+
" use_cache=True\n",
|
| 414 |
+
" )\n",
|
| 415 |
+
" logits = output.logits[:, -1, :] # batch, vocab\n",
|
| 416 |
+
" \n",
|
| 417 |
+
" past_kv = output.past_key_values\n",
|
| 418 |
+
" probs = torch.softmax(logits, dim=-1)\n",
|
| 419 |
+
" next_token = torch.multinomial(probs, num_samples=1)\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" result = torch.cat([result, next_token], dim=-1)\n",
|
| 422 |
+
" if eos_token in next_token or im_end_token in next_token:\n",
|
| 423 |
+
" break\n",
|
| 424 |
+
"\n",
|
| 425 |
+
"tokenizer.decode(result[0])"
|
| 426 |
+
]
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"cell_type": "code",
|
| 430 |
+
"execution_count": null,
|
| 431 |
+
"id": "c7a26417",
|
| 432 |
+
"metadata": {},
|
| 433 |
+
"outputs": [],
|
| 434 |
+
"source": []
|
| 435 |
+
}
|
| 436 |
+
],
|
| 437 |
+
"metadata": {
|
| 438 |
+
"kernelspec": {
|
| 439 |
+
"display_name": "speculativedecoding-B0TTdUOs-py3.13",
|
| 440 |
+
"language": "python",
|
| 441 |
+
"name": "python3"
|
| 442 |
+
},
|
| 443 |
+
"language_info": {
|
| 444 |
+
"codemirror_mode": {
|
| 445 |
+
"name": "ipython",
|
| 446 |
+
"version": 3
|
| 447 |
+
},
|
| 448 |
+
"file_extension": ".py",
|
| 449 |
+
"mimetype": "text/x-python",
|
| 450 |
+
"name": "python",
|
| 451 |
+
"nbconvert_exporter": "python",
|
| 452 |
+
"pygments_lexer": "ipython3",
|
| 453 |
+
"version": "3.13.5"
|
| 454 |
+
}
|
| 455 |
+
},
|
| 456 |
+
"nbformat": 4,
|
| 457 |
+
"nbformat_minor": 5
|
| 458 |
+
}
|
main.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
set_seed(67)
|
| 6 |
+
|
| 7 |
+
device = "mps"
|
| 8 |
+
|
| 9 |
+
# Initialize models and tokenizer
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-0.5B-Instruct")
|
| 11 |
+
draft_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-0.5B-Instruct", torch_dtype=torch.bfloat16).to(device)
|
| 12 |
+
verify_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-3B-Instruct", torch_dtype=torch.bfloat16).to(device)
|
| 13 |
+
|
| 14 |
+
def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv):
|
| 15 |
+
generated = input_ids.clone()
|
| 16 |
+
draft_probs = []
|
| 17 |
+
for _ in range(gamma):
|
| 18 |
+
with torch.no_grad():
|
| 19 |
+
outputs = draft_model(
|
| 20 |
+
generated if past_kv is None else generated[:, -1:],
|
| 21 |
+
past_key_values=past_kv,
|
| 22 |
+
use_cache=True
|
| 23 |
+
)
|
| 24 |
+
logits = outputs.logits[:, -1, :]
|
| 25 |
+
past_kv = outputs.past_key_values
|
| 26 |
+
|
| 27 |
+
probs = torch.softmax(logits, dim=-1)
|
| 28 |
+
|
| 29 |
+
confidence = probs.max().item()
|
| 30 |
+
if confidence < confidence_threshold and len(draft_probs) > 0:
|
| 31 |
+
break
|
| 32 |
+
|
| 33 |
+
next_token = torch.argmax(probs, dim=-1, keepdim=True)
|
| 34 |
+
|
| 35 |
+
draft_probs.append(probs)
|
| 36 |
+
generated = torch.cat([generated, next_token], dim=-1)
|
| 37 |
+
|
| 38 |
+
if next_token.item() == eos_token:
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
return generated, draft_probs, past_kv
|
| 42 |
+
|
| 43 |
+
def verify(drafted, drafted_probs, eos_token, past_kv):
|
| 44 |
+
draft_len = len(drafted_probs)
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
if past_kv is None:
|
| 47 |
+
target_outputs = verify_model(drafted, use_cache=True)
|
| 48 |
+
target_logits = target_outputs.logits[:, -draft_len - 1:-1, :]
|
| 49 |
+
else:
|
| 50 |
+
target_outputs = verify_model(
|
| 51 |
+
drafted[:, -(draft_len + 1):],
|
| 52 |
+
past_key_values=past_kv,
|
| 53 |
+
use_cache=True
|
| 54 |
+
)
|
| 55 |
+
target_logits = target_outputs.logits[:, :-1, :]
|
| 56 |
+
|
| 57 |
+
past_kv = target_outputs.past_key_values
|
| 58 |
+
|
| 59 |
+
target_probs = torch.softmax(target_logits, dim=-1)
|
| 60 |
+
accepted_tokens = []
|
| 61 |
+
num_accepted = 0
|
| 62 |
+
for i in range(draft_len):
|
| 63 |
+
q = drafted_probs[i]
|
| 64 |
+
p = target_probs[:, i, :]
|
| 65 |
+
token = drafted[:, i - draft_len]
|
| 66 |
+
x = token[0].item()
|
| 67 |
+
q_x = q[0, x].item()
|
| 68 |
+
p_x = p[0, x].item()
|
| 69 |
+
|
| 70 |
+
if q_x <= p_x:
|
| 71 |
+
accepted_tokens.append(x)
|
| 72 |
+
num_accepted += 1
|
| 73 |
+
else:
|
| 74 |
+
r = torch.rand(1).item()
|
| 75 |
+
acceptance_rate = p_x / q_x
|
| 76 |
+
|
| 77 |
+
if r < acceptance_rate:
|
| 78 |
+
accepted_tokens.append(x)
|
| 79 |
+
num_accepted += 1
|
| 80 |
+
else:
|
| 81 |
+
adjusted = torch.clamp(p - q, min=0)
|
| 82 |
+
adjusted = adjusted / adjusted.sum()
|
| 83 |
+
new_token = torch.multinomial(adjusted, num_samples=1)[0].item()
|
| 84 |
+
accepted_tokens.append(new_token)
|
| 85 |
+
break
|
| 86 |
+
if accepted_tokens[-1] == eos_token:
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
return accepted_tokens, num_accepted, past_kv
|
| 90 |
+
|
| 91 |
+
def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
|
| 92 |
+
# Prepare input
|
| 93 |
+
messages = [{"role": "user", "content": prompt}]
|
| 94 |
+
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 95 |
+
|
| 96 |
+
eos_token = tokenizer.eos_token_id
|
| 97 |
+
im_end_token = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 98 |
+
|
| 99 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
|
| 100 |
+
result = inputs["input_ids"].clone()
|
| 101 |
+
|
| 102 |
+
draft_kv = None
|
| 103 |
+
verify_kv = None
|
| 104 |
+
|
| 105 |
+
total_drafted = 0
|
| 106 |
+
total_accepted = 0
|
| 107 |
+
|
| 108 |
+
steps = []
|
| 109 |
+
|
| 110 |
+
while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
|
| 111 |
+
print(steps)
|
| 112 |
+
drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
|
| 113 |
+
accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)
|
| 114 |
+
|
| 115 |
+
total_drafted += len(drafted_probs)
|
| 116 |
+
total_accepted += num_accepted
|
| 117 |
+
|
| 118 |
+
# Extract token IDs for visualization
|
| 119 |
+
drafted_token_ids = drafted[0, -len(drafted_probs):].tolist()
|
| 120 |
+
|
| 121 |
+
step = {
|
| 122 |
+
"drafted": [tokenizer.decode([t]) for t in drafted_token_ids],
|
| 123 |
+
"accepted": num_accepted,
|
| 124 |
+
"resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None
|
| 125 |
+
}
|
| 126 |
+
steps.append(step)
|
| 127 |
+
|
| 128 |
+
valid_len = result.shape[-1] + num_accepted
|
| 129 |
+
result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)
|
| 130 |
+
|
| 131 |
+
if draft_kv is not None:
|
| 132 |
+
draft_kv.crop(max_length=valid_len)
|
| 133 |
+
if verify_kv is not None:
|
| 134 |
+
verify_kv.crop(max_length=valid_len)
|
| 135 |
+
|
| 136 |
+
if eos_token in accepted_tokens or im_end_token in accepted_tokens:
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
# Extract final output
|
| 140 |
+
final_output = tokenizer.decode(result[0])
|
| 141 |
+
|
| 142 |
+
# Build HTML visualization
|
| 143 |
+
html = "<div style='font-family: monospace;'>"
|
| 144 |
+
html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
|
| 145 |
+
html += f"<b>Final Output:</b><br/>{final_output}"
|
| 146 |
+
html += "</div>"
|
| 147 |
+
html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2pd solid white; border-radius: 5px;'>"
|
| 148 |
+
html += f"<b>Acceptance Rate:</b> {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%"
|
| 149 |
+
html += "</div>"
|
| 150 |
+
html += "<div style='margin-bottom: 10px;'><b>Decoding Steps:</b></div>"
|
| 151 |
+
|
| 152 |
+
for i, step in enumerate(steps):
|
| 153 |
+
html += f"<div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'>"
|
| 154 |
+
html += f"<b>Step {i+1}:</b> "
|
| 155 |
+
|
| 156 |
+
for j, token in enumerate(step["drafted"]):
|
| 157 |
+
# Escape HTML special characters
|
| 158 |
+
token_display = token.replace("<", "<").replace(">", ">")
|
| 159 |
+
if j < step["accepted"]:
|
| 160 |
+
html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 2px; border-radius: 3px;'>{token_display}</span>"
|
| 161 |
+
else:
|
| 162 |
+
html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
|
| 163 |
+
|
| 164 |
+
if step["resampled"]:
|
| 165 |
+
resampled_display = step["resampled"].replace("<", "<").replace(">", ">")
|
| 166 |
+
html += f" β <span style='background: #5AADCC; padding: 2px 4px; border-radius: 3px;'>{resampled_display}</span>"
|
| 167 |
+
|
| 168 |
+
html += "</div>"
|
| 169 |
+
html += "</div>"
|
| 170 |
+
|
| 171 |
+
return html
|
| 172 |
+
|
| 173 |
+
demo = gr.Interface(
|
| 174 |
+
fn=generate_visual,
|
| 175 |
+
inputs=[
|
| 176 |
+
gr.Textbox(label="Prompt", value="What is a deal flow in a VC fund?", lines=3),
|
| 177 |
+
gr.Slider(minimum=10, maximum=100, value=50, step=10, label="Max Tokens"),
|
| 178 |
+
gr.Slider(minimum=1, maximum=30, value=15, step=1, label="Gamma (draft lookahead)"),
|
| 179 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold")
|
| 180 |
+
],
|
| 181 |
+
outputs=gr.HTML(label="Speculative Decoding Visualization"),
|
| 182 |
+
title="π Speculative Decoding Demo",
|
| 183 |
+
description="""
|
| 184 |
+
**Speculative Decoding Visualization** using Qwen2.5-Coder models
|
| 185 |
+
|
| 186 |
+
- **Draft Model**: Qwen2.5-Coder-0.5B-Instruct (fast)
|
| 187 |
+
- **Verify Model**: Qwen2.5-Coder-3B-Instruct (accurate)
|
| 188 |
+
|
| 189 |
+
**Color Legend:**
|
| 190 |
+
- π’ Green = Accepted tokens from draft model
|
| 191 |
+
- π΄ Red = Rejected tokens (with strikethrough)
|
| 192 |
+
- π΅ Blue = Resampled tokens from verify model
|
| 193 |
+
""",
|
| 194 |
+
examples=[
|
| 195 |
+
["What is a deal flow in a VC fund?", 80, 15, 0.5],
|
| 196 |
+
["def fibonacci(n):", 50, 15, 0.5],
|
| 197 |
+
["Explain the concept of attention in transformers", 60, 10, 0.6]
|
| 198 |
+
]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
demo.launch()
|
poetry.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "speculativedecoding"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = ""
|
| 5 |
+
authors = ["Harrison <zhuyiyun060209@gmail.com>"]
|
| 6 |
+
readme = "README.md"
|
| 7 |
+
package-mode = false
|
| 8 |
+
|
| 9 |
+
[tool.poetry.dependencies]
|
| 10 |
+
python = ">=3.13"
|
| 11 |
+
transformers = ">=4.57.3,<5.0.0"
|
| 12 |
+
torch = ">=2.0.0"
|
| 13 |
+
ipykernel = "^7.1.0"
|
| 14 |
+
gradio = "^6.1.0"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
[build-system]
|
| 18 |
+
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
| 19 |
+
build-backend = "poetry.core.masonry.api"
|