|
| 1 | +import os |
| 2 | +import os.path as path |
| 3 | +from typing import Callable, Optional, Tuple, List |
| 4 | +import torch |
| 5 | +import pandas as pd |
| 6 | +import numpy as np |
| 7 | +from torch.utils import data |
| 8 | + |
| 9 | +from .utils import download_file, unzip_file |
| 10 | + |
| 11 | +class CyclePowerPlant(data.Dataset): |
| 12 | + """Combined cycle power planet dataset <https://archive.ics.uci.edu/ml/datasets/combined+cycle+power+plant>`_, |
| 13 | + Features consist of hourly average ambient variables Temperature (T), Ambient Pressure (AP), Relative Humidity (RH) and Exhaust Vacuum (V) to predict the net hourly electrical energy output (EP) of the plant. |
| 14 | +
|
| 15 | + Args: |
| 16 | + root (string): Root directory of dataset where downloaded dataset exists |
| 17 | + download (bool, optional): If True, downloads the dataset from the internet and |
| 18 | + puts it in root directory. If dataset is already downloaded, it is not |
| 19 | + downloaded again. |
| 20 | + transform (callable, optional): A function/transform that takes in an torch.FloatTensor |
| 21 | + and returns a transformed version. |
| 22 | + target_transform (callable, optional): A function/transform that takes in the |
| 23 | + target and transforms it. |
| 24 | + """ |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + root:str, |
| 28 | + download: bool = False, |
| 29 | + transform: Optional[Callable] = None, |
| 30 | + target_transform: Optional[Callable] = None, |
| 31 | + ): |
| 32 | + root = path.join(root, "ccpp") |
| 33 | + root = os.path.expanduser(root) |
| 34 | + self.root = root |
| 35 | + os.makedirs(self.root, exist_ok=True) |
| 36 | + |
| 37 | + self.transform = transform |
| 38 | + self.target_transform = target_transform |
| 39 | + |
| 40 | + if download: |
| 41 | + self.download() |
| 42 | + |
| 43 | + if not self._check_integrity(): |
| 44 | + raise RuntimeError( |
| 45 | + "Dataset not found or corrupted. You can use download=True to download it" |
| 46 | + ) |
| 47 | + |
| 48 | + self._load_data() |
| 49 | + |
| 50 | + def __len__(self) -> int: |
| 51 | + return self.data.size(0) |
| 52 | + |
| 53 | + def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, torch.FloatTensor]: |
| 54 | + """ |
| 55 | + Args: |
| 56 | + index (int): Index |
| 57 | +
|
| 58 | + Returns: |
| 59 | + Tuple[torch.FloatTensor, torch.FloatTensor]: (sample, target) where target is the index of the target class |
| 60 | + """ |
| 61 | + sample = self.data[index] |
| 62 | + label = self.targets[index] |
| 63 | + |
| 64 | + if self.transform: |
| 65 | + sample = self.transform(sample) |
| 66 | + |
| 67 | + if self.target_transform: |
| 68 | + label = self.target_transform(label) |
| 69 | + |
| 70 | + return sample, label |
| 71 | + |
| 72 | + def _check_integrity(self) -> bool: |
| 73 | + if not os.path.isdir(self.root): |
| 74 | + return False |
| 75 | + |
| 76 | + # Check if root directory contains the required data file |
| 77 | + has_data_file = os.path.isfile(os.path.join(self.root, "Folds5x2_pp.xlsx")) |
| 78 | + if has_data_file: |
| 79 | + return True |
| 80 | + |
| 81 | + return False |
| 82 | + |
| 83 | + def _load_data(self): |
| 84 | + file_name = "Folds5x2_pp.xlsx" |
| 85 | + data = pd.read_excel(os.path.join(self.root, file_name)) |
| 86 | + self.data = torch.tensor(data.values[:, :-1], dtype=torch.float) |
| 87 | + self.targets = torch.tensor(data.values[:, -1], dtype=torch.float) |
| 88 | + |
| 89 | + def download(self): |
| 90 | + """Downloads the dataset if not already present""" |
| 91 | + |
| 92 | + if self._check_integrity(): |
| 93 | + print("Files already downloaded and verified") |
| 94 | + return |
| 95 | + |
| 96 | + zip_file_path = os.path.join(self.root, "data.zip") |
| 97 | + download_file( |
| 98 | + "https://archive.ics.uci.edu/ml/machine-learning-databases/00294/CCPP.zip", |
| 99 | + zip_file_path |
| 100 | + ) |
| 101 | + |
| 102 | + unzip_file(zip_file_path, self.root) |
| 103 | + os.remove(zip_file_path) |
| 104 | + |
| 105 | + source_dir = os.path.join(self.root, "CCPP") |
| 106 | + data_files = os.listdir(source_dir) |
| 107 | + for filename in data_files: |
| 108 | + os.rename( |
| 109 | + os.path.join(source_dir, filename), os.path.join(self.root, filename) |
| 110 | + ) |
| 111 | + |
| 112 | + os.rmdir(source_dir) |
0 commit comments