HandsomeSB commited on
Commit
6a50f6f
Β·
0 Parent(s):
.gradio/flagged/dataset1.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Prompt,Max Tokens,Gamma (draft lookahead),Confidence Threshold,Speculative Decoding Visualization,timestamp
2
+ def fibonacci(n):,10,5,0.5,"<div style='font-family: monospace;'><div style='margin-bottom: 20px; padding: 10px; background: #f0f0f0; border-radius: 5px;'><b>Final Output:</b><br/><|im_start|>system
3
+ You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
4
+ <|im_start|>user
5
+ def fibonacci(n):<|im_end|>
6
+ <|im_start|>assistant
7
+ I understand that you're looking for a Python function</div><div style='margin-bottom: 20px; padding: 10px; background: #e0e0e0; border-radius: 5px;'><b>Acceptance Rate:</b> 8/15 = 53.3%</div><div style='margin-bottom: 10px;'><b>Decoding Steps:</b></div><div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'><b>Step 1:</b> <span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>Certainly</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>!</span> β†’ <span style='background: #87CEEB; padding: 2px 4px; border-radius: 3px;'> I</span></div><div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'><b>Step 2:</b> <span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>'m</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'> sorry</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>,</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'> but</span><span style='background: #FFB6C1; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'> I</span> β†’ <span style='background: #87CEEB; padding: 2px 4px; border-radius: 3px;'> understand</span></div><div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'><b>Step 3:</b> <span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> that</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> you</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'>'re</span></div><div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'><b>Step 4:</b> <span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> looking</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> for</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> a</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> Python</span><span style='background: #90EE90; padding: 2px 4px; margin: 2px; border-radius: 3px;'> function</span></div></div>",2025-12-14 13:54:04.330173
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Speculative Decoding
2
+
3
+ A project implementing speculative decoding techniques.
__pycache__/main.cpython-313.pyc ADDED
Binary file (9.49 kB). View file
 
