// Reference the elements we will use const statusLabel = document.getElementById('status'); const fileUpload = document.getElementById('upload'); const imageContainer = document.getElementById('container'); const example = document.getElementById('example'); const maskCanvas = document.getElementById('mask-output'); const uploadButton = document.getElementById('upload-button'); const resetButton = document.getElementById('reset-image'); const clearButton = document.getElementById('clear-points'); const cutButton = document.getElementById('cut-mask'); // State variables let lastPoints = null; let isEncoded = false; let isDecoding = false; let isMultiMaskMode = false; let modelReady = false; let imageDataURI = null; // Constants const BASE_URL = 'https://huggingface.co./datasets/Xenova/transformers.js-docs/resolve/main/'; const EXAMPLE_URL = BASE_URL + 'corgi.jpg'; // Create a web worker so that the main (UI) thread is not blocked during inference. const worker = new Worker('worker.js', { type: 'module', }); // Preload star and cross images to avoid lag on first click const star = new Image(); star.src = BASE_URL + 'star-icon.png'; star.className = 'icon'; const cross = new Image(); cross.src = BASE_URL + 'cross-icon.png'; cross.className = 'icon'; // Set up message handler worker.addEventListener('message', (e) => { const { type, data } = e.data; if (type === 'ready') { modelReady = true; statusLabel.textContent = 'Ready'; } else if (type === 'decode_result') { isDecoding = false; if (!isEncoded) { return; // We are not ready to decode yet } if (!isMultiMaskMode && lastPoints) { // Perform decoding with the last point decode(); lastPoints = null; } const { mask, scores } = data; // Update canvas dimensions (if different) if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { maskCanvas.width = mask.width; maskCanvas.height = mask.height; } // Create context and allocate buffer for pixel data const context = maskCanvas.getContext('2d'); const imageData = context.createImageData(maskCanvas.width, maskCanvas.height); // Select best mask const numMasks = scores.length; // 3 let bestIndex = 0; for (let i = 1; i < numMasks; ++i) { if (scores[i] > scores[bestIndex]) { bestIndex = i; } } statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; // Fill mask with colour const pixelData = imageData.data; for (let i = 0; i < pixelData.length; ++i) { if (mask.data[numMasks * i + bestIndex] === 1) { const offset = 4 * i; pixelData[offset] = 0; // red pixelData[offset + 1] = 114; // green pixelData[offset + 2] = 189; // blue pixelData[offset + 3] = 255; // alpha } } // Draw image data to context context.putImageData(imageData, 0, 0); } else if (type === 'segment_result') { if (data === 'start') { statusLabel.textContent = 'Extracting image embedding...'; } else { statusLabel.textContent = 'Embedding extracted!'; isEncoded = true; } } }); function decode() { isDecoding = true; worker.postMessage({ type: 'decode', data: lastPoints }); } function clearPointsAndMask() { // Reset state isMultiMaskMode = false; lastPoints = null; // Remove points from previous mask (if any) document.querySelectorAll('.icon').forEach(e => e.remove()); // Disable cut button cutButton.disabled = true; // Reset mask canvas maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height); } clearButton.addEventListener('click', clearPointsAndMask); resetButton.addEventListener('click', () => { // Update state isEncoded = false; imageDataURI = null; // Indicate to worker that we have reset the state worker.postMessage({ type: 'reset' }); // Clear points and mask (if present) clearPointsAndMask(); // Update UI cutButton.disabled = true; imageContainer.style.backgroundImage = 'none'; uploadButton.style.display = 'flex'; statusLabel.textContent = 'Ready'; }); function segment(data) { // Update state isEncoded = false; if (!modelReady) { statusLabel.textContent = 'Loading model...'; } imageDataURI = data; // Update UI imageContainer.style.backgroundImage = `url(${data})`; uploadButton.style.display = 'none'; cutButton.disabled = true; // Instruct worker to segment the image worker.postMessage({ type: 'segment', data }); } // Handle file selection fileUpload.addEventListener('change', function (e) { const file = e.target.files[0]; if (!file) { return; } const reader = new FileReader(); // Set up a callback when the file is loaded reader.onload = e2 => segment(e2.target.result); reader.readAsDataURL(file); }); example.addEventListener('click', (e) => { e.preventDefault(); segment(EXAMPLE_URL); }); function addIcon({ point, label }) { const icon = (label === 1 ? star : cross).cloneNode(); icon.style.left = `${point[0] * 100}%`; icon.style.top = `${point[1] * 100}%`; imageContainer.appendChild(icon); } // Attach hover event to image container imageContainer.addEventListener('mousedown', e => { if (e.button !== 0 && e.button !== 2) { return; // Ignore other buttons } if (!isEncoded) { return; // Ignore if not encoded yet } if (!isMultiMaskMode) { lastPoints = []; isMultiMaskMode = true; cutButton.disabled = false; } const point = getPoint(e); lastPoints.push(point); // add icon addIcon(point); decode(); }); // Clamp a value inside a range [min, max] function clamp(x, min = 0, max = 1) { return Math.max(Math.min(x, max), min) } function getPoint(e) { // Get bounding box const bb = imageContainer.getBoundingClientRect(); // Get the mouse coordinates relative to the container const mouseX = clamp((e.clientX - bb.left) / bb.width); const mouseY = clamp((e.clientY - bb.top) / bb.height); return { point: [mouseX, mouseY], label: e.button === 2 // right click ? 0 // negative prompt : 1, // positive prompt } } // Do not show context menu on right click imageContainer.addEventListener('contextmenu', e => { e.preventDefault(); }); // Attach hover event to image container imageContainer.addEventListener('mousemove', e => { if (!isEncoded || isMultiMaskMode) { // Ignore mousemove events if the image is not encoded yet, // or we are in multi-mask mode return; } lastPoints = [getPoint(e)]; if (!isDecoding) { decode(); // Only decode if we are not already decoding } }); // Handle cut button click cutButton.addEventListener('click', () => { const [w, h] = [maskCanvas.width, maskCanvas.height]; // Get the mask pixel data const maskContext = maskCanvas.getContext('2d'); const maskPixelData = maskContext.getImageData(0, 0, w, h); // Load the image const image = new Image(); image.crossOrigin = 'anonymous'; image.onload = async () => { // Create a new canvas to hold the image const imageCanvas = new OffscreenCanvas(w, h); const imageContext = imageCanvas.getContext('2d'); imageContext.drawImage(image, 0, 0, w, h); const imagePixelData = imageContext.getImageData(0, 0, w, h); // Create a new canvas to hold the cut-out const cutCanvas = new OffscreenCanvas(w, h); const cutContext = cutCanvas.getContext('2d'); const cutPixelData = cutContext.getImageData(0, 0, w, h); // Copy the image pixel data to the cut canvas for (let i = 3; i < maskPixelData.data.length; i += 4) { if (maskPixelData.data[i] > 0) { for (let j = 0; j < 4; ++j) { const offset = i - j; cutPixelData.data[offset] = imagePixelData.data[offset]; } } } cutContext.putImageData(cutPixelData, 0, 0); // Download image const link = document.createElement('a'); link.download = 'image.png'; link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); link.click(); link.remove(); } image.src = imageDataURI; });