import torch from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed import gradio as gr set_seed(67) device = "cpu" # Initialize models and tokenizer tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct") draft_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", torch_dtype=torch.bfloat16).to(device) verify_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-1.7B-Instruct", torch_dtype=torch.bfloat16).to(device) def draft(input_ids, gamma, confidence_threshold, eos_token, past_kv): generated = input_ids.clone() draft_probs = [] for _ in range(gamma): with torch.no_grad(): outputs = draft_model( generated if past_kv is None else generated[:, -1:], past_key_values=past_kv, use_cache=True ) logits = outputs.logits[:, -1, :] past_kv = outputs.past_key_values probs = torch.softmax(logits, dim=-1) confidence = probs.max().item() if confidence < confidence_threshold and len(draft_probs) > 0: break next_token = torch.argmax(probs, dim=-1, keepdim=True) draft_probs.append(probs) generated = torch.cat([generated, next_token], dim=-1) if next_token.item() == eos_token: break return generated, draft_probs, past_kv def verify(drafted, drafted_probs, eos_token, past_kv): draft_len = len(drafted_probs) with torch.no_grad(): if past_kv is None: target_outputs = verify_model(drafted, use_cache=True) target_logits = target_outputs.logits[:, -draft_len - 1:-1, :] else: target_outputs = verify_model( drafted[:, -(draft_len + 1):], past_key_values=past_kv, use_cache=True ) target_logits = target_outputs.logits[:, :-1, :] past_kv = target_outputs.past_key_values target_probs = torch.softmax(target_logits, dim=-1) accepted_tokens = [] num_accepted = 0 for i in range(draft_len): q = drafted_probs[i] p = target_probs[:, i, :] token = drafted[:, i - draft_len] x = token[0].item() q_x = q[0, x].item() p_x = p[0, x].item() if q_x <= p_x: accepted_tokens.append(x) num_accepted += 1 else: r = torch.rand(1).item() acceptance_rate = p_x / q_x if r < acceptance_rate: accepted_tokens.append(x) num_accepted += 1 else: adjusted = torch.clamp(p - q, min=0) adjusted = adjusted / adjusted.sum() new_token = torch.multinomial(adjusted, num_samples=1)[0].item() accepted_tokens.append(new_token) break if accepted_tokens[-1] == eos_token: break return accepted_tokens, num_accepted, past_kv def generate_visual(prompt, max_tokens=50, gamma=15, confidence_threshold=0.5): # Prepare input messages = [{"role": "user", "content": prompt}] formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) eos_token = tokenizer.eos_token_id im_end_token = tokenizer.convert_tokens_to_ids("<|im_end|>") inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) result = inputs["input_ids"].clone() draft_kv = None verify_kv = None total_drafted = 0 total_accepted = 0 steps = [] # Track the clean output tokens (only accepted/resampled) clean_output_tokens = [] all_tokens = [] # Metadata for ALL tokens: 'accepted', 'rejected', or 'resampled' all_token_metadata = [] def build_html(): html = "
" # Clean final output box html += f"
" html += f"Final Output (Clean):
" if clean_output_tokens: clean_text = tokenizer.decode(clean_output_tokens) html += clean_text html += "
" # Detailed output box html += f"
" html += f"Detailed Output (All Tokens):
" if all_tokens: for i, token_id in enumerate(all_tokens): token_text = tokenizer.decode([token_id]) token_display = token_text.replace("<", "<").replace(">", ">") if i < len(all_token_metadata): if all_token_metadata[i] == 'accepted': html += f"{token_display}" elif all_token_metadata[i] == 'resampled': html += f"{token_display}" elif all_token_metadata[i] == 'rejected': html += f"{token_display}" else: html += token_display html += "
" # Acceptance rate if total_drafted > 0: html += f"
" html += f"Acceptance Rate: {total_accepted}/{total_drafted} = {total_accepted/total_drafted*100:.1f}%" html += "
" # Decoding steps html += "
Decoding Steps:
" for i, step in enumerate(steps): html += f"
" html += f"Step {i+1}: " for j, token in enumerate(step["drafted"]): token_display = token.replace("<", "<").replace(">", ">") if j < step["accepted"]: html += f"{token_display}" else: html += f"{token_display}" if step["resampled"]: resampled_display = step["resampled"].replace("<", "<").replace(">", ">") html += f" → {resampled_display}" html += "
" html += "
" return html while result.shape[-1] - inputs["input_ids"].shape[-1] < max_tokens: # Draft drafted, drafted_probs, draft_kv = draft(result, gamma, confidence_threshold, eos_token, draft_kv) drafted_token_ids = drafted[0, -len(drafted_probs):].tolist() drafted_tokens = [tokenizer.decode([t]) for t in drafted_token_ids] clean_output_tokens.extend(drafted_token_ids) all_tokens.extend(drafted_token_ids) all_token_metadata.extend(['accepted'] * len(drafted_token_ids)) temp_step = { "drafted": drafted_tokens, "accepted": len(drafted_tokens), "resampled": None } steps.append(temp_step) total_drafted += len(drafted_probs) yield build_html() # Verify accepted_tokens, num_accepted, verify_kv = verify(drafted, drafted_probs, eos_token, verify_kv) total_accepted += num_accepted clean_output_tokens = clean_output_tokens[:-len(drafted_token_ids)] all_token_metadata = all_token_metadata[:-len(drafted_token_ids)] for i, token_id in enumerate(drafted_token_ids): if i < num_accepted: all_token_metadata.append('accepted') else: all_token_metadata.append('rejected') clean_output_tokens.extend(accepted_tokens) if num_accepted < len(accepted_tokens): all_tokens.append(accepted_tokens[-1]) all_token_metadata.append('resampled') steps[-1] = { "drafted": drafted_tokens, "accepted": num_accepted, "resampled": tokenizer.decode([accepted_tokens[-1]]) if num_accepted < len(accepted_tokens) else None } yield build_html() valid_len = result.shape[-1] + num_accepted result = torch.cat([result, torch.tensor([accepted_tokens], device=device)], dim=-1) if draft_kv is not None: draft_kv.crop(max_length=valid_len) if verify_kv is not None: verify_kv.crop(max_length=valid_len) if eos_token in accepted_tokens or im_end_token in accepted_tokens: break yield build_html() demo = gr.Interface( fn=generate_visual, inputs=[ gr.Textbox(label="Prompt", value="What is the capital of France?", lines=3), gr.Slider(minimum=10, maximum=100, value=50, step=10, label="Max Tokens"), gr.Slider(minimum=1, maximum=30, value=15, step=1, label="Gamma (draft lookahead)"), gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold") ], outputs=gr.HTML(label="Speculative Decoding Visualization"), title="🚀 Speculative Decoding Demo", description=""" **Speculative Decoding Visualization** using SmolLM2 models - **Draft Model**: HuggingFaceTB/SmolLM2-135M-Instruct (fast) - **Verify Model**: HuggingFaceTB/SmolLM2-1.7B-Instruct (accurate) **Color Legend:** - 🟢 Green = Accepted tokens from draft model - 🔴 Red = Rejected tokens (with strikethrough) - 🔵 Blue = Resampled tokens from verify model **Watch the tokens stream in real-time!** Draft tokens appear immediately, then get accepted or rejected by the verify model. """, examples=[ ["What is the capital of France?", 80, 15, 0.5], ["Complete the python function \n def fibonacci(n):", 50, 15, 0.5], ["Explain the concept of attention in transformers", 60, 10, 0.6] ] ) if __name__ == "__main__": demo.launch()