diff --git a/utils/dataset.py b/utils/dataset.py index aed86a4..469da4e 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -44,7 +44,9 @@ def __getitem__(self, index): with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f: image = load_image(f).convert('RGB') with open(image_path(self.labels_root, filename, '.png'), 'rb') as f: - label = load_image(f).convert('P') + label = load_image(f).convert('RGB') + r,g,b=label.split() + label=r if self.input_transform is not None: image = self.input_transform(image) @@ -88,7 +90,9 @@ def __getitem__(self, index): with open(image_path_city(self.images_root, filename), 'rb') as f: image = load_image(f).convert('RGB') with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: - label = load_image(f).convert('P') + label = load_image(f).convert('RGB') + r,g,b=label.split() + label=r if self.co_transform is not None: image, label = self.co_transform(image, label)