import { SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.4'; // Reference the elements we will use const statusLabel = document.getElementById('status'); const fileUpload = document.getElementById('upload'); const imageContainer = document.getElementById('container'); const example = document.getElementById('example'); const maskCanvas = document.getElementById('mask-output'); const uploadButton = document.getElementById('upload-button'); const resetButton = document.getElementById('reset-image'); const clearButton = document.getElementById('clear-points'); const cutButton = document.getElementById('cut-mask'); // Constants const BASE_URL = 'https://huggingface.co./datasets/Xenova/transformers.js-docs/resolve/main/'; const EXAMPLE_URL = BASE_URL + 'corgi.jpg'; // Preload star and cross images to avoid lag on first click const star = new Image(); star.src = BASE_URL + 'star-icon.png'; star.className = 'icon'; const cross = new Image(); cross.src = BASE_URL + 'cross-icon.png'; cross.className = 'icon'; // State variables let lastPoints = null; let isDecoding = false; let isMultiMaskMode = false; let imageDataURI = null; let imageInputs = null; let imageEmbeddings = null; async function decode() { if (!imageInputs || !imageEmbeddings) { return; } isDecoding = true; // Prepare inputs for decoding const reshaped = imageInputs.reshaped_input_sizes[0]; const points = lastPoints.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]]) const labels = lastPoints.map(x => BigInt(x.label)); const input_points = new Tensor( 'float32', points.flat(Infinity), [1, 1, points.length, 2], ) const input_labels = new Tensor( 'int64', labels.flat(Infinity), [1, 1, labels.length], ) // Generate the mask const { pred_masks, iou_scores } = await model({ ...imageEmbeddings, input_points, input_labels, }) // Post-process the mask const masks = await processor.post_process_masks( pred_masks, imageInputs.original_sizes, imageInputs.reshaped_input_sizes, ); const data = { mask: RawImage.fromTensor(masks[0][0]), scores: iou_scores.data, }; isDecoding = false; if (!isMultiMaskMode && lastPoints) { // Perform decoding with the last point decode(); lastPoints = null; } const { mask, scores } = data; // Update canvas dimensions (if different) if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { maskCanvas.width = mask.width; maskCanvas.height = mask.height; } // Create context and allocate buffer for pixel data const context = maskCanvas.getContext('2d'); const imageData = context.createImageData(maskCanvas.width, maskCanvas.height); // Select best mask const numMasks = scores.length; // 3 let bestIndex = 0; for (let i = 1; i < numMasks; ++i) { if (scores[i] > scores[bestIndex]) { bestIndex = i; } } statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; // Fill mask with colour const pixelData = imageData.data; for (let i = 0; i < pixelData.length; ++i) { if (mask.data[numMasks * i + bestIndex] === 1) { const offset = 4 * i; pixelData[offset] = 0; // red pixelData[offset + 1] = 114; // green pixelData[offset + 2] = 189; // blue pixelData[offset + 3] = 255; // alpha } } // Draw image data to context context.putImageData(imageData, 0, 0); } function clearPointsAndMask() { // Reset state isMultiMaskMode = false; lastPoints = null; // Remove points from previous mask (if any) document.querySelectorAll('.icon').forEach(e => e.remove()); // Disable cut button cutButton.disabled = true; // Reset mask canvas maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height); } clearButton.addEventListener('click', clearPointsAndMask); resetButton.addEventListener('click', () => { // Update state imageEmbeddings = null; imageDataURI = null; // Reset the state imageInputs = null; imageEmbeddings = null; isDecoding = false; // Clear points and mask (if present) clearPointsAndMask(); // Update UI cutButton.disabled = true; imageContainer.style.backgroundImage = 'none'; uploadButton.style.display = 'flex'; statusLabel.textContent = 'Ready'; }); async function segment(data) { statusLabel.textContent = 'Extracting image embedding...'; // Update state imageEmbeddings = null; imageDataURI = data; // Update UI imageContainer.style.backgroundImage = `url(${data})`; uploadButton.style.display = 'none'; cutButton.disabled = true; // Read the image and recompute image embeddings const image = await RawImage.read(data); imageInputs = await processor(image); imageEmbeddings = await model.get_image_embeddings(imageInputs) statusLabel.textContent = 'Embedding extracted!'; } // Handle file selection fileUpload.addEventListener('change', function (e) { const file = e.target.files[0]; if (!file) { return; } const reader = new FileReader(); // Set up a callback when the file is loaded reader.onload = e2 => segment(e2.target.result); reader.readAsDataURL(file); }); example.addEventListener('click', (e) => { e.preventDefault(); segment(EXAMPLE_URL); }); function addIcon({ point, label }) { const icon = (label === 1 ? star : cross).cloneNode(); icon.style.left = `${point[0] * 100}%`; icon.style.top = `${point[1] * 100}%`; imageContainer.appendChild(icon); } // Attach hover event to image container imageContainer.addEventListener('mousedown', e => { if (e.button !== 0 && e.button !== 2) { return; // Ignore other buttons } if (!imageEmbeddings) { return; // Ignore if not encoded yet } if (!isMultiMaskMode) { lastPoints = []; isMultiMaskMode = true; cutButton.disabled = false; } const point = getPoint(e); lastPoints.push(point); // add icon addIcon(point); decode(); }); // Clamp a value inside a range [min, max] function clamp(x, min = 0, max = 1) { return Math.max(Math.min(x, max), min) } function getPoint(e) { // Get bounding box const bb = imageContainer.getBoundingClientRect(); // Get the mouse coordinates relative to the container const mouseX = clamp((e.clientX - bb.left) / bb.width); const mouseY = clamp((e.clientY - bb.top) / bb.height); return { point: [mouseX, mouseY], label: e.button === 2 // right click ? 0 // negative prompt : 1, // positive prompt } } // Do not show context menu on right click imageContainer.addEventListener('contextmenu', e => { e.preventDefault(); }); // Attach hover event to image container imageContainer.addEventListener('mousemove', e => { if (!imageEmbeddings || isMultiMaskMode) { // Ignore mousemove events if the image is not encoded yet, // or we are in multi-mask mode return; } lastPoints = [getPoint(e)]; if (!isDecoding) { decode(); // Only decode if we are not already decoding } }); // Handle cut button click cutButton.addEventListener('click', () => { const [w, h] = [maskCanvas.width, maskCanvas.height]; // Get the mask pixel data const maskContext = maskCanvas.getContext('2d'); const maskPixelData = maskContext.getImageData(0, 0, w, h); // Load the image const image = new Image(); image.crossOrigin = 'anonymous'; image.onload = async () => { // Create a new canvas to hold the image const imageCanvas = new OffscreenCanvas(w, h); const imageContext = imageCanvas.getContext('2d'); imageContext.drawImage(image, 0, 0, w, h); const imagePixelData = imageContext.getImageData(0, 0, w, h); // Create a new canvas to hold the cut-out const cutCanvas = new OffscreenCanvas(w, h); const cutContext = cutCanvas.getContext('2d'); const cutPixelData = cutContext.getImageData(0, 0, w, h); // Copy the image pixel data to the cut canvas for (let i = 3; i < maskPixelData.data.length; i += 4) { if (maskPixelData.data[i] > 0) { for (let j = 0; j < 4; ++j) { const offset = i - j; cutPixelData.data[offset] = imagePixelData.data[offset]; } } } cutContext.putImageData(cutPixelData, 0, 0); // Download image const link = document.createElement('a'); link.download = 'image.png'; link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); link.click(); link.remove(); } image.src = imageDataURI; }); const model_id = 'Xenova/slimsam-77-uniform'; statusLabel.textContent = 'Loading model...'; const model = await SamModel.from_pretrained(model_id, { dtype: 'fp16', device: 'webgpu', }); const processor = await AutoProcessor.from_pretrained(model_id); statusLabel.textContent = 'Ready'; // Enable the user interface fileUpload.disabled = false; uploadButton.style.opacity = 1; example.style.pointerEvents = 'auto';