File size: 6,351 Bytes
d74d427 e287122 d74d427 ca2a338 d74d427 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import re
from threading import Thread
from typing import List
import torch
import solara
from unicodedata import normalize
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.generation import LogitsProcessor
from typing_extensions import TypedDict
# Auto select device (CUDA > MPS > CPU)
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
model_id = "Qwen/Qwen3-0.6B"
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
def response_generator(user_input, logits_processor=[], enable_thinking=False):
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": user_input}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking
)
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
generation_kwargs = dict(
model_inputs,
streamer=streamer,
logits_processor=logits_processor,
max_new_tokens=4 * 1024,
do_sample=True,
temperature=0.7,
top_p=1.0,
top_k=50,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for chunk in streamer:
if tokenizer.eos_token in chunk or tokenizer.pad_token in chunk:
chunk = chunk.split(tokenizer.eos_token)[0]
chunk = chunk.split(tokenizer.pad_token)[0]
yield chunk
thread.join()
list_of_vowels = ["a", "e", "i", "o", "u"]
tokens_per_vowel = dict()
for vowel in list_of_vowels:
tokens_containing_a_given_vowel = []
for token_id in range(tokenizer.vocab_size):
if (
vowel in tokenizer.decode(token_id)
or vowel.upper() in tokenizer.decode(token_id)
or normalize('NFC', f"{vowel}\u0300") in tokenizer.decode(token_id)
or normalize('NFC', f"{vowel}\u0301") in tokenizer.decode(token_id)
or normalize('NFC', f"{vowel}\u0302") in tokenizer.decode(token_id)
or normalize('NFC', f"{vowel}\u0303") in tokenizer.decode(token_id)
or normalize('NFC', f"{vowel}\u0308") in tokenizer.decode(token_id)
):
tokens_containing_a_given_vowel.append(token_id)
tokens_per_vowel[vowel] = tokens_containing_a_given_vowel
class GeorgePerecLogitsProcessor(LogitsProcessor):
def __init__(self, forbidden_tokens: List[int]):
self.forbidden_tokens = forbidden_tokens
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
scores_processed = scores.clone()
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
forbidden_tokens = torch.tensor(self.forbidden_tokens, device=scores.device)
forbidden_tokens_mask = torch.isin(vocab_tensor, forbidden_tokens)
scores_processed = torch.where(forbidden_tokens_mask, -torch.inf, scores)
return scores_processed
def add_chunk_to_ai_message(chunk: str):
messages.value = [
*messages.value[:-1],
{
"role": "assistant",
"content": messages.value[-1]["content"] + chunk,
},
]
class MessageDict(TypedDict):
role: str
content: str
messages: solara.Reactive[List[MessageDict]] = solara.reactive([])
enable_thinking_options = [True, False]
enable_thinking = solara.reactive(False)
vowels = ["a", "e", "i", "o", "u", "None"]
vowel = solara.reactive("e")
@solara.component
def Page():
solara.lab.theme.themes.light.primary = "#0000ff"
solara.lab.theme.themes.light.secondary = "#0000ff"
solara.lab.theme.themes.dark.primary = "#0000ff"
solara.lab.theme.themes.dark.secondary = "#0000ff"
title = "Georges Perec"
with solara.Head():
solara.Title(f"{title}")
with solara.Column(align="center"):
with solara.Sidebar():
solara.Markdown("# G⎵org⎵s P⎵r⎵c")
solara.Markdown("## Forcing a language model not to use a vowel")
solara.Markdown("Select a forbidden vowel:")
solara.ToggleButtonsSingle(value=vowel, values=vowels)
solara.Markdown("Enable thinking:")
solara.ToggleButtonsSingle(value=enable_thinking, values=enable_thinking_options)
if vowel.value == "None":
logits_processor = []
else:
logits_processor = [
GeorgePerecLogitsProcessor(
forbidden_tokens=tokens_per_vowel[vowel.value],
)
]
user_message_count = len([m for m in messages.value if m["role"] == "user"])
def send(message):
messages.value = [*messages.value, {"role": "user", "content": message}]
def response(message):
messages.value = [*messages.value, {"role": "assistant", "content": ""}]
for chunk in response_generator(message, logits_processor=logits_processor, enable_thinking=enable_thinking.value):
add_chunk_to_ai_message(chunk)
def result():
if messages.value != []:
response(messages.value[-1]["content"])
result = solara.lab.use_task(result, dependencies=[user_message_count])
with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll","scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "60%"}):
for item in messages.value:
with solara.lab.ChatMessage(
user=item["role"] == "user",
name="User" if item["role"] == "user" else "Assistant",
avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f",
border_radius="20px",
style="background-color:darkgrey!important;" if solara.lab.theme.dark_effective else "background-color:lightgrey!important;"
):
solara.Markdown(item["content"])
solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "60%"})
|