Skip to content

Commit 58e5a9f

Browse files
committed
Fix target_transform usage and missing len method
1 parent bea6e5e commit 58e5a9f

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

hdc/datasets/airfoil_self_noise.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def __init__(
4848

4949
self._load_data()
5050

51+
def __len__(self) -> int:
52+
return self.data.size(0)
53+
5154
def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
5255
"""
5356
Args:
@@ -62,8 +65,8 @@ def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, torch.FloatTensor]
6265
if self.transform:
6366
sample = self.transform(sample)
6467

65-
if self.transform:
66-
label = self.transform(label)
68+
if self.target_transform:
69+
label = self.target_transform(label)
6770

6871
return sample, label
6972

hdc/datasets/isolet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, torch.LongTensor]:
9898
if self.transform:
9999
sample = self.transform(sample)
100100

101-
if self.transform:
102-
label = self.transform(label)
101+
if self.target_transform:
102+
label = self.target_transform(label)
103103

104104
return sample, label
105105

0 commit comments

Comments
 (0)