Xenova HF staff commited on
Commit
1af09e8
1 Parent(s): 685050c

Upload 3 files

Browse files
Files changed (3) hide show
  1. index.css +119 -0
  2. index.html +41 -28
  3. index.js +325 -79
index.css ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ box-sizing: border-box;
3
+ padding: 0;
4
+ margin: 0;
5
+ font-family: sans-serif;
6
+ }
7
+
8
+ html,
9
+ body {
10
+ height: 100%;
11
+ }
12
+
13
+ body {
14
+ padding: 16px 32px;
15
+ }
16
+
17
+ body,
18
+ #container,
19
+ #upload-button {
20
+ display: flex;
21
+ flex-direction: column;
22
+ justify-content: center;
23
+ align-items: center;
24
+ }
25
+
26
+ h1,
27
+ h3 {
28
+ text-align: center;
29
+ }
30
+
31
+ #container {
32
+ position: relative;
33
+ width: 640px;
34
+ height: 420px;
35
+ max-width: 100%;
36
+ max-height: 100%;
37
+ border: 2px dashed #D1D5DB;
38
+ border-radius: 0.75rem;
39
+ overflow: hidden;
40
+ cursor: pointer;
41
+ margin-top: 1rem;
42
+ background-size: 100% 100%;
43
+ background-position: center;
44
+ background-repeat: no-repeat;
45
+ }
46
+
47
+ #mask-output {
48
+ position: absolute;
49
+ width: 100%;
50
+ height: 100%;
51
+ pointer-events: none;
52
+ }
53
+
54
+ #upload-button {
55
+ gap: 0.4rem;
56
+ font-size: 18px;
57
+ cursor: pointer;
58
+ opacity: 0.2;
59
+ }
60
+
61
+ #upload {
62
+ display: none;
63
+ }
64
+
65
+ svg {
66
+ pointer-events: none;
67
+ }
68
+
69
+ #example {
70
+ font-size: 14px;
71
+ text-decoration: underline;
72
+ cursor: pointer;
73
+ pointer-events: none;
74
+ }
75
+
76
+ #example:hover {
77
+ color: #2563EB;
78
+ }
79
+
80
+ canvas {
81
+ position: absolute;
82
+ width: 100%;
83
+ height: 100%;
84
+ opacity: 0.6;
85
+ }
86
+
87
+ #status {
88
+ min-height: 16px;
89
+ margin: 8px 0;
90
+ }
91
+
92
+ .icon {
93
+ height: 16px;
94
+ width: 16px;
95
+ position: absolute;
96
+ transform: translate(-50%, -50%);
97
+ }
98
+
99
+ #controls>button {
100
+ padding: 6px 12px;
101
+ background-color: #3498db;
102
+ color: white;
103
+ border: 1px solid #2980b9;
104
+ border-radius: 5px;
105
+ cursor: pointer;
106
+ font-size: 16px;
107
+ }
108
+
109
+ #controls>button:disabled {
110
+ background-color: #d1d5db;
111
+ color: #6b7280;
112
+ border: 1px solid #9ca3af;
113
+ cursor: not-allowed;
114
+ }
115
+
116
+ #information {
117
+ margin-top: 0.25rem;
118
+ font-size: 15px;
119
+ }
index.html CHANGED
@@ -1,29 +1,42 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
-
4
- <head>
5
- <meta charset="UTF-8" />
6
- <link rel="stylesheet" href="style.css" />
7
-
8
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
- <title>Transformers.js - Object Detection</title>
10
- </head>
11
-
12
- <body>
13
- <h1>Object Detection w/ 🤗 Transformers.js</h1>
14
- <label id="container" for="upload">
15
- <svg width="25" height="25" viewBox="0 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg">
16
- <path fill="#000"
17
- d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z">
18
- </path>
19
- </svg>
20
- Click to upload image
21
- <label id="example">(or try example)</label>
22
- </label>
23
- <label id="status">Loading model...</label>
24
- <input id="upload" type="file" accept="image/*" />
25
-
26
- <script src="index.js" type="module"></script>
27
- </body>
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  </html>
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <link rel="stylesheet" href="index.css" />
7
+
8
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
+ <title>Transformers.js - Segment Anything WebGPU</title>
10
+ </head>
11
+
12
+ <body>
13
+ <h1>Segment Anything WebGPU</h1>
14
+ <h3>In-browser image segmentation w/ <a href="https://hf.co/docs/transformers.js" target="_blank">🤗
15
+ Transformers.js</a></h3>
16
+ <div id="container">
17
+ <label id="upload-button" for="upload">
18
+ <svg width="25" height="25" viewBox="0 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg">
19
+ <path fill="#000"
20
+ d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z">
21
+ </path>
22
+ </svg>
23
+ Click to upload image
24
+ <label id="example">(or try example)</label>
25
+ </label>
26
+ <canvas id="mask-output"></canvas>
27
+ </div>
28
+ <label id="status"></label>
29
+ <div id="controls">
30
+ <button id="reset-image">Reset image</button>
31
+ <button id="clear-points">Clear points</button>
32
+ <button id="cut-mask" disabled>Cut mask</button>
33
+ </div>
34
+ <p id="information">
35
+ Left click = positive points, right click = negative points.
36
+ </p>
37
+ <input id="upload" type="file" accept="image/*" disabled />
38
+
39
+ <script src="index.js" type="module"></script>
40
+ </body>
41
+
42
  </html>
