diff --git a/deeplearning2020/datasets.py b/deeplearning2020/datasets.py index 1acbddc..a3f7958 100644 --- a/deeplearning2020/datasets.py +++ b/deeplearning2020/datasets.py @@ -40,26 +40,29 @@ class ImageWoof: BATCH_SIZE: int = 32 CLASS_NAMES: np.ndarray = None - data_dir: pathlib.Path image_count: int = 0 list_ds: "_tf.data.Dataset" = None - def __init__(self, dataset: str) -> None: + def __init__(self, dataset: str, file_path: str = None) -> None: if dataset not in ["train", "val"]: raise ValueError("Dataset not found") - - file_path = tf.keras.utils.get_file( - origin="https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2-320.tgz", - fname="imagewoof", - untar=True, - ) - self.data_dir = pathlib.Path(file_path + "2-320/" + dataset) - print(self.data_dir) - self.image_count = len(list(self.data_dir.glob("*/*.JPEG"))) + + if file_path == None: + file_path = tf.keras.utils.get_file( + origin="https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2-320.tgz", + fname="imagewoof", + untar=True, + ) + data_dir = file_path + "2-320/" + dataset + else: + data_dir = file_path + "/imagewoof2-320/" + dataset + print(data_dir) + # Might not work on *nix systems, see https://github.com/tensorflow/tensorflow/issues/20557 + self.image_count = len(list(tf.io.gfile.glob(data_dir + "/*/*.JPEG"))) print(f"Loaded {self.image_count} images") self.raw_class_names = [ - item.name for item in self.data_dir.glob("*") if item.name != "LICENSE.txt" + item.strip("/") for item in tf.io.gfile.listdir(data_dir) if item != "LICENSE.txt" ] self.raw_class_names.sort() @@ -77,15 +80,15 @@ def __init__(self, dataset: str) -> None: ) self.CLASS_NAMES = np.array([self.map_class(c) for c in self.raw_class_names]) - self.list_ds = tf.data.Dataset.list_files(str(self.data_dir / "*/*")) + self.list_ds = tf.data.Dataset.list_files(data_dir + "/*/*") @classmethod - def train(cls: typing.Type[ImageWoofType]) -> ImageWoofType: - return cls("train") + def train(cls: typing.Type[ImageWoofType], data_dir: str = None) -> ImageWoofType: + return cls("train", data_dir) @classmethod - def validation(cls: typing.Type[ImageWoofType]) -> ImageWoofType: - return cls("val") + def validation(cls: typing.Type[ImageWoofType], data_dir: str = None) -> ImageWoofType: + return cls("val", data_dir) def map_class(self, raw_cls: str) -> str: return self.class_name_mapping[raw_cls] @@ -117,10 +120,11 @@ def wrapped_load_data(self) -> "_tf.data.Dataset": @classmethod def load_data( cls: typing.Type[ImageWoofType], + data_dir: str = None ) -> typing.Tuple["_tf.data.Dataset", "_tf.data.Dataset", np.ndarray]: - train_ds = cls.train() + train_ds = cls.train(data_dir) return ( train_ds.wrapped_load_data(), - cls.validation().wrapped_load_data(), + cls.validation(data_dir).wrapped_load_data(), train_ds.CLASS_NAMES, )