Xenova HF staff commited on
Commit
b6b4104
1 Parent(s): 1af09e8

Update index.js

Browse files
Files changed (1) hide show
  1. index.js +325 -325
index.js CHANGED
@@ -1,325 +1,325 @@
1
- import { SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected]';
2
-
3
- // Reference the elements we will use
4
- const statusLabel = document.getElementById('status');
5
- const fileUpload = document.getElementById('upload');
6
- const imageContainer = document.getElementById('container');
7
- const example = document.getElementById('example');
8
- const maskCanvas = document.getElementById('mask-output');
9
- const uploadButton = document.getElementById('upload-button');
10
- const resetButton = document.getElementById('reset-image');
11
- const clearButton = document.getElementById('clear-points');
12
- const cutButton = document.getElementById('cut-mask');
13
-
14
- // State variables
15
- let lastPoints = null;
16
- let isDecoding = false;
17
- let isMultiMaskMode = false;
18
- let imageDataURI = null;
19
- let imageInputs = null;
20
- let imageEmbeddings = null;
21
-
22
- // Constants
23
- const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/';
24
- const EXAMPLE_URL = BASE_URL + 'corgi.jpg';
25
-
26
- // Preload star and cross images to avoid lag on first click
27
- const star = new Image();
28
- star.src = BASE_URL + 'star-icon.png';
29
- star.className = 'icon';
30
-
31
- const cross = new Image();
32
- cross.src = BASE_URL + 'cross-icon.png';
33
- cross.className = 'icon';
34
-
35
- async function decode() {
36
- if (!imageInputs || !imageEmbeddings) {
37
- return;
38
- }
39
- isDecoding = true;
40
-
41
- // Prepare inputs for decoding
42
- const reshaped = imageInputs.reshaped_input_sizes[0];
43
- const points = lastPoints.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]])
44
- const labels = lastPoints.map(x => BigInt(x.label));
45
-
46
- const input_points = new Tensor(
47
- 'float32',
48
- points.flat(Infinity),
49
- [1, 1, points.length, 2],
50
- )
51
- const input_labels = new Tensor(
52
- 'int64',
53
- labels.flat(Infinity),
54
- [1, 1, labels.length],
55
- )
56
-
57
- // Generate the mask
58
- const { pred_masks, iou_scores } = await model({
59
- ...imageEmbeddings,
60
- input_points,
61
- input_labels,
62
- })
63
-
64
- // Post-process the mask
65
- const masks = await processor.post_process_masks(
66
- pred_masks,
67
- imageInputs.original_sizes,
68
- imageInputs.reshaped_input_sizes,
69
- );
70
-
71
- const data = {
72
- mask: RawImage.fromTensor(masks[0][0]),
73
- scores: iou_scores.data,
74
- };
75
- isDecoding = false;
76
-
77
- if (!isMultiMaskMode && lastPoints) {
78
- // Perform decoding with the last point
79
- decode();
80
- lastPoints = null;
81
- }
82
-
83
- const { mask, scores } = data;
84
-
85
- // Update canvas dimensions (if different)
86
- if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
87
- maskCanvas.width = mask.width;
88
- maskCanvas.height = mask.height;
89
- }
90
-
91
- // Create context and allocate buffer for pixel data
92
- const context = maskCanvas.getContext('2d');
93
- const imageData = context.createImageData(maskCanvas.width, maskCanvas.height);
94
-
95
- // Select best mask
96
- const numMasks = scores.length; // 3
97
- let bestIndex = 0;
98
- for (let i = 1; i < numMasks; ++i) {
99
- if (scores[i] > scores[bestIndex]) {
100
- bestIndex = i;
101
- }
102
- }
103
- statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
104
-
105
- // Fill mask with colour
106
- const pixelData = imageData.data;
107
- for (let i = 0; i < pixelData.length; ++i) {
108
- if (mask.data[numMasks * i + bestIndex] === 1) {
109
- const offset = 4 * i;
110
- pixelData[offset] = 0; // red
111
- pixelData[offset + 1] = 114; // green
112
- pixelData[offset + 2] = 189; // blue
113
- pixelData[offset + 3] = 255; // alpha
114
- }
115
- }
116
-
117
- // Draw image data to context
118
- context.putImageData(imageData, 0, 0);
119
- }
120
-
121
- function clearPointsAndMask() {
122
- // Reset state
123
- isMultiMaskMode = false;
124
- lastPoints = null;
125
-
126
- // Remove points from previous mask (if any)
127
- document.querySelectorAll('.icon').forEach(e => e.remove());
128
-
129
- // Disable cut button
130
- cutButton.disabled = true;
131
-
132
- // Reset mask canvas
133
- maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height);
134
- }
135
- clearButton.addEventListener('click', clearPointsAndMask);
136
-
137
- resetButton.addEventListener('click', () => {
138
- // Update state
139
- imageEmbeddings = null;
140
- imageDataURI = null;
141
-
142
- // Reset the state
143
- imageInputs = null;
144
- imageEmbeddings = null;
145
- isDecoding = false;
146
-
147
- // Clear points and mask (if present)
148
- clearPointsAndMask();
149
-
150
- // Update UI
151
- cutButton.disabled = true;
152
- imageContainer.style.backgroundImage = 'none';
153
- uploadButton.style.display = 'flex';
154
- statusLabel.textContent = 'Ready';
155
- });
156
-
157
- async function segment(data) {
158
- statusLabel.textContent = 'Extracting image embedding...';
159
-
160
- // Update state
161
- imageEmbeddings = null;
162
- imageDataURI = data;
163
-
164
- // Update UI
165
- imageContainer.style.backgroundImage = `url(${data})`;
166
- uploadButton.style.display = 'none';
167
- cutButton.disabled = true;
168
-
169
- // Read the image and recompute image embeddings
170
- const image = await RawImage.read(data);
171
- imageInputs = await processor(image);
172
- imageEmbeddings = await model.get_image_embeddings(imageInputs)
173
-
174
- statusLabel.textContent = 'Embedding extracted!';
175
- }
176
-
177
- // Handle file selection
178
- fileUpload.addEventListener('change', function (e) {
179
- const file = e.target.files[0];
180
- if (!file) {
181
- return;
182
- }
183
-
184
- const reader = new FileReader();
185
-
186
- // Set up a callback when the file is loaded
187
- reader.onload = e2 => segment(e2.target.result);
188
-
189
- reader.readAsDataURL(file);
190
- });
191
-
192
- example.addEventListener('click', (e) => {
193
- e.preventDefault();
194
- segment(EXAMPLE_URL);
195
- });
196
-
197
- function addIcon({ point, label }) {
198
- const icon = (label === 1 ? star : cross).cloneNode();
199
- icon.style.left = `${point[0] * 100}%`;
200
- icon.style.top = `${point[1] * 100}%`;
201
- imageContainer.appendChild(icon);
202
- }
203
-
204
- // Attach hover event to image container
205
- imageContainer.addEventListener('mousedown', e => {
206
- if (e.button !== 0 && e.button !== 2) {
207
- return; // Ignore other buttons
208
- }
209
- if (!imageEmbeddings) {
210
- return; // Ignore if not encoded yet
211
- }
212
- if (!isMultiMaskMode) {
213
- lastPoints = [];
214
- isMultiMaskMode = true;
215
- cutButton.disabled = false;
216
- }
217
-
218
- const point = getPoint(e);
219
- lastPoints.push(point);
220
-
221
- // add icon
222
- addIcon(point);
223
-
224
- decode();
225
- });
226
-
227
-
228
- // Clamp a value inside a range [min, max]
229
- function clamp(x, min = 0, max = 1) {
230
- return Math.max(Math.min(x, max), min)
231
- }
232
-
233
- function getPoint(e) {
234
- // Get bounding box
235
- const bb = imageContainer.getBoundingClientRect();
236
-
237
- // Get the mouse coordinates relative to the container
238
- const mouseX = clamp((e.clientX - bb.left) / bb.width);
239
- const mouseY = clamp((e.clientY - bb.top) / bb.height);
240
-
241
- return {
242
- point: [mouseX, mouseY],
243
- label: e.button === 2 // right click
244
- ? 0 // negative prompt
245
- : 1, // positive prompt
246
- }
247
- }
248
-
249
- // Do not show context menu on right click
250
- imageContainer.addEventListener('contextmenu', e => {
251
- e.preventDefault();
252
- });
253
-
254
- // Attach hover event to image container
255
- imageContainer.addEventListener('mousemove', e => {
256
- if (!imageEmbeddings || isMultiMaskMode) {
257
- // Ignore mousemove events if the image is not encoded yet,
258
- // or we are in multi-mask mode
259
- return;
260
- }
261
- lastPoints = [getPoint(e)];
262
-
263
- if (!isDecoding) {
264
- decode(); // Only decode if we are not already decoding
265
- }
266
- });
267
-
268
- // Handle cut button click
269
- cutButton.addEventListener('click', () => {
270
- const [w, h] = [maskCanvas.width, maskCanvas.height];
271
-
272
- // Get the mask pixel data
273
- const maskContext = maskCanvas.getContext('2d');
274
- const maskPixelData = maskContext.getImageData(0, 0, w, h);
275
-
276
- // Load the image
277
- const image = new Image();
278
- image.crossOrigin = 'anonymous';
279
- image.onload = async () => {
280
- // Create a new canvas to hold the image
281
- const imageCanvas = new OffscreenCanvas(w, h);
282
- const imageContext = imageCanvas.getContext('2d');
283
- imageContext.drawImage(image, 0, 0, w, h);
284
- const imagePixelData = imageContext.getImageData(0, 0, w, h);
285
-
286
- // Create a new canvas to hold the cut-out
287
- const cutCanvas = new OffscreenCanvas(w, h);
288
- const cutContext = cutCanvas.getContext('2d');
289
- const cutPixelData = cutContext.getImageData(0, 0, w, h);
290
-
291
- // Copy the image pixel data to the cut canvas
292
- for (let i = 3; i < maskPixelData.data.length; i += 4) {
293
- if (maskPixelData.data[i] > 0) {
294
- for (let j = 0; j < 4; ++j) {
295
- const offset = i - j;
296
- cutPixelData.data[offset] = imagePixelData.data[offset];
297
- }
298
- }
299
- }
300
- cutContext.putImageData(cutPixelData, 0, 0);
301
-
302
- // Download image
303
- const link = document.createElement('a');
304
- link.download = 'image.png';
305
- link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
306
- link.click();
307
- link.remove();
308
- }
309
- image.src = imageDataURI;
310
- });
311
-
312
-
313
- const model_id = 'Xenova/slimsam-77-uniform';
314
- statusLabel.textContent = 'Loading model...';
315
- const model = await SamModel.from_pretrained(model_id, {
316
- dtype: 'fp16',
317
- device: 'webgpu',
318
- });
319
- const processor = await AutoProcessor.from_pretrained(model_id);
320
- statusLabel.textContent = 'Ready';
321
-
322
- // Enable the user interface
323
- fileUpload.disabled = false;
324
- uploadButton.style.opacity = 1;
325
- example.style.pointerEvents = 'auto';
 
