-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_sample.py
More file actions
31 lines (25 loc) · 875 Bytes
/
data_sample.py
File metadata and controls
31 lines (25 loc) · 875 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import numpy as np
class CategoriesSampler():
def __init__(self, label, chosen_classes, n_batch, n_per):
self.n_batch = n_batch
self.chosen_classes = chosen_classes
self.n_per = n_per
label = np.array(label)
#print(label)
self.m_ind = []
for i in range(max(label) + 1):
ind = np.argwhere(label == i).reshape(-1)
ind = torch.from_numpy(ind)
self.m_ind.append(ind)
def __len__(self):
return self.n_batch
def __iter__(self):
for i_batch in range(self.n_batch):
batch = []
for c in self.chosen_classes:
l = self.m_ind[c]
pos = torch.randperm(len(l))[:self.n_per]
batch.append(l[pos])
batch = torch.stack(batch).t().reshape(-1)
yield batch