Spaces:
Runtime error
Runtime error
Add support for MNIST
Browse files- dataset_tool.py +44 -11
dataset_tool.py
CHANGED
@@ -13,6 +13,7 @@ import os
|
|
13 |
import pickle
|
14 |
import sys
|
15 |
import tarfile
|
|
|
16 |
import zipfile
|
17 |
from pathlib import Path
|
18 |
from typing import Callable, Optional, Tuple, Union
|
@@ -165,6 +166,36 @@ def open_cifar10(tarball: str, *, max_images: Optional[int]):
|
|
165 |
|
166 |
#----------------------------------------------------------------------------
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
def make_transform(
|
169 |
transform: Optional[str],
|
170 |
output_width: Optional[int],
|
@@ -225,10 +256,11 @@ def open_dataset(source, *, max_images: Optional[int]):
|
|
225 |
else:
|
226 |
return open_image_folder(source, max_images=max_images)
|
227 |
elif os.path.isfile(source):
|
228 |
-
if
|
229 |
return open_cifar10(source, max_images=max_images)
|
230 |
-
|
231 |
-
|
|
|
232 |
return open_image_zip(source, max_images=max_images)
|
233 |
else:
|
234 |
assert False, 'unknown archive type'
|
@@ -293,17 +325,18 @@ def convert_dataset(
|
|
293 |
The input dataset format is guessed from the --source argument:
|
294 |
|
295 |
\b
|
296 |
-
--source *_lmdb/
|
297 |
-
--source cifar-10-python.tar.gz
|
298 |
-
--source
|
299 |
-
--source
|
|
|
300 |
|
301 |
-
The output dataset format can be either an image folder or a zip archive.
|
302 |
-
the output format and path:
|
303 |
|
304 |
\b
|
305 |
-
--dest /path/to/dir
|
306 |
-
--dest /path/to/dataset.zip
|
307 |
|
308 |
Images within the dataset archive will be stored as uncompressed PNG.
|
309 |
|
|
|
13 |
import pickle
|
14 |
import sys
|
15 |
import tarfile
|
16 |
+
import gzip
|
17 |
import zipfile
|
18 |
from pathlib import Path
|
19 |
from typing import Callable, Optional, Tuple, Union
|
|
|
166 |
|
167 |
#----------------------------------------------------------------------------
|
168 |
|
169 |
+
def open_mnist(images_gz: str, *, max_images: Optional[int]):
|
170 |
+
labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
|
171 |
+
assert labels_gz != images_gz
|
172 |
+
images = []
|
173 |
+
labels = []
|
174 |
+
|
175 |
+
with gzip.open(images_gz, 'rb') as f:
|
176 |
+
images = np.frombuffer(f.read(), np.uint8, offset=16)
|
177 |
+
with gzip.open(labels_gz, 'rb') as f:
|
178 |
+
labels = np.frombuffer(f.read(), np.uint8, offset=8)
|
179 |
+
|
180 |
+
images = images.reshape(-1, 28, 28)
|
181 |
+
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
|
182 |
+
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
|
183 |
+
assert labels.shape == (60000,) and labels.dtype == np.uint8
|
184 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
185 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
186 |
+
|
187 |
+
max_idx = maybe_min(len(images), max_images)
|
188 |
+
|
189 |
+
def iterate_images():
|
190 |
+
for idx, img in enumerate(images):
|
191 |
+
yield dict(img=img, label=int(labels[idx]))
|
192 |
+
if idx >= max_idx-1:
|
193 |
+
break
|
194 |
+
|
195 |
+
return max_idx, iterate_images()
|
196 |
+
|
197 |
+
#----------------------------------------------------------------------------
|
198 |
+
|
199 |
def make_transform(
|
200 |
transform: Optional[str],
|
201 |
output_width: Optional[int],
|
|
|
256 |
else:
|
257 |
return open_image_folder(source, max_images=max_images)
|
258 |
elif os.path.isfile(source):
|
259 |
+
if os.path.basename(source) == 'cifar-10-python.tar.gz':
|
260 |
return open_cifar10(source, max_images=max_images)
|
261 |
+
elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
|
262 |
+
return open_mnist(source, max_images=max_images)
|
263 |
+
elif file_ext(source) == 'zip':
|
264 |
return open_image_zip(source, max_images=max_images)
|
265 |
else:
|
266 |
assert False, 'unknown archive type'
|
|
|
325 |
The input dataset format is guessed from the --source argument:
|
326 |
|
327 |
\b
|
328 |
+
--source *_lmdb/ Load LSUN dataset
|
329 |
+
--source cifar-10-python.tar.gz Load CIFAR-10 dataset
|
330 |
+
--source train-images-idx3-ubyte.gz Load MNIST dataset
|
331 |
+
--source path/ Recursively load all images from path/
|
332 |
+
--source dataset.zip Recursively load all images from dataset.zip
|
333 |
|
334 |
+
The output dataset format can be either an image folder or a zip archive.
|
335 |
+
Specifying the output format and path:
|
336 |
|
337 |
\b
|
338 |
+
--dest /path/to/dir Save output files under /path/to/dir
|
339 |
+
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
|
340 |
|
341 |
Images within the dataset archive will be stored as uncompressed PNG.
|
342 |
|