-
Notifications
You must be signed in to change notification settings - Fork 106
Expand file tree
/
Copy pathcaltech_dataset.py
More file actions
59 lines (45 loc) · 2.06 KB
/
caltech_dataset.py
File metadata and controls
59 lines (45 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from torchvision.datasets import VisionDataset
from PIL import Image
import os
import os.path
import sys
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class Caltech(VisionDataset):
def __init__(self, root, split='train', transform=None, target_transform=None):
super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform)
self.split = split # This defines the split you are going to use
# (split files are called 'train.txt' and 'test.txt')
'''
- Here you should implement the logic for reading the splits files and accessing elements
- If the RAM size allows it, it is faster to store all data in memory
- PyTorch Dataset classes use indexes to read elements
- You should provide a way for the __getitem__ method to access the image-label pair
through the index
- Labels should start from 0, so for Caltech you will have lables 0...100 (excluding the background class)
'''
def __getitem__(self, index):
'''
__getitem__ should access an element through its index
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
'''
image, label = ... # Provide a way to access image and label via index
# Image should be a PIL Image
# label can be int
# Applies preprocessing when accessing the image
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
'''
The __len__ method returns the length of the dataset
It is mandatory, as this is used by several other components
'''
length = ... # Provide a way to get the length (number of elements) of the dataset
return length