File size: 10,545 Bytes
6a50f6f
 
 
 
 
 
c576ce6
6a50f6f
 
65b04d1
 
 
6a50f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d1ef4b
 
 
 
 
9aa1873
 
 
 
5d1ef4b
9aa1873
5d1ef4b
 
 
 
 
 
 
 
 
 
 
9aa1873
 
 
5d1ef4b
 
9aa1873
5d1ef4b
9aa1873
5d1ef4b
9aa1873
 
 
 
 
 
 
5d1ef4b
9aa1873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a50f6f
5d1ef4b
6a50f6f
9aa1873
 
 
5d1ef4b
 
 
9aa1873
 
 
 
 
 
 
6a50f6f
9aa1873
 
 
5d1ef4b
9aa1873
6a50f6f
 
5d1ef4b
 
9aa1873
5d1ef4b
9aa1873
5d1ef4b
9aa1873
5d1ef4b
 
 
 
 
 
 
6a50f6f
9aa1873
 
6a50f6f
 
 
9aa1873
 
6a50f6f
 
 
 
 
 
 
 
 
 
 
 
9aa1873
6a50f6f
 
 
 
9794109
6a50f6f
 
 
 
 
 
 
65b04d1
6a50f6f
65b04d1
 
6a50f6f
 
 
 
 
9aa1873
 
6a50f6f
 
5d1ef4b
 
6a50f6f
 
 
 
 
9aa1873
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import gradio as gr

set_seed(67)

device = "cpu"

# Initialize models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
draft_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", torch_dtype=torch.bfloat16).to(device)
verify_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct", torch_dtype=torch.bfloat16).to(device)

def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv):
    generated = input_ids.clone()
    draft_probs = []
    for _ in range(gamma):
        with torch.no_grad():
            outputs = draft_model(
                generated if past_kv is None else generated[:, -1:],
                past_key_values=past_kv,
                use_cache=True
            )
            logits = outputs.logits[:, -1, :]
            past_kv = outputs.past_key_values

        probs = torch.softmax(logits, dim=-1)

        confidence = probs.max().item()
        if confidence < confidence_threshold and len(draft_probs) > 0:
            break

        next_token = torch.argmax(probs, dim=-1, keepdim=True)

        draft_probs.append(probs)
        generated = torch.cat([generated, next_token], dim=-1)

        if next_token.item() == eos_token:
            break

    return generated, draft_probs, past_kv

def verify(drafted, drafted_probs, eos_token, past_kv):
    draft_len = len(drafted_probs)
    with torch.no_grad():
        if past_kv is None:
            target_outputs = verify_model(drafted, use_cache=True)
            target_logits = target_outputs.logits[:, -draft_len - 1:-1, :]
        else:
            target_outputs = verify_model(
                drafted[:, -(draft_len + 1):],
                past_key_values=past_kv,
                use_cache=True
            )
            target_logits = target_outputs.logits[:, :-1, :]

        past_kv = target_outputs.past_key_values

    target_probs = torch.softmax(target_logits, dim=-1)
    accepted_tokens = []
    num_accepted = 0
    for i in range(draft_len):
        q = drafted_probs[i]
        p = target_probs[:, i, :]
        token = drafted[:, i - draft_len]
        x = token[0].item()
        q_x = q[0, x].item()
        p_x = p[0, x].item()

        if q_x <= p_x:
            accepted_tokens.append(x)
            num_accepted += 1
        else:
            r = torch.rand(1).item()
            acceptance_rate = p_x / q_x

            if r < acceptance_rate:
                accepted_tokens.append(x)
                num_accepted += 1
            else:
                adjusted = torch.clamp(p - q, min=0)
                adjusted = adjusted / adjusted.sum()
                new_token = torch.multinomial(adjusted, num_samples=1)[0].item()
                accepted_tokens.append(new_token)
                break
        if accepted_tokens[-1] == eos_token:
            break

    return accepted_tokens, num_accepted, past_kv

