Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions deeplearning2020/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,29 @@
class ImageWoof:
BATCH_SIZE: int = 32
CLASS_NAMES: np.ndarray = None
data_dir: pathlib.Path
image_count: int = 0
list_ds: "_tf.data.Dataset" = None

def __init__(self, dataset: str) -> None:
def __init__(self, dataset: str, file_path: str = None) -> None:
if dataset not in ["train", "val"]:
raise ValueError("Dataset not found")

file_path = tf.keras.utils.get_file(
origin="https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2-320.tgz",
fname="imagewoof",
untar=True,
)
self.data_dir = pathlib.Path(file_path + "2-320/" + dataset)
print(self.data_dir)
self.image_count = len(list(self.data_dir.glob("*/*.JPEG")))

if file_path == None:
file_path = tf.keras.utils.get_file(
origin="https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2-320.tgz",
fname="imagewoof",
untar=True,
)
data_dir = file_path + "2-320/" + dataset
else:
data_dir = file_path + "/imagewoof2-320/" + dataset
print(data_dir)
# Might not work on *nix systems, see https://github.com/tensorflow/tensorflow/issues/20557
self.image_count = len(list(tf.io.gfile.glob(data_dir + "/*/*.JPEG")))
print(f"Loaded {self.image_count} images")

self.raw_class_names = [
item.name for item in self.data_dir.glob("*") if item.name != "LICENSE.txt"
item.strip("/") for item in tf.io.gfile.listdir(data_dir) if item != "LICENSE.txt"
]
self.raw_class_names.sort()

Expand All @@ -77,15 +80,15 @@ def __init__(self, dataset: str) -> None:
)

self.CLASS_NAMES = np.array([self.map_class(c) for c in self.raw_class_names])
self.list_ds = tf.data.Dataset.list_files(str(self.data_dir / "*/*"))
self.list_ds = tf.data.Dataset.list_files(data_dir + "/*/*")

@classmethod
def train(cls: typing.Type[ImageWoofType]) -> ImageWoofType:
return cls("train")
def train(cls: typing.Type[ImageWoofType], data_dir: str = None) -> ImageWoofType:
return cls("train", data_dir)

@classmethod
def validation(cls: typing.Type[ImageWoofType]) -> ImageWoofType:
return cls("val")
def validation(cls: typing.Type[ImageWoofType], data_dir: str = None) -> ImageWoofType:
return cls("val", data_dir)

def map_class(self, raw_cls: str) -> str:
return self.class_name_mapping[raw_cls]
Expand Down Expand Up @@ -117,10 +120,11 @@ def wrapped_load_data(self) -> "_tf.data.Dataset":
@classmethod
def load_data(
cls: typing.Type[ImageWoofType],
data_dir: str = None
) -> typing.Tuple["_tf.data.Dataset", "_tf.data.Dataset", np.ndarray]:
train_ds = cls.train()
train_ds = cls.train(data_dir)
return (
train_ds.wrapped_load_data(),
cls.validation().wrapped_load_data(),
cls.validation(data_dir).wrapped_load_data(),
train_ds.CLASS_NAMES,
)