@@ -21,46 +21,96 @@ install the "gpu" version of PyTorch.<br>
2121 import pytorch_pro_gan.PRO_GAN as pg
2222
2323 Use the modules ` pg.Generator ` , ` pg.Discriminator ` and
24- ` pg.ProGAN ` .
25-
26- Help on class ProGAN in module pro_gan_pytorch.PRO_GAN:
27-
28- class ProGAN(builtins.object)
29- | Wrapper around the Generator and the Discriminator
30- |
31- | Methods defined here:
32- |
33- | __init__(self, depth=7, latent_size=64, learning_rate=0.001, beta_1=0, beta_2=0.99, eps=1e-08, drift=0.001, n_critic=1, device=device(type='cpu'))
34- | constructor for the class
35- | :param depth: depth of the GAN (will be used for each generator and discriminator)
36- | :param latent_size: latent size of the manifold used by the GAN
37- | :param learning_rate: learning rate for Adam
38- | :param beta_1: beta_1 for Adam
39- | :param beta_2: beta_2 for Adam
40- | :param eps: epsilon for Adam
41- | :param n_critic: number of times to update discriminator
42- | :param device: device to run the GAN on (GPU / CPU)
43- |
44- | optimize_discriminator(self, noise, real_batch, depth, alpha)
45- | performs one step of weight update on discriminator using the batch of data
46- | :param noise: input noise of sample generation
47- | :param real_batch: real samples batch
48- | :param depth: current depth of optimization
49- | :param alpha: current alpha for fade-in
50- | :return: current loss (Wasserstein loss)
51- |
52- | optimize_generator(self, noise, depth, alpha)
53- | performs one step of weight update on generator for the given batch_size
54- | :param noise: input random noise required for generating samples
55- | :param depth: depth of the network at which optimization is done
56- | :param alpha: value of alpha for fade-in effect
57- | :return: current loss (Wasserstein estimate)
58- |
59- | ----------------------------------------------------------------------
60- | Data descriptors defined here:
61- |
62- | __dict__
63- | dictionary for instance variables (if defined)
64- |
65- | __weakref__
66- | list of weak references to the object (if defined)
24+ ` pg.ProGAN ` . Mostly, you'll only need the ProGAN module.
25+
26+ 4.) Example Code for CIFAR-10 dataset:
27+
28+ import torch as th
29+ import torchvision as tv
30+ import pro_gan_pytorch.PRO_GAN as pg
31+
32+ # select the device to be used for training
33+ device = th.device("cuda" if th.cuda.is_available() else "cpu")
34+ data_path = "cifar-10/"
35+
36+ def setup_data(batch_size, num_workers, download=False):
37+ """
38+ setup the CIFAR-10 dataset for training the CNN
39+ :param batch_size: batch_size for sgd
40+ :param num_workers: num_readers for data reading
41+ :param download: Boolean for whether to download the data
42+ :return: classes, trainloader, testloader => training and testing data loaders
43+ """
44+ # data setup:
45+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
46+ 'dog', 'frog', 'horse', 'ship', 'truck')
47+
48+ transforms = tv.transforms.ToTensor()
49+
50+ trainset = tv.datasets.CIFAR10(root=data_path,
51+ transform=transforms,
52+ download=download)
53+ trainloader = th.utils.data.DataLoader(trainset, batch_size=batch_size,
54+ shuffle=True,
55+ num_workers=num_workers)
56+
57+ testset = tv.datasets.CIFAR10(root=data_path,
58+ transform=transforms, train=False,
59+ download=False)
60+ testloader = th.utils.data.DataLoader(testset, batch_size=batch_size,
61+ shuffle=True,
62+ num_workers=num_workers)
63+
64+ return classes, trainloader, testloader
65+
66+
67+ if __name__ == '__main__':
68+
69+ # some parameters:
70+ depth = 4
71+ num_epochs = 100 # number of epochs per depth (resolution)
72+ latent_size = 128
73+
74+ # get the data. Ignore the test data and their classes
75+ _, train_data_loader, _ = setup_data(batch_size=32, num_workers=3, download=True)
76+
77+ # ======================================================================
78+ # This line creates the PRO-GAN
79+ # ======================================================================
80+ pro_gan = pg.ProGAN(depth=depth, latent_size=latent_size, device=device)
81+ # ======================================================================
82+
83+ # train the pro_gan using the cifar-10 data
84+ for current_depth in range(depth):
85+ print("working on depth:", current_depth)
86+
87+ # note that the rest of the api indexes depth from 0
88+ for epoch in range(1, num_epochs + 1):
89+ print("\ncurrent_epoch: ", epoch)
90+
91+ # calculate the value of aplha for fade-in effect
92+ alpha = int(epoch / num_epochs)
93+
94+ # iterate over the dataset in batches:
95+ for i, batch in enumerate(train_data_loader, 1):
96+ images, _ = batch
97+ # generate some random noise:
98+ noise = th.randn(images.shape[0], latent_size)
99+
100+ # optimize discriminator:
101+ dis_loss = pro_gan.optimize_discriminator(noise, images, current_depth, alpha)
102+
103+ # optimize generator:
104+ gen_loss = pro_gan.optimize_generator(noise, current_depth, alpha)
105+
106+ print("Batch: %d dis_loss: %.3f gen_loss: %.3f"
107+ % (i, dis_loss, gen_loss))
108+
109+ print("epoch finished ...")
110+
111+ print("training complete ...")
112+
113+ # #TODO
114+ 1.) Add the conditional PRO_GAN module <br >
115+ 2.) Setup the travis - checker. (I have to figure out some good unit tests too : D lulz!) <br >
116+ 3.) Write an informative README.rst (although it is rarely read) <br >
0 commit comments