diff --git a/guild.yml b/guild.yml index f3488dd..800a292 100644 --- a/guild.yml +++ b/guild.yml @@ -1,4 +1,6 @@ - package: nvae-package + data-files: + - vae/splits/*.json - model: nvae-mixture-logistic sourcecode: diff --git a/prepare.py b/prepare.py index 22efe06..c90c1a9 100644 --- a/prepare.py +++ b/prepare.py @@ -1,7 +1,7 @@ import argparse +from joblib import Parallel, delayed from pathlib import Path import shutil -import pandas as pd import torchvision from tqdm import tqdm @@ -15,22 +15,31 @@ def image_path(directory, index): return directory / f'{index}.png' -def save_images(dataset, directory): +def prepare_and_save(image, directory, index, left, top, crop_size): + return ( + image.crop((left, top, left + crop_size, top + crop_size)) + .resize((problem.settings.WIDTH, problem.settings.HEIGHT)) + .save(image_path(directory, index)) + ) + + +def save_images(dataset, directory, n_jobs): original_size = (178, 218) crop_size = 148 left = (original_size[0] - crop_size) // 2 top = (original_size[1] - crop_size) // 2 - for index, (image, _) in enumerate(tqdm(dataset)): - ( - image.crop((left, top, left + crop_size, top + crop_size)) - .resize((problem.settings.WIDTH, problem.settings.HEIGHT)) - .save(image_path(directory, index)) + Parallel(n_jobs)( + delayed(prepare_and_save)( + image, directory, index, left, top, crop_size ) + for index, (image, _) in enumerate(tqdm(dataset)) + ) if __name__ == '__main__': parser = argparse.ArgumentParser() + parser.add_argument('--n_jobs', default=4, type=int) args = parser.parse_args() dataset = torchvision.datasets.CelebA( @@ -39,5 +48,5 @@ def save_images(dataset, directory): directory = Path('prepared') directory.mkdir(parents=True) - save_images(dataset, directory) + save_images(dataset, directory, args.n_jobs) shutil.rmtree(CACHE_ROOT)