forked from Lukeasargen/GarbageML
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_stats.py
More file actions
32 lines (25 loc) · 883 Bytes
/
dataset_stats.py
File metadata and controls
32 lines (25 loc) · 883 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import time
import numpy as np
from torchvision.datasets import ImageFolder
from util import pil_loader, prepare_image
root = "data/full448"
num = 10000
sample_ds = ImageFolder(root=root)
mean = 0.0
var = 0.0
n = min(len(sample_ds), num) # Go through the whole dataset if possible
t0 = time.time()
t1 = t0
for i in range(n):
# img in shape [W, H, C]
img_path, y = sample_ds.samples[i]
img = np.array(pil_loader(img_path)) / 255.0
mean += np.mean(img, axis=(0, 1))
var += np.var(img, axis=(0, 1)) # you can add var, not std
if (i+1) % 100 == 0:
t2 = time.time()
print("{}/{} measured. Total time={:.2f}s. Images per second {:.2f}.".format(i+1, n, t2-t0, 100/(t2-t1)))
t1 = t2
print("mean = [{:4.3f}, {:4.3f}, {:4.3f}]".format(*(mean/n)))
print("std = [{:4.3f}, {:4.3f}, {:4.3f}]".format(*np.sqrt(var/n)))
print("var :", var/n)