index.js CHANGED
@@ -1,79 +1,325 @@
1
- import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.10.1';
2
-
3
- // Since we will download the model from the Hugging Face Hub, we can skip the local model check
4
- env.allowLocalModels = false;
5
-
6
- // Reference the elements that we will need
7
- const status = document.getElementById('status');
8
- const fileUpload = document.getElementById('upload');
9
- const imageContainer = document.getElementById('container');
10
- const example = document.getElementById('example');
11
-
12
- const EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg';
13
-
14
- // Create a new object detection pipeline
15
- status.textContent = 'Loading model...';
16
- const detector = await pipeline('object-detection', 'Xenova/detr-resnet-50');
17
- status.textContent = 'Ready';
18
-
19
- example.addEventListener('click', (e) => {
20
- e.preventDefault();
21
- detect(EXAMPLE_URL);
22
- });
23
-
24
- fileUpload.addEventListener('change', function (e) {
25
- const file = e.target.files[0];
26
- if (!file) {
27
- return;
28
- }
29
-
30
- const reader = new FileReader();
31
-
32
- // Set up a callback when the file is loaded
33
- reader.onload = e2 => detect(e2.target.result);
34
-
35
- reader.readAsDataURL(file);
36
- });
37
-
38
-
39
- // Detect objects in the image
40
- async function detect(img) {
41
- imageContainer.innerHTML = '';
42
- imageContainer.style.backgroundImage = `url(${img})`;
43
-
44
- status.textContent = 'Analysing...';
45
- const output = await detector(img, {
46
- threshold: 0.5,
47
- percentage: true,
48
- });
49
- status.textContent = '';
50
- output.forEach(renderBox);
51
- }
52
-
53
- // Render a bounding box and label on the image
54
- function renderBox({ box, label }) {
55
- const { xmax, xmin, ymax, ymin } = box;
56
-
57
- // Generate a random color for the box
58
- const color = '#' + Math.floor(Math.random() * 0xFFFFFF).toString(16).padStart(6, 0);
59
-
60
- // Draw the box
61
- const boxElement = document.createElement('div');
62
- boxElement.className = 'bounding-box';
63
- Object.assign(boxElement.style, {
64
- borderColor: color,
65
- left: 100 * xmin + '%',
66
- top: 100 * ymin + '%',
67
- width: 100 * (xmax - xmin) + '%',
68
- height: 100 * (ymax - ymin) + '%',
69
- })
70
-
71
- // Draw label
72
- const labelElement = document.createElement('span');
73
- labelElement.textContent = label;
74
- labelElement.className = 'bounding-box-label';
75
- labelElement.style.backgroundColor = color;
76
-
77
- boxElement.appendChild(labelElement);
78
- imageContainer.appendChild(boxElement);
79
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.4';
2
+
3
+ // Reference the elements we will use
4
+ const statusLabel = document.getElementById('status');
5
+ const fileUpload = document.getElementById('upload');
6
+ const imageContainer = document.getElementById('container');
7
+ const example = document.getElementById('example');
8
+ const maskCanvas = document.getElementById('mask-output');
9
+ const uploadButton = document.getElementById('upload-button');
10
+ const resetButton = document.getElementById('reset-image');
11
+ const clearButton = document.getElementById('clear-points');
12
+ const cutButton = document.getElementById('cut-mask');
13
+
14
+ // State variables
15
+ let lastPoints = null;
16
+ let isDecoding = false;
17
+ let isMultiMaskMode = false;
18
+ let imageDataURI = null;
19
+ let imageInputs = null;
20
+ let imageEmbeddings = null;
21
+
22
+ // Constants
23
+ const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/';
24
+ const EXAMPLE_URL = BASE_URL + 'corgi.jpg';
25
+
26
+ // Preload star and cross images to avoid lag on first click
27
+ const star = new Image();
28
+ star.src = BASE_URL + 'star-icon.png';
29
+ star.className = 'icon';
30
+
31
+ const cross = new Image();
32
+ cross.src = BASE_URL + 'cross-icon.png';
33
+ cross.className = 'icon';
34
+
35
+ async function decode() {
36
+ if (!imageInputs || !imageEmbeddings) {
37
+ return;
38
+ }
39
+ isDecoding = true;
40
+
41
+ // Prepare inputs for decoding
42
+ const reshaped = imageInputs.reshaped_input_sizes[0];
43
+ const points = lastPoints.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]])
44
+ const labels = lastPoints.map(x => BigInt(x.label));
45
+
46
+ const input_points = new Tensor(
47
+ 'float32',
48
+ points.flat(Infinity),
49
+ [1, 1, points.length, 2],
50
+ )
51
+ const input_labels = new Tensor(
52
+ 'int64',
53
+ labels.flat(Infinity),
54
+ [1, 1, labels.length],
55
+ )
56
+
57
+ // Generate the mask
58
+ const { pred_masks, iou_scores } = await model({
59
+ ...imageEmbeddings,
60
+ input_points,
61
+ input_labels,
62
+ })
63
+
64
+ // Post-process the mask
65
+ const masks = await processor.post_process_masks(
66
+ pred_masks,
67
+ imageInputs.original_sizes,
68
+ imageInputs.reshaped_input_sizes,
69
+ );
70
+
71
+ const data = {
72
+ mask: RawImage.fromTensor(masks[0][0]),
73
+ scores: iou_scores.data,
74
+ };
75
+ isDecoding = false;
76
+
77
+ if (!isMultiMaskMode && lastPoints) {
78
+ // Perform decoding with the last point
79
+ decode();
80
+ lastPoints = null;
81
+ }
82
+
83
+ const { mask, scores } = data;
84
+
85
+ // Update canvas dimensions (if different)
86
+ if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
87
+ maskCanvas.width = mask.width;
88
+ maskCanvas.height = mask.height;
89
+ }
90
+
91
+ // Create context and allocate buffer for pixel data
92
+ const context = maskCanvas.getContext('2d');
93
+ const imageData = context.createImageData(maskCanvas.width, maskCanvas.height);
94
+
95
+ // Select best mask
96
+ const numMasks = scores.length; // 3
97
+ let bestIndex = 0;
98
+ for (let i = 1; i < numMasks; ++i) {
99
+ if (scores[i] > scores[bestIndex]) {
100
+ bestIndex = i;
101
+ }
102
+ }
103
+ statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
104
+
105
+ // Fill mask with colour
106
+ const pixelData = imageData.data;
107
+ for (let i = 0; i < pixelData.length; ++i) {
108
+ if (mask.data[numMasks * i + bestIndex] === 1) {
109
+ const offset = 4 * i;
110
+ pixelData[offset] = 0; // red
111
+ pixelData[offset + 1] = 114; // green
112
+ pixelData[offset + 2] = 189; // blue
113
+ pixelData[offset + 3] = 255; // alpha
114
+ }
115
+ }
116
+
117
+ // Draw image data to context
118
+ context.putImageData(imageData, 0, 0);
119
+ }
120
+
121
+ function clearPointsAndMask() {
122
+ // Reset state
123
+ isMultiMaskMode = false;
124
+ lastPoints = null;
125
+
126
+ // Remove points from previous mask (if any)
127
+ document.querySelectorAll('.icon').forEach(e => e.remove());
128
+
129
+ // Disable cut button
130
+ cutButton.disabled = true;
131
+
132
+ // Reset mask canvas
133
+ maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height);
134
+ }
135
+ clearButton.addEventListener('click', clearPointsAndMask);
136
+
137
+ resetButton.addEventListener('click', () => {
138
+ // Update state
139
+ imageEmbeddings = null;
140
+ imageDataURI = null;
141
+
142
+ // Reset the state
143
+ imageInputs = null;
144
+ imageEmbeddings = null;
145
+ isDecoding = false;
146
+
147
+ // Clear points and mask (if present)
148
+ clearPointsAndMask();
149
+
150
+ // Update UI
151
+ cutButton.disabled = true;
152
+ imageContainer.style.backgroundImage = 'none';
153
+ uploadButton.style.display = 'flex';
154
+ statusLabel.textContent = 'Ready';
155
+ });
156
+
157
+ async function segment(data) {
158
+ statusLabel.textContent = 'Extracting image embedding...';
159
+
160
+ // Update state
161
+ imageEmbeddings = null;
162
+ imageDataURI = data;
163
+
164
+ // Update UI
165
+ imageContainer.style.backgroundImage = `url(${data})`;
166
+ uploadButton.style.display = 'none';
167
+ cutButton.disabled = true;
168
+
169
+ // Read the image and recompute image embeddings
170
+ const image = await RawImage.read(data);
171
+ imageInputs = await processor(image);
172
+ imageEmbeddings = await model.get_image_embeddings(imageInputs)
173
+
174
+ statusLabel.textContent = 'Embedding extracted!';
175
+ }
176
+
177
+ // Handle file selection
178
+ fileUpload.addEventListener('change', function (e) {
179
+ const file = e.target.files[0];
180
+ if (!file) {
181
+ return;
182
+ }
183
+
184
+ const reader = new FileReader();
185
+
186
+ // Set up a callback when the file is loaded
187
+ reader.onload = e2 => segment(e2.target.result);
188
+
189
+ reader.readAsDataURL(file);
190
+ });
191
+
192
+ example.addEventListener('click', (e) => {
193
+ e.preventDefault();
194
+ segment(EXAMPLE_URL);
195
+ });
196
+
197
+ function addIcon({ point, label }) {
198
+ const icon = (label === 1 ? star : cross).cloneNode();
199
+ icon.style.left = `${point[0] * 100}%`;
200
+ icon.style.top = `${point[1] * 100}%`;
201
+ imageContainer.appendChild(icon);
202
+ }
203
+
204
+ // Attach hover event to image container
205
+ imageContainer.addEventListener('mousedown', e => {
206
+ if (e.button !== 0 && e.button !== 2) {
207
+ return; // Ignore other buttons
208
+ }
209
+ if (!imageEmbeddings) {
210
+ return; // Ignore if not encoded yet
211
+ }
212
+ if (!isMultiMaskMode) {
213
+ lastPoints = [];
214
+ isMultiMaskMode = true;
215
+ cutButton.disabled = false;
216
+ }
217
+
218
+ const point = getPoint(e);
219
+ lastPoints.push(point);
220
+
221
+ // add icon
222
+ addIcon(point);
223
+
224
+ decode();
225
+ });
226
+
227
+
228
+ // Clamp a value inside a range [min, max]
229
+ function clamp(x, min = 0, max = 1) {
230
+ return Math.max(Math.min(x, max), min)
231
+ }
232
+
233
+ function getPoint(e) {
234
+ // Get bounding box
235
+ const bb = imageContainer.getBoundingClientRect();
236
+
237
+ // Get the mouse coordinates relative to the container
238
+ const mouseX = clamp((e.clientX - bb.left) / bb.width);
239
+ const mouseY = clamp((e.clientY - bb.top) / bb.height);
240
+
241
+ return {
242
+ point: [mouseX, mouseY],
243
+ label: e.button === 2 // right click
244
+ ? 0 // negative prompt
245
+ : 1, // positive prompt
246
+ }
247
+ }
248
+
249
+ // Do not show context menu on right click
250
+ imageContainer.addEventListener('contextmenu', e => {
251
+ e.preventDefault();
252
+ });
253
+
254
+ // Attach hover event to image container
255
+ imageContainer.addEventListener('mousemove', e => {
256
+ if (!imageEmbeddings || isMultiMaskMode) {
257
+ // Ignore mousemove events if the image is not encoded yet,
258
+ // or we are in multi-mask mode
259
+ return;
260
+ }
261
+ lastPoints = [getPoint(e)];
262
+
263
+ if (!isDecoding) {
264
+ decode(); // Only decode if we are not already decoding
265
+ }
266
+ });
267
+
268
+ // Handle cut button click
269
+ cutButton.addEventListener('click', () => {
270
+ const [w, h] = [maskCanvas.width, maskCanvas.height];
271
+
272
+ // Get the mask pixel data
273
+ const maskContext = maskCanvas.getContext('2d');
274
+ const maskPixelData = maskContext.getImageData(0, 0, w, h);
275
+
276
+ // Load the image
277
+ const image = new Image();
278
+ image.crossOrigin = 'anonymous';
279
+ image.onload = async () => {
280
+ // Create a new canvas to hold the image
281
+ const imageCanvas = new OffscreenCanvas(w, h);
282
+ const imageContext = imageCanvas.getContext('2d');
283
+ imageContext.drawImage(image, 0, 0, w, h);
284
+ const imagePixelData = imageContext.getImageData(0, 0, w, h);
285
+
286
+ // Create a new canvas to hold the cut-out
287
+ const cutCanvas = new OffscreenCanvas(w, h);
288
+ const cutContext = cutCanvas.getContext('2d');
289
+ const cutPixelData = cutContext.getImageData(0, 0, w, h);
290
+
291
+ // Copy the image pixel data to the cut canvas
292
+ for (let i = 3; i < maskPixelData.data.length; i += 4) {
293
+ if (maskPixelData.data[i] > 0) {
294
+ for (let j = 0; j < 4; ++j) {
295
+ const offset = i - j;
296
+ cutPixelData.data[offset] = imagePixelData.data[offset];
297
+ }
298
+ }
299
+ }
300
+ cutContext.putImageData(cutPixelData, 0, 0);
301
+
302
+ // Download image
303
+ const link = document.createElement('a');
304
+ link.download = 'image.png';
305
+ link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
306
+ link.click();
307
+ link.remove();
308
+ }
309
+ image.src = imageDataURI;
310
+ });
311
+
312
+
313
+ const model_id = 'Xenova/slimsam-77-uniform';
314
+ statusLabel.textContent = 'Loading model...';
315
+ const model = await SamModel.from_pretrained(model_id, {
316
+ dtype: 'fp16',
317
+ device: 'webgpu',
318
+ });
319
+ const processor = await AutoProcessor.from_pretrained(model_id);
320
+ statusLabel.textContent = 'Ready';
321
+
322
+ // Enable the user interface
323
+ fileUpload.disabled = false;
324
+ uploadButton.style.opacity = 1;
325
+ example.style.pointerEvents = 'auto';