| | import { |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | TextStreamer, |
| | InterruptableStoppingCriteria, |
| | } from "@huggingface/transformers"; |
| |
|
| | |
| | |
| | |
| | |
| | async function check() { |
| | try { |
| | const adapter = await navigator.gpu.requestAdapter(); |
| | if (!adapter) { |
| | throw new Error("WebGPU is not supported (no adapter found)"); |
| | } |
| | |
| | } catch (e) { |
| | self.postMessage({ |
| | status: "error", |
| | data: e.toString(), |
| | }); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | class TextGenerationPipeline { |
| | static model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"; |
| |
|
| | static async getInstance(progress_callback = null) { |
| | this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, { |
| | progress_callback, |
| | }); |
| |
|
| | this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, { |
| | dtype: "q4f16", |
| | device: "webgpu", |
| | progress_callback, |
| | }); |
| |
|
| | return Promise.all([this.tokenizer, this.model]); |
| | } |
| | } |
| |
|
| | const stopping_criteria = new InterruptableStoppingCriteria(); |
| |
|
| | let past_key_values_cache = null; |
| | async function generate(messages) { |
| | |
| | const [tokenizer, model] = await TextGenerationPipeline.getInstance(); |
| |
|
| | const inputs = tokenizer.apply_chat_template(messages, { |
| | add_generation_prompt: true, |
| | return_dict: true, |
| | }); |
| |
|
| | let startTime; |
| | let numTokens = 0; |
| | let tps; |
| | const token_callback_function = () => { |
| | startTime ??= performance.now(); |
| |
|
| | if (numTokens++ > 0) { |
| | tps = (numTokens / (performance.now() - startTime)) * 1000; |
| | } |
| | }; |
| | const callback_function = (output) => { |
| | self.postMessage({ |
| | status: "update", |
| | output, |
| | tps, |
| | numTokens, |
| | }); |
| | }; |
| |
|
| | const streamer = new TextStreamer(tokenizer, { |
| | skip_prompt: true, |
| | skip_special_tokens: true, |
| | callback_function, |
| | token_callback_function, |
| | }); |
| |
|
| | |
| | self.postMessage({ status: "start" }); |
| |
|
| | const { past_key_values, sequences } = await model.generate({ |
| | ...inputs, |
| | past_key_values: past_key_values_cache, |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | max_new_tokens: 1024, |
| | streamer, |
| | stopping_criteria, |
| | return_dict_in_generate: true, |
| | }); |
| | past_key_values_cache = past_key_values; |
| |
|
| | const decoded = tokenizer.batch_decode(sequences, { |
| | skip_special_tokens: true, |
| | }); |
| |
|
| | |
| | self.postMessage({ |
| | status: "complete", |
| | output: decoded, |
| | }); |
| | } |
| |
|
| | async function load() { |
| | self.postMessage({ |
| | status: "loading", |
| | data: "Loading model...", |
| | }); |
| |
|
| | |
| | const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => { |
| | |
| | |
| | self.postMessage(x); |
| | }); |
| |
|
| | self.postMessage({ |
| | status: "loading", |
| | data: "Compiling shaders and warming up model...", |
| | }); |
| |
|
| | |
| | const inputs = tokenizer("a"); |
| | await model.generate({ ...inputs, max_new_tokens: 1 }); |
| | self.postMessage({ status: "ready" }); |
| | } |
| | |
| | self.addEventListener("message", async (e) => { |
| | const { type, data } = e.data; |
| |
|
| | switch (type) { |
| | case "check": |
| | check(); |
| | break; |
| |
|
| | case "load": |
| | load(); |
| | break; |
| |
|
| | case "generate": |
| | stopping_criteria.reset(); |
| | generate(data); |
| | break; |
| |
|
| | case "interrupt": |
| | stopping_criteria.interrupt(); |
| | break; |
| |
|
| | case "reset": |
| | past_key_values_cache = null; |
| | stopping_criteria.reset(); |
| | break; |
| | } |
| | }); |
| |
|