Okay, time to take what we’ve learned so far up a notch. I am going to work on a GAN to generate simple, low resolution images. In most of the examples/tutorials I saw on the web, the generator was trying to produce replicas of hand drawn digits. They used the MNIST Handwritten Digit Dataset. And I expect there might even be a copy of that on the PC somewhere.
But, I figured it would be more interesting to try and replicate images of the clothing from the dataset we used for the classification project.
Most of those examples/tutorials also used CNN (convolutional neural networks) rather than the fully connected dense layers we used in the classifier. But for these fairly low resolution images I will, for now, stick with our fully connected linear transformation layers.
And as the discrimator here is a classifier (good image or bad image), its code will be similar to that from the earlier project. As I expect most of the setup will be as well.
I am not going to bother showing the setup in the file. A review of the previous posts on machine learning with PyTorch should show you what you will likely need. The rest will come from debugging any problems when running your code.
A Few Bits of Setup Code
That said…
Most of the code was pretty much copied from the Python module for the sine curve GAN (Project 2). With modifications based on the code for the clothing classifier (Project 1).
Because I want to use the datasets that were already downloaded for the classifier project, I added a variable pointing to the appropriate directory: pth_data = Path("/learn/mcl_pytorch/chap2/data")
. And when loading the dataset I specified it should not be downloaded. If I get an error, I will sort out the issue (note: no error).
# Create dataset for training, do not download
ds_train = torchvision.datasets.FashionMNIST(
root=pth_data,
train=True,
download=False,
transform=transform
)
The loss function (for a binary classifier) and learning rate are:
loss_fn = nn.BCELoss() # loss function for models
lr = 0.0001 # learninng rate
Function to Display Generated Images
We don’t really have any way to measure performance other than to visually inspect the images produced by the GAN’s generator model. So I am going to write a function to display the images or save to file (as I will want to include some of them in this post).
The code that follows has gone through several iterations to arrive at its current state. I started out with a batch size of 60 for training. Then repeated the training with a batch size of 32. Saving the two generators created by training with both batch sizes.
Initially the function was only getting the epoch number to use in naming the file. But I eventually realized I should put more information in the file name when I changed the batch size. But I was just using the global value for the batch size when doing so.
When I finally got around to loading the two generators from file and generating images, I couldn’t use that global value. As the function used whatever it was before I changed it to get the other generator. And, I didn’t want to continually change the initial global value and re-execute the module. I wanted to do that in a loop. So, I added batch size as a parameter. That was very late into my development efforts.
The differing batch sizes also meant I was generating differing sets of subplots, 30 and 32. So, I also needed to use the batch size to generate some of my image parameters. Finally, as you will see I was originally using the “gray” colour map. But I didn’t like the dark background. So when I got around to that last bit of code using the saved generators, I switched to “gist_yarg”.
And, I wasn’t displaying the images during training, as I didn’t really want to sit here watching the training slowly progressing. I settled on a value of 50 epochs and did not show the images during training. But of course at the end I did want them displayed. So I eventually added that final if
block to unload the image or show it after saving to file.
Wow, my mouth runneth over. Here’s the code for that function.
# function to display batch of generated images during training
# and when generating images using model loaded from file
def show_images(ep, b_sz):
# generate batch of fake images
noise = torch.randn((b_sz, 100))
noise = noise.to(device)
# need fakes available to numpy/matplotlib (keep forgetting this)
fakes = genatr(noise).cpu().detach()
if b_sz == 60:
_, axs = plt.subplots(5, 6, figsize=(10, 6))
rw_sz = 6
nbr_img = int(b_sz / 2)
elif b_sz == 32:
_, axs = plt.subplots(4, 8, figsize=(10, 6))
rw_sz = 8
nbr_img = b_sz
rw = 0
for i in range(nbr_img):
img = (fakes[i] / 2 + 0.5).reshape(28, 28)
if i > 0 and i % rw_sz == 0:
rw += 1
ax = axs[rw, i - (rw_sz * rw)]
ax.imshow(img, cmap="gist_yarg")
ax.set_xticks([])
ax.set_yticks([])
if ep != 0:
plt.savefig(img_dir / f"gan_{ep}_{BATCH_SZ}.png")
if not ld_model:
plt.close()
else:
plt.show()
Building the GAN
I am coding the model definitions and training in an if
block controlled by a global variable. That allows me to train the model or in the else
block load one from file and generate images.
Model Classes and Instantiation
Now the discriminator and generator models do change a bit from the previous projects. More layers for one. Also, we code the generator to effectively mirror the discriminator. That is, we run the layers in more or less the reverse order. Using the same number of neurons for input and output, but swapped around. (Web search provided the mirror insight.) Otherwise they are very similar. Note: I_SZ
is a global variable intialized to 28 * 28
, the number of pixels in each image.
The final activation layer for the discriminator uses the Sigmoid
function. Essentially giving us a probability value between 0 and 1. The generator uses the Tanh
function. Giving us values between -1 and 1. Why?
This is due to the fact that when generating the images, they are typically normalized to be either in the range [0,1] or [-1,1]. So if you want your output images to be in [0,1] you can use a sigmoid and if you want them to be in [-1,1] you can use tanh.
Why use tanh function at the last layer of generator in GAN?, answer by Roger Trullo
If you look at the code we used in the classifier we normalized the image data around 0
. And as I copied that transformer to this module, we need to use Tanh
.
# define Discriminator model class
class Discriminator(nn.Module):
def __init__(self, do_p=0.3):
super().__init__()
self.model = nn.Sequential(
nn.Linear(I_SZ, 1024),
nn.ReLU(),
nn.Dropout(p=do_p),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(p=do_p),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(p=do_p),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
outp = self.model(x)
return outp
# define Generator model class, it should "mirror" the discriminator
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, I_SZ),
nn.Tanh()
)
def forward(self, x):
outp = self.model(x)
return outp
# instantiate the models, specify optimizer function
discrm = Discriminator()
discrm.to(device)
opt_d = torch.optim.Adam(discrm.parameters(), lr=lr)
genatr = Generator()
genatr.to(device)
opt_g = torch.optim.Adam(genatr.parameters(), lr=lr)
Helper Functions
As in the sine curve GAN we will also use functions for the various steps in training this GAN. They are virtually identical with the exception of some of the input/output handling.
# okay some functions to facilitate training the gan
def train_disc_real(samples):
# load 'real' samples to gpu
r_smpl = samples.reshape(-1, I_SZ).to(device)
# zero out the gradients on each iteration
# don't want them accumulating over a number of training cycles or epochs
opt_d.zero_grad()
out_d = discrm(r_smpl)
# calulate loss and backpropogate to update weights and biases
loss_d = loss_fn(out_d, real_lbls)
loss_d.backward()
opt_d.step()
return loss_d
def train_disc_fake():
# generate some fake data
noise = torch.randn((BATCH_SZ, 100))
noise = noise.to(device)
fakes = genatr(noise)
# train discriminator on fake data
opt_d.zero_grad()
out_d = discrm(fakes)
loss_d = loss_fn(out_d, fake_lbls)
loss_d.backward()
opt_d.step()
return loss_d
def train_genr():
# generate random vectors from latent space
noise = torch.randn((BATCH_SZ, 100))
noise = noise.to(device)
# train generator
opt_g.zero_grad()
fakes = genatr(noise)
out_disc = discrm(fakes)
# want to convince discriminator fakes are real, so must use real labels
loss_g = loss_fn(out_disc, real_lbls)
loss_g.backward()
opt_g.step()
return loss_g, fakes
Training Loop
And the training loop is also pretty similar (copied from sine curve GAN and refactored). One difference is generating and saving generator output every 10 epochs. Immediately after completing training I save the both network models to files. In case I decide to do more training, I want to start with where I left off. (Note: further reading after I wrote this post, well the draft, leads me to believe that if I wish to do further training, I should be saving the model and optimizer states not the torchscript versions of the models.)
st_tm = time.perf_counter()
for epoch in range(trn_len):
agg_l_g = 0
agg_l_d = 0
# don't need training labels
for i, (r_samp, _) in enumerate(trn_loader):
loss_d = train_disc_real(r_samp)
agg_l_d += loss_d
loss_d = train_disc_fake()
agg_l_d += loss_d
loss_g, fakes = train_genr()
agg_l_g += loss_g
d_loss = agg_l_d / i
g_loss = agg_l_g / i
if epoch % 10 == 9:
print(f"epoch {epoch + 1}, d_loss: {d_loss}, g_loss: {g_loss}")
show_images(epoch, BATCH_SZ)
nd_tm = time.perf_counter()
print(f"\ntime to train GAN : {nd_tm - st_tm}")
show_images(epoch, BATCH_SZ)
# Save models, using torchscript instead of model and parameters
if sv_model:
g_script = torch.jit.script(genatr)
fl_nm = Path(f"gen_clothing_{BATCH_SZ}_{trn_len}.pt")
g_script.save(sv_dir / fl_nm)
d_script = torch.jit.script(discrm)
fl_nm = Path(f"disc_clothing_{BATCH_SZ}_{trn_len}.pt")
d_script.save(sv_dir / fl_nm)
Training the GAN
As mentioned previously I trained the GAN twice with different batch sizes.
Batch Size 60
Well, as I had copied the code from another module, the batch size was initially 64. Which of course caused training to crash. There are 60,000 samples in the training set. That value is not evenly divisible by 64. The last batch of 32 items was not acceptable to PyTorch, given my code, and it raised an error. So after a bit of thinking, I changed it to 60.
Though, I have since realized I could likely have determined the actual batch size in my training related functions rather than using a fixed global variable. Would definitely have made my code a touch more flexible.
(mclp-3.12) PS F:\learn\mcl_pytorch\chap4> python gan_grayscale.py
epoch 10, d_loss: 0.31766918301582336, g_loss: 3.8361012935638428
epoch 20, d_loss: 0.7685217261314392, g_loss: 1.814321517944336
epoch 30, d_loss: 0.9799447059631348, g_loss: 1.3401504755020142
epoch 40, d_loss: 1.0692670345306396, g_loss: 1.1749638319015503
epoch 50, d_loss: 1.1013325452804565, g_loss: 1.1256409883499146
time to train GAN : 818.5133819000039
Those are not exactly great loss numbers for the generator. See below for some sample images generated during training.
Batch Size 32
(mclp-3.12) PS F:\learn\mcl_pytorch\chap4> python gan_grayscale.py
epoch 10, d_loss: 0.7847617864608765, g_loss: 1.8588207960128784
epoch 20, d_loss: 1.0907469987869263, g_loss: 1.1640613079071045
epoch 30, d_loss: 1.1679587364196777, g_loss: 1.0121952295303345
epoch 40, d_loss: 1.2104214429855347, g_loss: 0.9440439343452454
epoch 50, d_loss: 1.2266089916229248, g_loss: 0.915054440498352
time to train GAN : 1090.6042208000144
The generator loss value is significantly lower from the start. I am assuming that the smaller batches mean “more training” per epoch. Resulting in slightly better performance for each epoch. A slightly better than linear improvement over the larger batch size. But of course at the price of a longer period of time to complete training.
Images During Training
The images generated with the batch size of 60 are on the left, those for 32 on the right.
I think the images on the right (smaller batch size) are of better quality following each of the epochs sampled.
Load Saved Models and Generate Images
And, let’s have a look at what the generators can produce. I am loading them from file for each of the two images being generated.
else:
# Load model and generate images
g_tsts = {"t1": 32, "t2": 60}
for t_lbl, t_bsz in g_tsts.items():
fl_nm = Path(f"gen_clothing_{t_bsz}_{trn_len}.pt")
genatr = torch.jit.load(sv_dir / fl_nm, map_location=device)
# set to generation mode
genatr.eval()
show_images(t_lbl, t_bsz)
Once again the images from the generator trained with the smaller batch size is on the right.
As displayed on the page the images are not the best. But I think we can still see that the generators do a pretty good job given how little code was required to build and train them. And relatively speaking not too much time was required. (I left them running and went downstairs to work on supper.)
Done
I have been thinking I might look at modifying the models for this project to use convoluted layers instead of the transformation layers. I expect that might speed up and/or improve training of the generator.
But with the code and images this post is getting way too long. So calling this one fini.
Until next time, do enjoy the fascination of training a GAN. I am really quite happy getting this one coded and working with, in my view, relative ease.