-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathprepare_data.py
More file actions
59 lines (45 loc) · 1.35 KB
/
prepare_data.py
File metadata and controls
59 lines (45 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Prepare MNIST and CIFAR10 datasets for the experiments."""
import os
import shutil
import wget
import numpy as np
from sklearn.datasets import fetch_openml
def get_mnist(data_root):
"""Get MNIST files."""
# download
print("Downloading MNIST")
X, y = fetch_openml(
name="mnist_784",
version=1,
return_X_y=True,
data_home="tmp",
as_frame=False,
)
# dump to disk
print("Unpacking MNIST")
mnist_dir = os.path.join(data_root, "mnist-784")
os.makedirs(mnist_dir, exist_ok=True)
with open(os.path.join(mnist_dir, "images.npy"), "wb") as fh:
np.save(fh, X)
with open(os.path.join(mnist_dir, "labels.npy"), "wb") as fh:
np.save(fh, y)
# remove cache
shutil.rmtree("tmp")
def get_cifar10(data_root):
"""Get CIFAR10 files."""
# download
print("Downloading CIFAR10")
data_path = wget.download(
"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", data_root
)
# unpack archive
print("\nUnpacking CIFAR10...")
os.system(f"tar -xf {data_path} -C {data_root}")
# remove archive
os.remove(data_path)
if __name__ == "__main__":
data_root = os.getenv("NDL_DATA_ROOT", "data")
shutil.rmtree(data_root, ignore_errors=True)
os.makedirs(data_root, exist_ok=True)
get_mnist(data_root)
get_cifar10(data_root)