Skip to content

Commit b1ed6d3

Browse files
dheyaymikeheddes
andauthored
Add power plant regression dataset (#25)
* Added power plant regression dataset * Removed extra files * Updates Co-authored-by: mikeheddes <mikeheddes@gmail.com>
1 parent 44cb725 commit b1ed6d3

File tree

4 files changed

+116
-2
lines changed

4 files changed

+116
-2
lines changed

dev-requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,4 @@ numpy
66
sphinx
77
sphinx-rtd-theme
88
flake8
9-
pytest
10-
black
9+
pytest

docs/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ The hdc library provides many popular built-in datasets to work with.
1616
AirfoilSelfNoise
1717
EMGHandGestures
1818
PAMAP
19+
CyclePowerPlant

hdc/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from hdc.datasets.airfoil_self_noise import AirfoilSelfNoise
66
from hdc.datasets.emg_hand_gestures import EMGHandGestures
77
from hdc.datasets.pamap import PAMAP
8+
from hdc.datasets.ccpp import CyclePowerPlant
89

910
__all__ = [
1011
"BeijingAirQuality",
@@ -14,4 +15,5 @@
1415
"AirfoilSelfNoise",
1516
"EMGHandGestures",
1617
"PAMAP",
18+
"CyclePowerPlant",
1719
]

hdc/datasets/ccpp.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import os
2+
import os.path as path
3+
from typing import Callable, Optional, Tuple, List
4+
import torch
5+
import pandas as pd
6+
import numpy as np
7+
from torch.utils import data
8+
9+
from .utils import download_file, unzip_file
10+
11+
class CyclePowerPlant(data.Dataset):
12+
"""Combined cycle power planet dataset <https://archive.ics.uci.edu/ml/datasets/combined+cycle+power+plant>`_,
13+
Features consist of hourly average ambient variables Temperature (T), Ambient Pressure (AP), Relative Humidity (RH) and Exhaust Vacuum (V) to predict the net hourly electrical energy output (EP) of the plant.
14+
15+
Args:
16+
root (string): Root directory of dataset where downloaded dataset exists
17+
download (bool, optional): If True, downloads the dataset from the internet and
18+
puts it in root directory. If dataset is already downloaded, it is not
19+
downloaded again.
20+
transform (callable, optional): A function/transform that takes in an torch.FloatTensor
21+
and returns a transformed version.
22+
target_transform (callable, optional): A function/transform that takes in the
23+
target and transforms it.
24+
"""
25+
def __init__(
26+
self,
27+
root:str,
28+
download: bool = False,
29+
transform: Optional[Callable] = None,
30+
target_transform: Optional[Callable] = None,
31+
):
32+
root = path.join(root, "ccpp")
33+
root = os.path.expanduser(root)
34+
self.root = root
35+
os.makedirs(self.root, exist_ok=True)
36+
37+
self.transform = transform
38+
self.target_transform = target_transform
39+
40+
if download:
41+
self.download()
42+
43+
if not self._check_integrity():
44+
raise RuntimeError(
45+
"Dataset not found or corrupted. You can use download=True to download it"
46+
)
47+
48+
self._load_data()
49+
50+
def __len__(self) -> int:
51+
return self.data.size(0)
52+
53+
def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
54+
"""
55+
Args:
56+
index (int): Index
57+
58+
Returns:
59+
Tuple[torch.FloatTensor, torch.FloatTensor]: (sample, target) where target is the index of the target class
60+
"""
61+
sample = self.data[index]
62+
label = self.targets[index]
63+
64+
if self.transform:
65+
sample = self.transform(sample)
66+
67+
if self.target_transform:
68+
label = self.target_transform(label)
69+
70+
return sample, label
71+
72+
def _check_integrity(self) -> bool:
73+
if not os.path.isdir(self.root):
74+
return False
75+
76+
# Check if root directory contains the required data file
77+
has_data_file = os.path.isfile(os.path.join(self.root, "Folds5x2_pp.xlsx"))
78+
if has_data_file:
79+
return True
80+
81+
return False
82+
83+
def _load_data(self):
84+
file_name = "Folds5x2_pp.xlsx"
85+
data = pd.read_excel(os.path.join(self.root, file_name))
86+
self.data = torch.tensor(data.values[:, :-1], dtype=torch.float)
87+
self.targets = torch.tensor(data.values[:, -1], dtype=torch.float)
88+
89+
def download(self):
90+
"""Downloads the dataset if not already present"""
91+
92+
if self._check_integrity():
93+
print("Files already downloaded and verified")
94+
return
95+
96+
zip_file_path = os.path.join(self.root, "data.zip")
97+
download_file(
98+
"https://archive.ics.uci.edu/ml/machine-learning-databases/00294/CCPP.zip",
99+
zip_file_path
100+
)
101+
102+
unzip_file(zip_file_path, self.root)
103+
os.remove(zip_file_path)
104+
105+
source_dir = os.path.join(self.root, "CCPP")
106+
data_files = os.listdir(source_dir)
107+
for filename in data_files:
108+
os.rename(
109+
os.path.join(source_dir, filename), os.path.join(self.root, filename)
110+
)
111+
112+
os.rmdir(source_dir)

0 commit comments

Comments
 (0)