def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5):
    # Prepare input
    messages = [{"role": "user", "content": prompt}]
    formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    eos_token = tokenizer.eos_token_id
    im_end_token = tokenizer.convert_tokens_to_ids("<|im_end|>")

    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
    result = inputs["input_ids"].clone()

    draft_kv = None
    verify_kv = None

    total_drafted = 0
    total_accepted = 0

    steps = []

    # Track the clean output tokens (only accepted/resampled)
    clean_output_tokens = []
    all_tokens = []
    # Metadata for ALL tokens: 'accepted', 'rejected', or 'resampled'
    all_token_metadata = []

    def build_html():
        html = "<div style='font-family: monospace;'>"

        # Clean final output box
        html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
        html += f"<b>Final Output (Clean):</b><br/>"
        if clean_output_tokens:
            clean_text = tokenizer.decode(clean_output_tokens)
            html += clean_text
        html += "</div>"

        # Detailed output box
        html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
        html += f"<b>Detailed Output (All Tokens):</b><br/>"
        if all_tokens:
            for i, token_id in enumerate(all_tokens):
                token_text = tokenizer.decode([token_id])
                token_display = token_text.replace("<", "&lt;").replace(">", "&gt;")

                if i < len(all_token_metadata):
                    if all_token_metadata[i] == 'accepted':
                        html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
                    elif all_token_metadata[i] == 'resampled':
                        html += f"<span style='background: #5AADCC; padding: 2px 4px; margin: 1px; border-radius: 3px;'>{token_display}</span>"
                    elif all_token_metadata[i] == 'rejected':
                        html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 1px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"
                else:
                    html += token_display
        html += "</div>"

        # Acceptance rate
        if total_drafted > 0:
            html += f"<div style='margin-bottom: 20px; padding: 10px; background: transparent; border: 2px solid white; border-radius: 5px;'>"
            html += f"<b>Acceptance Rate:</b> {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%"
            html += "</div>"

        # Decoding steps
        html += "<div style='margin-bottom: 10px;'><b>Decoding Steps:</b></div>"
        for i, step in enumerate(steps):
            html += f"<div style='margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px;'>"
            html += f"<b>Step {i+1}:</b> "

            for j, token in enumerate(step["drafted"]):
                token_display = token.replace("<", "&lt;").replace(">", "&gt;")
                if j < step["accepted"]:
                    html += f"<span style='background: #66CC66; padding: 2px 4px; margin: 2px; border-radius: 3px;'>{token_display}</span>"
                else:
                    html += f"<span style='background: #FF8B9A; padding: 2px 4px; margin: 2px; text-decoration: line-through; border-radius: 3px;'>{token_display}</span>"

            if step["resampled"]:
                resampled_display = step["resampled"].replace("<", "&lt;").replace(">", "&gt;")
                html += f" β†’ <span style='background: #5AADCC; padding: 2px 4px; border-radius: 3px;'>{resampled_display}</span>"

            html += "</div>"
        html += "</div>"
        return html

    while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens:
        # Draft
        drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv)
        drafted_token_ids = drafted[0, -len(drafted_probs):].tolist()
        drafted_tokens = [tokenizer.decode([t]) for t in drafted_token_ids]

        clean_output_tokens.extend(drafted_token_ids)
        all_tokens.extend(drafted_token_ids)
        all_token_metadata.extend(['accepted'] * len(drafted_token_ids))

        temp_step = {
            "drafted": drafted_tokens,
            "accepted": len(drafted_tokens),
            "resampled": None
        }
        steps.append(temp_step)
        total_drafted += len(drafted_probs)

        yield build_html()

        # Verify
        accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv)
        total_accepted += num_accepted

        clean_output_tokens = clean_output_tokens[:-len(drafted_token_ids)]
        all_token_metadata = all_token_metadata[:-len(drafted_token_ids)]

        for i, token_id in enumerate(drafted_token_ids):
            if i < num_accepted:
                all_token_metadata.append('accepted')
            else:
                all_token_metadata.append('rejected')

        clean_output_tokens.extend(accepted_tokens)

        if num_accepted < len(accepted_tokens):
            all_tokens.append(accepted_tokens[-1])
            all_token_metadata.append('resampled')

        steps[-1] = {
            "drafted": drafted_tokens,
            "accepted": num_accepted,
            "resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None
        }

        yield build_html()

        valid_len = result.shape[-1] + num_accepted
        result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1)

        if draft_kv is not None:
            draft_kv.crop(max_length=valid_len)
        if verify_kv is not None:
            verify_kv.crop(max_length=valid_len)

        if eos_token in accepted_tokens or im_end_token in accepted_tokens:
            break

    yield build_html()

demo = gr.Interface(
    fn=generate_visual,
    inputs=[
        gr.Textbox(label="Prompt", value="What is the capital of France?", lines=3),
        gr.Slider(minimum=10, maximum=100, value=50, step=10, label="Max Tokens"),
        gr.Slider(minimum=1, maximum=30, value=15, step=1, label="Gamma (draft lookahead)"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold")
    ],
    outputs=gr.HTML(label="Speculative Decoding Visualization"),
    title="πŸš€ Speculative Decoding Demo",
    description="""
    **Speculative Decoding Visualization** using SmolLM2 models

    - **Draft Model**: HuggingFaceTB/SmolLM2-135M-Instruct (fast)
    - **Verify Model**: HuggingFaceTB/SmolLM2-1.7B-Instruct (accurate)

    **Color Legend:**
    - 🟒 Green = Accepted tokens from draft model
    - πŸ”΄ Red = Rejected tokens (with strikethrough)
    - πŸ”΅ Blue = Resampled tokens from verify model

    **Watch the tokens stream in real-time!** Draft tokens appear immediately, then get accepted or rejected by the verify model.
    """,
    examples=[
        ["What is the capital of France?", 80, 15, 0.5],
        ["Complete the python function \n def fibonacci(n):", 50, 15, 0.5],
        ["Explain the concept of attention in transformers", 60, 10, 0.6]
    ]
)

if __name__ == "__main__":
    demo.launch()