1
+ import { SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected]';
2
+
3
+ // Reference the elements we will use
4
+ const statusLabel = document.getElementById('status');
5
+ const fileUpload = document.getElementById('upload');
6
+ const imageContainer = document.getElementById('container');
7
+ const example = document.getElementById('example');
8
+ const maskCanvas = document.getElementById('mask-output');
9
+ const uploadButton = document.getElementById('upload-button');
10
+ const resetButton = document.getElementById('reset-image');
11
+ const clearButton = document.getElementById('clear-points');
12
+ const cutButton = document.getElementById('cut-mask');
13
+
14
+ // Constants
15
+ const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/';
16
+ const EXAMPLE_URL = BASE_URL + 'corgi.jpg';
17
+
18
+ // Preload star and cross images to avoid lag on first click
19
+ const star = new Image();
20
+ star.src = BASE_URL + 'star-icon.png';
21
+ star.className = 'icon';
22
+
23
+ const cross = new Image();
24
+ cross.src = BASE_URL + 'cross-icon.png';
25
+ cross.className = 'icon';
26
+
27
+ // State variables
28
+ let lastPoints = null;
29
+ let isDecoding = false;
30
+ let isMultiMaskMode = false;
31
+ let imageDataURI = null;
32
+ let imageInputs = null;
33
+ let imageEmbeddings = null;
34
+
35
+ async function decode() {
36
+ if (!imageInputs || !imageEmbeddings) {
37
+ return;
38
+ }
39
+ isDecoding = true;
40
+
41
+ // Prepare inputs for decoding
42
+ const reshaped = imageInputs.reshaped_input_sizes[0];
43
+ const points = lastPoints.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]])
44
+ const labels = lastPoints.map(x => BigInt(x.label));
45
+
46
+ const input_points = new Tensor(
47
+ 'float32',
48
+ points.flat(Infinity),
49
+ [1, 1, points.length, 2],
50
+ )
51
+ const input_labels = new Tensor(
52
+ 'int64',
53
+ labels.flat(Infinity),
54
+ [1, 1, labels.length],
55
+ )
56
+
57
+ // Generate the mask
58
+ const { pred_masks, iou_scores } = await model({
59
+ ...imageEmbeddings,
60
+ input_points,
61
+ input_labels,
62
+ })
63
+
64
+ // Post-process the mask
65
+ const masks = await processor.post_process_masks(
66
+ pred_masks,
67
+ imageInputs.original_sizes,
68
+ imageInputs.reshaped_input_sizes,
69
+ );
70
+
71
+ const data = {
72
+ mask: RawImage.fromTensor(masks[0][0]),
73
+ scores: iou_scores.data,
74
+ };
75
+ isDecoding = false;
76
+
77
+ if (!isMultiMaskMode && lastPoints) {
78
+ // Perform decoding with the last point
79
+ decode();
80
+ lastPoints = null;
81
+ }
82
+
83
+ const { mask, scores } = data;
84
+
85
+ // Update canvas dimensions (if different)
86
+ if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
87
+ maskCanvas.width = mask.width;
88
+ maskCanvas.height = mask.height;
89
+ }
90
+
91
+ // Create context and allocate buffer for pixel data
92
+ const context = maskCanvas.getContext('2d');
93
+ const imageData = context.createImageData(maskCanvas.width, maskCanvas.height);
94
+
95
+ // Select best mask
96
+ const numMasks = scores.length; // 3
97
+ let bestIndex = 0;
98
+ for (let i = 1; i < numMasks; ++i) {
99
+ if (scores[i] > scores[bestIndex]) {
100
+ bestIndex = i;
101
+ }
102
+ }
103
+ statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
104
+
105
+ // Fill mask with colour
106
+ const pixelData = imageData.data;
107
+ for (let i = 0; i < pixelData.length; ++i) {
108
+ if (mask.data[numMasks * i + bestIndex] === 1) {
109
+ const offset = 4 * i;
110
+ pixelData[offset] = 0; // red
111
+ pixelData[offset + 1] = 114; // green
112
+ pixelData[offset + 2] = 189; // blue
113
+ pixelData[offset + 3] = 255; // alpha
114
+ }
115
+ }
116
+
117
+ // Draw image data to context
118
+ context.putImageData(imageData, 0, 0);
119
+ }
120
+
121
+ function clearPointsAndMask() {
122
+ // Reset state
123
+ isMultiMaskMode = false;
124
+ lastPoints = null;
125
+
126
+ // Remove points from previous mask (if any)
127
+ document.querySelectorAll('.icon').forEach(e => e.remove());
128
+
129
+ // Disable cut button
130
+ cutButton.disabled = true;
131
+
132
+ // Reset mask canvas
133
+ maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height);
134
+ }
135
+ clearButton.addEventListener('click', clearPointsAndMask);
136
+
137
+ resetButton.addEventListener('click', () => {
138
+ // Update state
139
+ imageEmbeddings = null;
140
+ imageDataURI = null;
141
+
142
+ // Reset the state
143
+ imageInputs = null;
144
+ imageEmbeddings = null;
145
+ isDecoding = false;
146
+
147
+ // Clear points and mask (if present)
148
+ clearPointsAndMask();
149
+
150
+ // Update UI
151
+ cutButton.disabled = true;
152
+ imageContainer.style.backgroundImage = 'none';
153
+ uploadButton.style.display = 'flex';
154
+ statusLabel.textContent = 'Ready';
155
+ });
156
+
157
+ async function segment(data) {
158
+ statusLabel.textContent = 'Extracting image embedding...';
159
+
160
+ // Update state
161
+ imageEmbeddings = null;
162
+ imageDataURI = data;
163
+
164
+ // Update UI
165
+ imageContainer.style.backgroundImage = `url(${data})`;
166
+ uploadButton.style.display = 'none';
167
+ cutButton.disabled = true;
168
+
169
+ // Read the image and recompute image embeddings
170
+ const image = await RawImage.read(data);
171
+ imageInputs = await processor(image);
172
+ imageEmbeddings = await model.get_image_embeddings(imageInputs)
173
+
174
+ statusLabel.textContent = 'Embedding extracted!';
175
+ }
176
+
177
+ // Handle file selection
178
+ fileUpload.addEventListener('change', function (e) {
179
+ const file = e.target.files[0];
180
+ if (!file) {
181
+ return;
182
+ }
183
+
184
+ const reader = new FileReader();
185
+
186
+ // Set up a callback when the file is loaded
187
+ reader.onload = e2 => segment(e2.target.result);
188
+
189
+ reader.readAsDataURL(file);
190
+ });
191
+
192
+ example.addEventListener('click', (e) => {
193
+ e.preventDefault();
194
+ segment(EXAMPLE_URL);
195
+ });
196
+
197
+ function addIcon({ point, label }) {
198
+ const icon = (label === 1 ? star : cross).cloneNode();
199
+ icon.style.left = `${point[0] * 100}%`;
200
+ icon.style.top = `${point[1] * 100}%`;
201
+ imageContainer.appendChild(icon);
202
+ }
203
+
204
+ // Attach hover event to image container
205
+ imageContainer.addEventListener('mousedown', e => {
206
+ if (e.button !== 0 && e.button !== 2) {
207
+ return; // Ignore other buttons
208
+ }
209
+ if (!imageEmbeddings) {
210
+ return; // Ignore if not encoded yet
211
+ }
212
+ if (!isMultiMaskMode) {
213
+ lastPoints = [];
214
+ isMultiMaskMode = true;
215
+ cutButton.disabled = false;
216
+ }
217
+
218
+ const point = getPoint(e);
219
+ lastPoints.push(point);
220
+
221
+ // add icon
222
+ addIcon(point);
223
+
224
+ decode();
225
+ });
226
+
227
+
228
+ // Clamp a value inside a range [min, max]
229
+ function clamp(x, min = 0, max = 1) {
230
+ return Math.max(Math.min(x, max), min)
231
+ }
232
+
233
+ function getPoint(e) {
234
+ // Get bounding box
235
+ const bb = imageContainer.getBoundingClientRect();
236
+
237
+ // Get the mouse coordinates relative to the container
238
+ const mouseX = clamp((e.clientX - bb.left) / bb.width);
239
+ const mouseY = clamp((e.clientY - bb.top) / bb.height);
240
+
241
+ return {
242
+ point: [mouseX, mouseY],
243
+ label: e.button === 2 // right click
244
+ ? 0 // negative prompt
245
+ : 1, // positive prompt
246
+ }
247
+ }
248
+
249
+ // Do not show context menu on right click
250
+ imageContainer.addEventListener('contextmenu', e => {
251
+ e.preventDefault();
252
+ });
253
+
254
+ // Attach hover event to image container
255
+ imageContainer.addEventListener('mousemove', e => {
256
+ if (!imageEmbeddings || isMultiMaskMode) {
257
+ // Ignore mousemove events if the image is not encoded yet,
258
+ // or we are in multi-mask mode
259
+ return;
260
+ }
261
+ lastPoints = [getPoint(e)];
262
+
263
+ if (!isDecoding) {
264
+ decode(); // Only decode if we are not already decoding
265
+ }
266
+ });
267
+
268
+ // Handle cut button click
269
+ cutButton.addEventListener('click', () => {
270
+ const [w, h] = [maskCanvas.width, maskCanvas.height];
271
+
272
+ // Get the mask pixel data
273
+ const maskContext = maskCanvas.getContext('2d');
274
+ const maskPixelData = maskContext.getImageData(0, 0, w, h);
275
+
276
+ // Load the image
277
+ const image = new Image();
278
+ image.crossOrigin = 'anonymous';
279
+ image.onload = async () => {
280
+ // Create a new canvas to hold the image
281
+ const imageCanvas = new OffscreenCanvas(w, h);
282
+ const imageContext = imageCanvas.getContext('2d');
283
+ imageContext.drawImage(image, 0, 0, w, h);
284
+ const imagePixelData = imageContext.getImageData(0, 0, w, h);
285
+
286
+ // Create a new canvas to hold the cut-out
287
+ const cutCanvas = new OffscreenCanvas(w, h);
288
+ const cutContext = cutCanvas.getContext('2d');
289
+ const cutPixelData = cutContext.getImageData(0, 0, w, h);
290
+
291
+ // Copy the image pixel data to the cut canvas
292
+ for (let i = 3; i < maskPixelData.data.length; i += 4) {
293
+ if (maskPixelData.data[i] > 0) {
294
+ for (let j = 0; j < 4; ++j) {
295
+ const offset = i - j;
296
+ cutPixelData.data[offset] = imagePixelData.data[offset];
297
+ }
298
+ }
299
+ }
300
+ cutContext.putImageData(cutPixelData, 0, 0);
301
+
302
+ // Download image
303
+ const link = document.createElement('a');
304
+ link.download = 'image.png';
305
+ link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
306
+ link.click();
307
+ link.remove();
308
+ }
309
+ image.src = imageDataURI;
310
+ });
311
+
312
+
313
+ const model_id = 'Xenova/slimsam-77-uniform';
314
+ statusLabel.textContent = 'Loading model...';
315
+ const model = await SamModel.from_pretrained(model_id, {
316
+ dtype: 'fp16',
317
+ device: 'webgpu',
318
+ });
319
+ const processor = await AutoProcessor.from_pretrained(model_id);
320
+ statusLabel.textContent = 'Ready';
321
+
322
+ // Enable the user interface
323
+ fileUpload.disabled = false;
324
+ uploadButton.style.opacity = 1;
325
+ example.style.pointerEvents = 'auto';