-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmnist.py
More file actions
68 lines (54 loc) · 1.69 KB
/
mnist.py
File metadata and controls
68 lines (54 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import struct
from util import save_data, new, prefix
import dtypes
name = "mnist"
images_hash, labels_hash = 0x00000803, 0x00000801
files = [
{
"file": "train-images-idx3-ubyte",
"hash": images_hash,
"key": "X_train",
},
{
"file": "train-labels-idx1-ubyte",
"hash": labels_hash,
"key": "y_train",
},
{
"file": "t10k-images-idx3-ubyte",
"hash": images_hash,
"key": "X_test",
},
{
"file": "t10k-labels-idx1-ubyte",
"hash": labels_hash,
"key": "y_test",
},
]
def process(data, file):
file_path = prefix(name) + file["file"]
hash = file["hash"]
key = file["key"]
with open(file_path, 'rb') as f:
magic = struct.unpack(dtypes.INT, f.read(4))[0]
if magic != hash:
raise AssertionError(f"Magic number {magic} doesn't match for {file_path}")
n = struct.unpack(dtypes.INT, f.read(4))[0]
if magic == images_hash:
n_rows = struct.unpack(dtypes.INT, f.read(4))[0]
n_cols = struct.unpack(dtypes.INT, f.read(4))[0]
image_size = n_rows * n_cols
image_data = f.read(n * image_size)
pixels = struct.unpack(f">{n * image_size}B", image_data)
for i in range(n):
image = pixels[i * image_size : (i + 1) * image_size]
data[key].append(image)
elif magic == labels_hash:
label_data = f.read(n)
labels = struct.unpack(f">{n}B", label_data)
data[key].extend(labels)
if __name__ == "__main__":
data = new()
data["split"] = True
[process(data, f) for f in files]
save_data(data, name)