import { AutoModelForCausalLM, AutoTokenizer, TextStreamer, } from "@huggingface/transformers"; // Worker state let model: any = null; let tokenizer: any = null; let pastKeyValues: any = null; let isGenerating = false; // Cache for loaded models const modelCache: { [modelId: string]: { model: any; tokenizer: any; }; } = {}; // Message types from main thread interface LoadMessage { type: "load"; modelId: string; } interface GenerateMessage { type: "generate"; messages: Array<{ role: string; content: string }>; tools: Array; } interface InterruptMessage { type: "interrupt"; } interface ResetMessage { type: "reset"; } type WorkerMessage = LoadMessage | GenerateMessage | InterruptMessage | ResetMessage; // Message types to main thread interface ProgressMessage { type: "progress"; progress: number; file?: string; } interface ReadyMessage { type: "ready"; } interface UpdateMessage { type: "update"; token: string; tokensPerSecond: number; numTokens: number; } interface CompleteMessage { type: "complete"; text: string; } interface ErrorMessage { type: "error"; error: string; } type WorkerResponse = ProgressMessage | ReadyMessage | UpdateMessage | CompleteMessage | ErrorMessage; function postMessage(message: WorkerResponse) { self.postMessage(message); } // Load model async function loadModel(modelId: string) { try { // Check cache first if (modelCache[modelId]) { model = modelCache[modelId].model; tokenizer = modelCache[modelId].tokenizer; postMessage({ type: "ready" }); return; } const progressCallback = (progress: any) => { if ( progress.status === "progress" && progress.file.endsWith(".onnx_data") ) { const percentage = Math.round( (progress.loaded / progress.total) * 100 ); postMessage({ type: "progress", progress: percentage, file: progress.file, }); } }; // Load tokenizer tokenizer = await AutoTokenizer.from_pretrained(modelId, { progress_callback: progressCallback, }); // Load model model = await AutoModelForCausalLM.from_pretrained(modelId, { dtype: "q4f16", device: "webgpu", progress_callback: progressCallback, }); // Pre-warm the model with a dummy input for shader compilation const dummyInput = tokenizer("Hello", { return_tensors: "pt", padding: false, truncation: false, }); await model.generate({ ...dummyInput, max_new_tokens: 1, do_sample: false, }); // Cache the loaded model modelCache[modelId] = { model, tokenizer }; postMessage({ type: "ready" }); } catch (error) { postMessage({ type: "error", error: error instanceof Error ? error.message : "Failed to load model", }); } } // Generate response async function generate( messages: Array<{ role: string; content: string }>, tools: Array ) { if (!model || !tokenizer) { postMessage({ type: "error", error: "Model not loaded" }); return; } try { isGenerating = true; // Apply chat template with tools const input = tokenizer.apply_chat_template(messages, { tools, add_generation_prompt: true, return_dict: true, }); // Track tokens and timing const startTime = performance.now(); let tokenCount = 0; const streamer = new TextStreamer(tokenizer, { skip_prompt: true, skip_special_tokens: false, callback_function: (token: string) => { if (!isGenerating) return; // Check if interrupted tokenCount++; const elapsed = (performance.now() - startTime) / 1000; const tps = tokenCount / elapsed; postMessage({ type: "update", token, tokensPerSecond: tps, numTokens: tokenCount, }); }, }); // Generate the response const { sequences, past_key_values } = await model.generate({ ...input, past_key_values: pastKeyValues, max_new_tokens: 1024, do_sample: false, streamer, return_dict_in_generate: true, }); pastKeyValues = past_key_values; // Decode the generated text const response = tokenizer .batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), { skip_special_tokens: false, })[0] .replace(/<\|im_end\|>$/, "") .replace(/<\|end_of_text\|>$/, ""); if (isGenerating) { postMessage({ type: "complete", text: response }); } isGenerating = false; } catch (error) { isGenerating = false; postMessage({ type: "error", error: error instanceof Error ? error.message : "Generation failed", }); } } // Interrupt generation function interrupt() { isGenerating = false; // Send a completion message with empty text to resolve the promise postMessage({ type: "complete", text: "" }); } // Reset past key values function reset() { pastKeyValues = null; } // Handle messages from main thread self.onmessage = async (e: MessageEvent) => { const message = e.data; switch (message.type) { case "load": await loadModel(message.modelId); break; case "generate": await generate(message.messages, message.tools); break; case "interrupt": interrupt(); break; case "reset": reset(); break; } }; // Export for TypeScript export {};