main.ipynb ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "e0ef8d28",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/Users/yiyunzhu/Library/Caches/pypoetry/virtualenvs/speculativedecoding-B0TTdUOs-py3.13/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n",
15
+ "`torch_dtype` is deprecated! Use `dtype` instead!\n",
16
+ "Fetching 2 files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:58<00:00, 29.32s/it]\n",
17
+ "Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 27.91it/s]\n"
18
+ ]
19
+ }
20
+ ],
21
+ "source": [
22
+ "import torch\n",
23
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed\n",
24
+ "\n",
25
+ "set_seed(67)\n",
26
+ "\n",
27
+ "device = \"mps\"\n",
28
+ "\n",
29
+ "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-Coder-0.5B-Instruct\") #HuggingFaceTB/SmolLM2-135M-Instruct\n",
30
+ "draft_model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-Coder-0.5B-Instruct\", torch_dtype=torch.bfloat16).to(device)\n",
31
+ "verify_model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-Coder-3B-Instruct\", torch_dtype=torch.bfloat16).to(device) #HuggingFaceTB/SmolLM2-1.7B-Instruct"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 8,
37
+ "id": "30d81505",
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "data": {
42
+ "text/plain": [
43
+ "'The quick brown fox jumps over the lazy dog'"
44
+ ]
45
+ },
46
+ "execution_count": 8,
47
+ "metadata": {},
48
+ "output_type": "execute_result"
49
+ }
50
+ ],
51
+ "source": [
52
+ "prompt = \"The quick brown fox\"\n",
53
+ "inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
54
+ "input_ids = inputs[\"input_ids\"]\n",
55
+ "\n",
56
+ "generated = input_ids.clone() # [1, seq_len]\n",
57
+ "draft_probs = []\n",
58
+ "for _ in range(5): # gamma = 5\n",
59
+ " with torch.no_grad():\n",
60
+ " outputs = draft_model(generated)\n",
61
+ " logits = outputs.logits[:, -1, :] # batch, seq_len, vocab_size\n",
62
+ " \n",
63
+ " probs = torch.softmax(logits, dim=-1) # [1, 50257]\n",
64
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
65
+ "\n",
66
+ " draft_probs.append(probs)\n",
67
+ " generated = torch.cat([generated, next_token], dim=-1)\n",
68
+ "\n",
69
+ "tokenizer.decode(generated[0])"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 9,
75
+ "id": "2492ee58",
76
+ "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "name": "stdout",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "Token: ' jumps', q(x)=0.9727, p(x)=0.9023\n",
83
+ " -> Accepted\n",
84
+ "Token: ' over', q(x)=1.0000, p(x)=0.9961\n",
85
+ " -> Accepted\n",
86
+ "Token: ' the', q(x)=0.9023, p(x)=0.9102\n",
87
+ " -> Accepted\n",
88
+ "Token: ' lazy', q(x)=1.0000, p(x)=0.9922\n",
89
+ " -> Accepted\n",
90
+ "Token: ' dog', q(x)=0.9922, p(x)=0.9844\n",
91
+ " -> Accepted\n",
92
+ " jumps over the lazy dog\n",
93
+ "[34208, 916, 279, 15678, 5562]\n"
94
+ ]
95
+ }
96
+ ],
97
+ "source": [
98
+ "with torch.no_grad():\n",
99
+ " target_outputs = verify_model(generated)\n",
100
+ " target_logits = target_outputs.logits[:, -6:-1, :]\n",
101
+ "\n",
102
+ "target_probs = torch.softmax(target_logits, dim=-1) # [1, 5, 50257]\n",
103
+ "accepted_tokens = []\n",
104
+ "for i in range(5):\n",
105
+ " # if q(x) <= p(x), keep\n",
106
+ " # if q(x) > p(x), reject with 1-p(x)/q(x) chance\n",
107
+ " # if rejected, we sample from norm(max(0, p(x) - q(x)))\n",
108
+ " q = draft_probs[i] # [1, 50257]\n",
109
+ " p = target_probs[:, i, :] # [1, 50257]\n",
110
+ " token = generated[:, i - 5] # [1]\n",
111
+ " # assume unbatched for now\n",
112
+ " x = token[0].item()\n",
113
+ " q_x = q[0, x].item()\n",
114
+ " p_x = p[0, x].item()\n",
115
+ "\n",
116
+ " print(f\"Token: '{tokenizer.decode(x)}', q(x)={q_x:.4f}, p(x)={p_x:.4f}\")\n",
117
+ " if q_x <= p_x:\n",
118
+ " print(\" -> Accepted\")\n",
119
+ " accepted_tokens.append(x)\n",
120
+ " else:\n",
121
+ " r = torch.rand(1).item()\n",
122
+ " acceptance_rate = p_x / q_x\n",
123
+ "\n",
124
+ " if r < acceptance_rate:\n",
125
+ " print(\" -> Accepted\")\n",
126
+ " accepted_tokens.append(x)\n",
127
+ " else:\n",
128
+ " print(\" -> Rejected\")\n",
129
+ " adjusted = torch.clamp(p-q, min=0)\n",
130
+ " adjusted = adjusted / adjusted.sum()\n",
131
+ " new_token = torch.multinomial(adjusted, num_samples=1)[0].item()\n",
132
+ " accepted_tokens.append(new_token)\n",
133
+ " break\n",
134
+ "\n",
135
+ "print(tokenizer.decode(accepted_tokens))\n",
136
+ "print(accepted_tokens)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 20,
142
+ "id": "2df0e2a9",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv):\n",
147
+ " generated = input_ids.clone() # [1, seq_len]\n",
148
+ " draft_probs = []\n",
149
+ " for _ in range(gamma): \n",
150
+ " with torch.no_grad():\n",
151
+ " outputs = draft_model(\n",
152
+ " generated if past_kv is None else generated[:, -1:],\n",
153
+ " past_key_values=past_kv,\n",
154
+ " use_cache=True\n",
155
+ " )\n",
156
+ " logits = outputs.logits[:, -1, :] # batch, seq_len, vocab_size\n",
157
+ " past_kv = outputs.past_key_values\n",
158
+ " \n",
159
+ " probs = torch.softmax(logits, dim=-1) # [1, 50257]\n",
160
+ "\n",
161
+ " confidence = probs.max().item() # dynamic speculative decoding\n",
162
+ " if confidence < confidence_threshold and len(draft_probs) > 0:\n",
163
+ " break \n",
164
+ "\n",
165
+ " next_token = torch.argmax(probs, dim=-1, keepdim=True)\n",
166
+ "\n",
167
+ " draft_probs.append(probs)\n",
168
+ " generated = torch.cat([generated, next_token], dim=-1)\n",
169
+ "\n",
170
+ " if next_token.item() == eos_token:\n",
171
+ " break;\n",
172
+ "\n",
173
+ " return generated, draft_probs, past_kv\n",
174
+ "\n",
175
+ "def verify(drafted, drafted_probs, eos_token, past_kv):\n",
176
+ " draft_len = len(drafted_probs) # number of new drafted tokens\n",
177
+ " with torch.no_grad():\n",
178
+ " if past_kv is None:\n",
179
+ " target_outputs = verify_model(drafted, use_cache=True)\n",
180
+ " target_logits = target_outputs.logits[:, -draft_len - 1:-1, :]\n",
181
+ " else:\n",
182
+ " target_outputs = verify_model(\n",
183
+ " drafted[:, -(draft_len + 1):], # extra token \n",
184
+ " past_key_values=past_kv,\n",
185
+ " use_cache=True\n",
186
+ " )\n",
187
+ " target_logits = target_outputs.logits[:, :-1, :] # Drop last (predicts bonus token)\n",
188
+ "\n",
189
+ " past_kv = target_outputs.past_key_values\n",
190
+ "\n",
191
+ " target_probs = torch.softmax(target_logits, dim=-1) # [1, 5, 50257]\n",
192
+ " accepted_tokens = []\n",
193
+ " num_accepted = 0 # number of tokens from drafted that is accepted\n",
194
+ " for i in range(draft_len):\n",
195
+ " # if q(x) <= p(x), keep\n",
196
+ " # if q(x) > p(x), reject with 1-p(x)/q(x) chance\n",
197
+ " # if rejected, we sample from norm(max(0, p(x) - q(x)))\n",
198
+ " q = drafted_probs[i] # [1, 50257]\n",
199
+ " p = target_probs[:, i, :] # [1, 50257]\n",
200
+ " token = drafted[:, i - draft_len] # [1]\n",
201
+ " # assume unbatched for now\n",
202
+ " x = token[0].item()\n",
203
+ " q_x = q[0, x].item()\n",
204
+ " p_x = p[0, x].item()\n",
205
+ "\n",
206
+ " print(f\"Token: '{tokenizer.decode(x)}'\", end = \"\")\n",
207
+ " if q_x <= p_x:\n",
208
+ " print(\" -> Accepted\")\n",
209
+ " accepted_tokens.append(x)\n",
210
+ " num_accepted+=1\n",
211
+ " else:\n",
212
+ " r = torch.rand(1).item()\n",
213
+ " acceptance_rate = p_x / q_x\n",
214
+ "\n",
215
+ " if r < acceptance_rate:\n",
216
+ " print(\" -> Accepted\")\n",
217
+ " accepted_tokens.append(x)\n",
218
+ " num_accepted+=1\n",
219
+ " else:\n",
220
+ " print(\" -> Rejected\", end = \"\")\n",
221
+ " adjusted = torch.clamp(p-q, min=0)\n",
222
+ " adjusted = adjusted / adjusted.sum()\n",
223
+ " new_token = torch.multinomial(adjusted, num_samples=1)[0].item()\n",
224
+ " accepted_tokens.append(new_token)\n",
225
+ " print(tokenizer.decode(new_token))\n",
226
+ " break\n",
227
+ " if accepted_tokens[-1] == eos_token:\n",
228
+ " break\n",
229
+ "\n",
230
+ " return accepted_tokens, num_accepted, past_kv"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": 24,
236
+ "id": "5378f5d5",
237
+ "metadata": {},
238
+ "outputs": [
239
+ {
240
+ "name": "stdout",
241
+ "output_type": "stream",
242
+ "text": [
243
+ "Token: 'A' -> Accepted\n",
244
+ "Token: ' deal' -> Accepted\n",
245
+ "Token: ' flow' -> Accepted\n",
246
+ "Token: ' in' -> Accepted\n",
247
+ "Token: ' a' -> Accepted\n",
248
+ "Token: ' VC' -> Rejected venture\n",
249
+ "Token: ' capital' -> Accepted\n",
250
+ "Token: ' fund' -> Rejected (\n",
251
+ "Token: 'VC' -> Accepted\n",
252
+ "Token: ')' -> Accepted\n",
253
+ "Token: ' fund' -> Accepted\n",
254
+ "Token: ' is' -> Rejected refers\n",
255
+ "Token: ' to' -> Accepted\n",
256
+ "Token: ' the' -> Accepted\n",
257
+ "Token: ' process' -> Accepted\n",
258
+ "Token: ' of' -> Accepted\n",
259
+ "Token: ' setting' -> Rejected screening\n",
260
+ "Token: ',' -> Accepted\n",
261
+ "Token: ' evaluating' -> Accepted\n",
262
+ "Token: ',' -> Accepted\n",
263
+ "Token: ' and' -> Accepted\n",
264
+ "Token: ' selecting' -> Accepted\n",
265
+ "Token: ' investors' -> Rejected potential\n",
266
+ "Token: ' investment' -> Accepted\n",
267
+ "Token: ' in' -> Rejected opportunities\n",
268
+ "Token: ' for' -> Rejected.\n",
269
+ "Token: ' It' -> Accepted\n",
270
+ "Token: ' involves' -> Accepted\n",
271
+ "Token: ' several' -> Rejected identifying\n",
272
+ "Token: ' potential' -> Accepted\n",
273
+ "Token: ' investors' -> Rejected companies\n",
274
+ "Token: ' that' -> Accepted\n",
275
+ "Token: ' could' -> Rejected the\n",
276
+ "Token: ' VC' -> Accepted\n",
277
+ "Token: ' fund' -> Accepted\n",
278
+ "Token: 'ers' -> Rejected is\n",
279
+ "Token: ' interested' -> Accepted\n",
280
+ "Token: ' in' -> Accepted\n",
281
+ "Token: ',' -> Accepted\n",
282
+ "Token: ' assessing' -> Accepted\n",
283
+ "Token: ' their' -> Accepted\n",
284
+ "Token: ' financial' -> Rejected growth\n",
285
+ "Token: ' potential' -> Accepted\n",
286
+ "Token: ',' -> Accepted\n",
287
+ "Token: ' and' -> Accepted\n",
288
+ "Token: ' evaluating' -> Rejected business\n",
289
+ "Token: ' model' -> Accepted\n",
290
+ "Token: ',' -> Accepted\n",
291
+ "Token: ' and' -> Accepted\n",
292
+ "Token: ',' -> Rejected market\n",
293
+ "Token: ' demand' -> Accepted\n",
294
+ "Token: ',' -> Accepted\n",
295
+ "Token: ' and' -> Accepted\n",
296
+ "Token: ' then' -> Accepted\n",
297
+ "Token: ' making' -> Accepted\n",
298
+ "Token: ' a' -> Accepted\n",
299
+ "Token: ' decision' -> Accepted\n",
300
+ "Token: ' on' -> Accepted\n",
301
+ "Token: ' whether' -> Accepted\n",
302
+ "Token: ' to' -> Accepted\n",
303
+ "Token: ' invest' -> Accepted\n",
304
+ "Token: ' in' -> Accepted\n",
305
+ "Token: ' those' -> Rejected the\n",
306
+ "Token: ' VC' -> Rejected fund\n",
307
+ "Token: ' in' -> Rejected’s\n",
308
+ "Token: ' capital' -> Accepted\n",
309
+ "Token: '.' -> Accepted\n",
310
+ "Token: '<|im_end|>' -> Accepted\n"
311
+ ]
312
+ },
313
+ {
314
+ "data": {
315
+ "text/plain": [
316
+ "'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is a deal flow in a VC fund?<|im_end|>\\n<|im_start|>assistant\\nA deal flow in a venture capital (VC) fund refers to the process of screening, evaluating, and selecting potential investment opportunities. It involves identifying potential companies that the VC fund is interested in, assessing their growth potential, and business model, and market demand, and then making a decision on whether to invest in the fund’s capital.<|im_end|>'"
317
+ ]
318
+ },
319
+ "execution_count": 24,
320
+ "metadata": {},
321
+ "output_type": "execute_result"
322
+ }
323
+ ],
324
+ "source": [
325
+ "messages = [{\"role\": \"user\", \"content\": \"What is a deal flow in a VC fund?\"}]\n",
326
+ "prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
327
+ "\n",
328
+ "max_tokens = 80\n",
329
+ "eos_token = tokenizer.eos_token_id\n",
330
+ "im_end_token = tokenizer.convert_tokens_to_ids(\"<|im_end|>\")\n",
331
+ "gamma = 15\n",
332
+ "confidence_threshold = 0.5\n",
333
+ "\n",
334
+ "inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
335
+ "result = inputs[\"input_ids\"].clone()\n",
336
+ "\n",
337
+ "draft_kv = None\n",
338
+ "verify_kv = None\n",
339
+ "\n",
340
+ "total_drafted = 0\n",
341
+ "total_accepted = 0\n",
342
+ "\n",
343
+ "while result.shape[-1] - inputs[\"input_ids\"].shape[-1] < max_tokens:\n",
344
+ " drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)\n",
345
+ " accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)\n",
346
+ "\n",
347
+ " total_drafted += len(drafted_probs)\n",
348
+ " total_accepted += num_accepted\n",
349
+ " \n",
350
+ " valid_len = result.shape[-1] + num_accepted\n",
351
+ " result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)\n",
352
+ "\n",
353
+ " if draft_kv is not None:\n",
354
+ " draft_kv.crop(max_length=valid_len)\n",
355
+ " if verify_kv is not None:\n",
356
+ " verify_kv.crop(max_length=valid_len)\n",
357
+ "\n",
358
+ " if eos_token in accepted_tokens or im_end_token in accepted_tokens:\n",
359
+ " break\n",
360
+ "\n",
361
+ "tokenizer.decode(result[0])\n"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": 25,
367
+ "id": "40c92741",
368
+ "metadata": {},
369
+ "outputs": [
370
+ {
371
+ "data": {
372
+ "text/plain": [
373
+ "0.6071428571428571"
374
+ ]
375
+ },
376
+ "execution_count": 25,
377
+ "metadata": {},
378
+ "output_type": "execute_result"
379
+ }
380
+ ],
381
+ "source": [
382
+ "total_accepted / total_drafted"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 26,
388
+ "id": "0c661940",
389
+ "metadata": {},
390
+ "outputs": [
391
+ {
392
+ "data": {
393
+ "text/plain": [
394
+ "'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat is a deal flow in a VC fund?<|im_end|>\\n<|im_start|>assistant\\nA deal flow in a VC fund refers to the collection and processing of new investment opportunities presented to the fund for screening, evaluation, and ultimately investment.\\n\\nHere’s a basic overview:\\n\\n* **Deal Flow Collection**: Fund managers typically collect deals through email, cold calls, calendars, meeting notes, investor referrals, and referrals by other investors. They utilize various channels including online searches, investor networks, and industry'"
395
+ ]
396
+ },
397
+ "execution_count": 26,
398
+ "metadata": {},
399
+ "output_type": "execute_result"
400
+ }
401
+ ],
402
+ "source": [
403
+ "inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
404
+ "result = inputs[\"input_ids\"].clone()\n",
405
+ "\n",
406
+ "past_kv = None\n",
407
+ "\n",
408
+ "while result.shape[-1] - inputs[\"input_ids\"].shape[-1] < max_tokens:\n",
409
+ " with torch.no_grad():\n",
410
+ " output = verify_model(\n",
411
+ " result if past_kv is None else result[:, -1:],\n",
412
+ " past_key_values=past_kv,\n",
413
+ " use_cache=True\n",
414
+ " )\n",
415
+ " logits = output.logits[:, -1, :] # batch, vocab\n",
416
+ " \n",
417
+ " past_kv = output.past_key_values\n",
418
+ " probs = torch.softmax(logits, dim=-1)\n",
419
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
420
+ "\n",
421
+ " result = torch.cat([result, next_token], dim=-1)\n",
422
+ " if eos_token in next_token or im_end_token in next_token:\n",
423
+ " break\n",
424
+ "\n",
425
+ "tokenizer.decode(result[0])"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "id": "c7a26417",
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": []
435
+ }
436
+ ],
437
+ "metadata": {
438
+ "kernelspec": {
439
+ "display_name": "speculativedecoding-B0TTdUOs-py3.13",
440
+ "language": "python",
441
+ "name": "python3"
442
+ },
443
+ "language_info": {
444
+ "codemirror_mode": {
445
+ "name": "ipython",
446
+ "version": 3
447
+ },
448
+ "file_extension": ".py",
449
+ "mimetype": "text/x-python",
450
+ "name": "python",
451
+ "nbconvert_exporter": "python",
452
+ "pygments_lexer": "ipython3",
453
+ "version": "3.13.5"
454
+ }
455
+ },
456
+ "nbformat": 4,
457
+ "nbformat_minor": 5
458
+ }
main.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
3
+ import gradio as gr
4
+
5
+ set_seed(67)
6
+
7
+ device = "mps"
8
+
9
+ # Initialize models and tokenizer
10
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-0.5B-Instruct")
11
+ draft_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-0.5B-Instruct", torch_dtype=torch.bfloat16).to(device)
12
+ verify_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-3B-Instruct", torch_dtype=torch.bfloat16).to(device)
13
+
14
+ def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv):
15
+ generated = input_ids.clone()
16
+ draft_probs = []
17
+ for _ in range(gamma):
18
+ with torch.no_grad():
19
+ outputs = draft_model(
20
+ generated if past_kv is None else generated[:, -1:],
21
+ past_key_values=past_kv,
22
+ use_cache=True
23
+ )
24
+ logits = outputs.logits[:, -1, :]
25
+ past_kv = outputs.past_key_values
26
+
27
+ probs = torch.softmax(logits, dim=-1)
28
+
29
+ confidence = probs.max().item()
30
+ if confidence < confidence_threshold and len(draft_probs) > 0:
31
+ break
32
+
33
+ next_token = torch.argmax(probs, dim=-1, keepdim=True)
34
+
35
+ draft_probs.append(probs)
36
+ generated = torch.cat([generated, next_token], dim=-1)
37
+
38
+ if next_token.item() == eos_token:
39
+ break
40
+
41
+ return generated, draft_probs, past_kv
42
+
43
+ def verify(drafted, drafted_probs, eos_token, past_kv):
44
+ draft_len = len(drafted_probs)
45
+ with torch.no_grad():
46
+ if past_kv is None:
47
+ target_outputs = verify_model(drafted, use_cache=True)
48
+ target_logits = target_outputs.logits[:, -draft_len - 1:-1, :]
49
+ else:
50
+ target_outputs = verify_model(
51
+ drafted[:, -(draft_len + 1):],
52
+ past_key_values=past_kv,
53
+ use_cache=True
54
+ )
55
+ target_logits = target_outputs.logits[:, :-1, :]
56
+
57
+ past_kv = target_outputs.past_key_values
58
+
59
+ target_probs = torch.softmax(target_logits, dim=-1)
60
+ accepted_tokens = []
61
+ num_accepted = 0
62
+ for i in range(draft_len):
63
+ q = drafted_probs[i]
64
+ p = target_probs[:, i, :]
65
+ token = drafted[:, i - draft_len]
66
+ x = token[0].item()
67
+ q_x = q[0, x].item()
68
+ p_x = p[0, x].item()
69
+
70
+ if q_x <= p_x:
71
+ accepted_tokens.append(x)
72
+ num_accepted += 1
73
+ else:
74
+ r = torch.rand(1).item()
75
+ acceptance_rate = p_x / q_x
76
+
77
+ if r < acceptance_rate:
78
+ accepted_tokens.append(x)
79
+ num_accepted += 1
80
+ else:
81
+ adjusted = torch.clamp(p - q, min=0)
82
+ adjusted = adjusted / adjusted.sum()
83
+ new_token = torch.multinomial(adjusted, num_samples=1)[0].item()
84
+ accepted_tokens.append(new_token)
85
+ break
86
+ if accepted_tokens[-1] == eos_token:
87
+ break
88
+
89
+ return accepted_tokens, num_accepted, past_kv
90
+
91
+ def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
92
+ # Prepare input
93
+ messages = [{"role": "user", "content": prompt}]
94
+ formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
95
+
96
+ eos_token = tokenizer.eos_token_id
97
+ im_end_token = tokenizer.convert_tokens_to_ids("<|im_end|>")
98
+
99
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
100
+ result = inputs["input_ids"].clone()
101
+
102
+ draft_kv = None
103
+ verify_kv = None
104
+
105
+ total_drafted = 0
106
+ total_accepted = 0
107
+
108
+ steps = []
109
+
110
+ while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
111
+ print(steps)
112
+ drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
113
+ accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)
114
+
115
+ total_drafted += len(drafted_probs)
116
+ total_accepted += num_accepted
117
+
118
+ # Extract token IDs for visualization
119
+ drafted_token_ids = drafted[0, -len(drafted_probs):].tolist()
120
+
121
+ step = {
122
+ "drafted": [tokenizer.decode([t]) for t in drafted_token_ids],
123
+ "accepted": num_accepted,
124
+ "resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None
125
+ }
126
+ steps.append(step)
127
+
128
+ valid_len = result.shape[-1] + num_accepted
129
+ result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)
130
+
131
+ if draft_kv is not None:
132
+ draft_kv.crop(max_length=valid_len)
133
+ if verify_kv is not None:
134
+ verify_kv.crop(max_length=valid_len)
135
+
136
+ if eos_token in accepted_tokens or im_end_token in accepted_tokens:
137
+ break
138
+
139
+ # Extract final output
140
+ final_output = tokenizer.decode(result[0])
141
+
142
+ # Build HTML visualization
143
+ html = "<div style='font-family: monospace;'>"
144
+ html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
145
+ html += f"<b>Final Output:</b><br/>{final_output}"
146
+ html += "</div>"
147
+ html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2pd solid white; border-radius: 5px;'>"
148
+ html += f"<b>Acceptance Rate:</b> {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%"
149
+ html += "</div>"
150
+ html += "<div style='margin-bottom: 10px;'><b>Decoding Steps:</b></div>"
151
+
152
+ for i, step in enumerate(steps):
153
+ html += f"<div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'>"
154
+ html += f"<b>Step {i+1}:</b> "
155
+
156
+ for j, token in enumerate(step["drafted"]):
157
+ # Escape HTML special characters
158
+ token_display = token.replace("<", "&lt;").replace(">", "&gt;")
159
+ if j < step["accepted"]:
160
+ html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 2px; border-radius: 3px;'>{token_display}</span>"
161
+ else:
162
+ html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
163
+
164
+ if step["resampled"]:
165
+ resampled_display = step["resampled"].replace("<", "&lt;").replace(">", "&gt;")
166
+ html += f" β†’ <span style='background: #5AADCC; padding: 2px 4px; border-radius: 3px;'>{resampled_display}</span>"
167
+
168
+ html += "</div>"
169
+ html += "</div>"
170
+
171
+ return html
172
+
173
+ demo = gr.Interface(
174
+ fn=generate_visual,
175
+ inputs=[
176
+ gr.Textbox(label="Prompt", value="What is a deal flow in a VC fund?", lines=3),
177
+ gr.Slider(minimum=10, maximum=100, value=50, step=10, label="Max Tokens"),
178
+ gr.Slider(minimum=1, maximum=30, value=15, step=1, label="Gamma (draft lookahead)"),
179
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold")
180
+ ],
181
+ outputs=gr.HTML(label="Speculative Decoding Visualization"),
182
+ title="πŸš€ Speculative Decoding Demo",
183
+ description="""
184
+ **Speculative Decoding Visualization** using Qwen2.5-Coder models
185
+
186
+ - **Draft Model**: Qwen2.5-Coder-0.5B-Instruct (fast)
187
+ - **Verify Model**: Qwen2.5-Coder-3B-Instruct (accurate)
188
+
189
+ **Color Legend:**
190
+ - 🟒 Green = Accepted tokens from draft model
191
+ - πŸ”΄ Red = Rejected tokens (with strikethrough)
192
+ - πŸ”΅ Blue = Resampled tokens from verify model
193
+ """,
194
+ examples=[
195
+ ["What is a deal flow in a VC fund?", 80, 15, 0.5],
196
+ ["def fibonacci(n):", 50, 15, 0.5],
197
+ ["Explain the concept of attention in transformers", 60, 10, 0.6]
198
+ ]
199
+ )
200
+
201
+ if __name__ == "__main__":
202
+ demo.launch()
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "speculativedecoding"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Harrison <zhuyiyun060209@gmail.com>"]
6
+ readme = "README.md"
7
+ package-mode = false
8
+
9
+ [tool.poetry.dependencies]
10
+ python = ">=3.13"
11
+ transformers = ">=4.57.3,<5.0.0"
12
+ torch = ">=2.0.0"
13
+ ipykernel = "^7.1.0"
14
+ gradio = "^6.1.0"
15
+
16
+
17
+ [build-system]
18
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
19
+ build-backend = "poetry.core.masonry.api"