import { select } from "https://cdn.skypack.dev/d3-selection@3"; import { color, hcl } from "https://cdn.skypack.dev/d3-color@3"; import { piecewise, interpolateRgb, } from "https://cdn.skypack.dev/d3-interpolate@3"; const d3 = { select, color, hcl, piecewise, interpolateRgb }; const colorPalette = d3.piecewise(d3.interpolateRgb.gamma(2.2), [ "#0E21A0", "#4D2DB7", "#9D44C0", "#EC53B0", "#F57B55", "#FFDB5A", ]); class WorkerPool { constructor(size) { this.workers = []; this.size = size; this.init(); } get size() { return this._size; } set size(size) { this._size = size; } init() { for (let i = 0; i < this.size; i++) { this.workers.push({ worker: new Worker("./llama2c.worker.js", { type: "module" }), inUse: false, id: window.crypto.randomUUID(), }); } } getWorker() { return new Promise((resolve) => { const check = () => { for (let worker of this.workers) { if (!worker.inUse) { worker.inUse = true; resolve(worker.worker); return; } } setTimeout(check, 100); }; check(); }); } releaseWorker(worker) { this.workers.find((w) => w.worker === worker).inUse = false; } } // base url for audio examples const MODELS_BASE_URL = "https://huggingface.co./karpathy/tinyllamas/resolve/main"; // models base url const MODELS = { stories15M: { url: "stories15M.bin", seq_len: 256, }, stories42M: { url: "stories42M.bin", seq_len: 1024, }, stories110M: { url: "stories110M.bin", seq_len: 1024, }, }; const workerPool = new WorkerPool(window.navigator.hardwareConcurrency); async function generateSequence({ worker, prompt, weightsURL, modelID, maxSeqLen, temp, repeatPenalty, top_p, contentEl, controller, }) { return new Promise((resolve, reject) => { const seed = BigInt( `0x${ Math.floor(Math.random() * Number.MAX_SAFE_INTEGER).toString(16) + Date.now().toString(16) }` ); worker.postMessage({ weightsURL, modelID, tokenizerURL: "tokenizer.json", prompt, temp, repeatPenalty, top_p, seed: seed, maxSeqLen, command: "start", }); function handleAbort() { worker.postMessage({ command: "abort" }); } function updateStatus(data) { if (data.status === "loading") { contentEl.innerHTML = ``; } if (data.status === "aborted") { contentEl.innerHTML = ``; } if (data.status === "generating") { const { message, prompt, sentence, tokensSec, totalTime } = data; contentEl.innerHTML = `${prompt} ${sentence.replace(/\|\<\/s\>/g, "")}`; } } const handleMessage = (event) => { const { status, error } = event.data; if (status) updateStatus(event.data); if (error) { workerPool.releaseWorker(worker); worker.removeEventListener("message", handleMessage); reject(new Error(error)); } if (status === "complete" || status === "aborted") { workerPool.releaseWorker(worker); worker.removeEventListener("message", handleMessage); resolve(event.data); } }; controller.signal.addEventListener("abort", handleAbort); worker.addEventListener("message", handleMessage); }); } async function initWorkers() { const containerEl = d3.select("#container"); d3.select("#model").on("input", (e) => { const model = MODELS[e.target.value]; d3.select("#max-seq").property("max", model.seq_len); }); const containers = await Promise.all( workerPool.workers.map(async (_, i) => { const contentEl = d3.select(document.createElement("div")); contentEl.on("pointerover", (e) => { e.currentTarget.classList.add("c-hover"); }); contentEl.on("pointerdown", (e) => { e.currentTarget.classList.toggle("c-hover"); }); contentEl.on("pointerout pointercancel pointerleave", (e) => { e.currentTarget.classList.remove("c-hover"); }); const bgColor = colorPalette(i / workerPool.size); const fontColor = d3.hcl(bgColor).l < 50 ? "#fff" : "#000"; contentEl .style("background-color", bgColor) .style("color", fontColor) .style("grid-row-start", `${i + 1}`) .classed("c-block ", true); containerEl.append(() => contentEl.node()); return [contentEl.node(), await workerPool.getWorker()]; }) ); return containers; } async function fetchArrayBuffer(url) { const cacheName = "llama2c-candle-cache"; const cache = await caches.open(cacheName); const cachedResponse = await cache.match(url); if (cachedResponse) { const data = await cachedResponse.arrayBuffer(); return new Uint8Array(data); } const res = await fetch(url, { cache: "force-cache" }); const resClone = res.clone(); const arrayBuffer = await res.arrayBuffer(); await cache.put(url, resClone); return new Uint8Array(arrayBuffer); } async function run(containers, controller) { const getValue = (e) => e.value; const prompt = document.querySelector("#prompt"); const maxSeqLen = document.querySelector("#max-seq"); const temp = document.querySelector("#temperature"); const repeatPenalty = document.querySelector("#repeat-penalty"); const topP = document.querySelector("#top-p"); const modelID = document.querySelector("#model"); const weightsURL = `${MODELS_BASE_URL}/${MODELS[getValue(modelID)].url}`; maxSeqLen.disabled = true; temp.disabled = true; repeatPenalty.disabled = true; modelID.disabled = true; // pre fetch and cache weights and tokenizer await Promise.all([ fetchArrayBuffer(weightsURL), fetchArrayBuffer(`tokenizer.json`), ]); await Promise.all( containers.map(([container, worker]) => generateSequence({ worker, prompt: getValue(prompt), weightsURL, modelID: getValue(modelID), maxSeqLen: getValue(maxSeqLen), temp: getValue(temp), top_p: getValue(topP), repeatPenalty: getValue(repeatPenalty), contentEl: container, controller, }) ) ); maxSeqLen.disabled = false; temp.disabled = false; repeatPenalty.disabled = false; modelID.disabled = false; } initWorkers().then((containers) => { const runBtn = document.querySelector("#run"); let runController = new AbortController(); let isRunning = false; d3.select("#form").on("submit", async (e) => { e.preventDefault(); if (isRunning) { stopRunning(); } else { startRunning(); await run(containers, runController); stopRunning(); } }); function startRunning() { isRunning = true; runBtn.innerText = "Stop"; } function stopRunning() { runBtn.innerText = "Run"; runController.abort(); runController = new AbortController(); isRunning = false; } });